CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .dependency_injection import BaseContext, ChatContext, Depends
from .function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from .tool import Tool, tool
from .toolkit import Toolkit
__all__ = [
"BaseContext",
"ChatContext",
"Depends",
"Tool",
"Toolkit",
"get_function_schema",
"load_basemodels_if_needed",
"serialize_to_str",
"tool",
]

View File

@@ -0,0 +1,9 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .time import TimeTool
__all__ = [
"TimeTool",
]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .time import TimeTool
__all__ = ["TimeTool"]

View File

@@ -0,0 +1,41 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from typing import Annotated
from autogen.tools import Tool
from ....doc_utils import export_module
__all__ = ["TimeTool"]
@export_module("autogen.tools.contrib") # API Reference: autogen > tools > contrib > TimeAgent
class TimeTool(Tool):
"""Outputs the current date and time of the computer."""
def __init__(
self,
*,
date_time_format: str = "%Y-%m-%d %H:%M:%S", # This is a parameter that is unique to this tool
):
"""Get the date and time of the computer.
Args:
date_time_format (str, optional): The format of the date and time. Defaults to "%Y-%m-%d %H:%M:%S".
"""
self._date_time_format = date_time_format
async def get_date_and_time(
date_time_format: Annotated[str, "date/time Python format"] = self._date_time_format,
) -> str:
return datetime.now().strftime(date_time_format)
super().__init__(
name="date_time",
description="Get the current computer's date and time.",
func_or_tool=get_date_and_time,
)

View File

@@ -0,0 +1,254 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import functools
import inspect
import sys
from abc import ABC
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, Union, get_type_hints
from ..agentchat import Agent
from ..doc_utils import export_module
from ..fast_depends import Depends as FastDepends
from ..fast_depends import inject
from ..fast_depends.dependencies import model
if TYPE_CHECKING:
from ..agentchat.conversable_agent import ConversableAgent
__all__ = [
"BaseContext",
"ChatContext",
"Depends",
"Field",
"get_context_params",
"inject_params",
"on",
"remove_params",
]
@export_module("autogen.tools")
class BaseContext(ABC):
"""Base class for context classes.
This is the base class for defining various context types that may be used
throughout the application. It serves as a parent for specific context classes.
"""
pass
@export_module("autogen.tools")
class ChatContext(BaseContext):
"""ChatContext class that extends BaseContext.
This class is used to represent a chat context that holds a list of messages.
It inherits from `BaseContext` and adds the `messages` attribute.
"""
def __init__(self, agent: "ConversableAgent") -> None:
"""Initializes the ChatContext with an agent.
Args:
agent: The agent to use for retrieving chat messages.
"""
self._agent = agent
@property
def chat_messages(self) -> dict[Agent, list[dict[Any, Any]]]:
"""The messages in the chat.
Returns:
A dictionary of agents and their messages.
"""
return self._agent.chat_messages
@property
def last_message(self) -> Optional[dict[str, Any]]:
"""The last message in the chat.
Returns:
The last message in the chat.
"""
return self._agent.last_message()
T = TypeVar("T")
def on(x: T) -> Callable[[], T]:
def inner(ag2_x: T = x) -> T:
return ag2_x
return inner
@export_module("autogen.tools")
def Depends(x: Any) -> Any: # noqa: N802
"""Creates a dependency for injection based on the provided context or type.
Args:
x: The context or dependency to be injected.
Returns:
A FastDepends object that will resolve the dependency for injection.
"""
if isinstance(x, BaseContext):
return FastDepends(lambda: x)
return FastDepends(x)
def get_context_params(func: Callable[..., Any], subclass: Union[type[BaseContext], type[ChatContext]]) -> list[str]:
"""Gets the names of the context parameters in a function signature.
Args:
func: The function to inspect for context parameters.
subclass: The subclass to search for.
Returns:
A list of parameter names that are instances of the specified subclass.
"""
sig = inspect.signature(func)
return [p.name for p in sig.parameters.values() if _is_context_param(p, subclass=subclass)]
def _is_context_param(
param: inspect.Parameter, subclass: Union[type[BaseContext], type[ChatContext]] = BaseContext
) -> bool:
# param.annotation.__args__[0] is used to handle Annotated[MyContext, Depends(MyContext(b=2))]
param_annotation = param.annotation.__args__[0] if hasattr(param.annotation, "__args__") else param.annotation
try:
return isinstance(param_annotation, type) and issubclass(param_annotation, subclass)
except TypeError:
return False
def _is_depends_param(param: inspect.Parameter) -> bool:
return isinstance(param.default, model.Depends) or (
hasattr(param.annotation, "__metadata__")
and type(param.annotation.__metadata__) == tuple
and isinstance(param.annotation.__metadata__[0], model.Depends)
)
def remove_params(func: Callable[..., Any], sig: inspect.Signature, params: Iterable[str]) -> None:
new_signature = sig.replace(parameters=[p for p in sig.parameters.values() if p.name not in params])
func.__signature__ = new_signature # type: ignore[attr-defined]
def _remove_injected_params_from_signature(func: Callable[..., Any]) -> Callable[..., Any]:
# This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible
if sys.version_info >= (3, 9) and isinstance(func, staticmethod) and hasattr(func, "__func__"):
func = _fix_staticmethod(func)
sig = inspect.signature(func)
params_to_remove = [p.name for p in sig.parameters.values() if _is_context_param(p) or _is_depends_param(p)]
remove_params(func, sig, params_to_remove)
return func
class Field:
"""Represents a description field for use in type annotations.
This class is used to store a description for an annotated field, often used for
documenting or validating fields in a context or data model.
"""
def __init__(self, description: str) -> None:
"""Initializes the Field with a description.
Args:
description: The description text for the field.
"""
self._description = description
@property
def description(self) -> str:
return self._description
def _string_metadata_to_description_field(func: Callable[..., Any]) -> Callable[..., Any]:
type_hints = get_type_hints(func, include_extras=True)
for _, annotation in type_hints.items():
# Check if the annotation itself has metadata (using __metadata__)
if hasattr(annotation, "__metadata__"):
metadata = annotation.__metadata__
if metadata and isinstance(metadata[0], str):
# Replace string metadata with Field
annotation.__metadata__ = (Field(description=metadata[0]),)
# For Python < 3.11, annotations like `Optional` are stored as `Union`, so metadata
# would be in the first element of __args__ (e.g., `__args__[0]` for `int` in `Optional[int]`)
elif hasattr(annotation, "__args__") and hasattr(annotation.__args__[0], "__metadata__"):
metadata = annotation.__args__[0].__metadata__
if metadata and isinstance(metadata[0], str):
# Replace string metadata with Field
annotation.__args__[0].__metadata__ = (Field(description=metadata[0]),)
return func
def _fix_staticmethod(f: Callable[..., Any]) -> Callable[..., Any]:
# This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible
if sys.version_info >= (3, 9) and isinstance(f, staticmethod) and hasattr(f, "__func__"):
@wraps(f.__func__)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return f.__func__(*args, **kwargs) # type: ignore[attr-defined]
wrapper.__name__ = f.__func__.__name__
f = wrapper
return f
def _set_return_annotation_to_any(f: Callable[..., Any]) -> Callable[..., Any]:
if inspect.iscoroutinefunction(f):
@functools.wraps(f)
async def _a_wrapped_func(*args: Any, **kwargs: Any) -> Any:
return await f(*args, **kwargs)
wrapped_func = _a_wrapped_func
else:
@functools.wraps(f)
def _wrapped_func(*args: Any, **kwargs: Any) -> Any:
return f(*args, **kwargs)
wrapped_func = _wrapped_func
sig = inspect.signature(f)
# Change the return annotation directly on the signature of the wrapper
wrapped_func.__signature__ = sig.replace(return_annotation=Any) # type: ignore[attr-defined]
return wrapped_func
def inject_params(f: Callable[..., Any]) -> Callable[..., Any]:
"""Injects parameters into a function, removing injected dependencies from its signature.
This function is used to modify a function by injecting dependencies and removing
injected parameters from the function's signature.
Args:
f: The function to modify with dependency injection.
Returns:
The modified function with injected dependencies and updated signature.
"""
# This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible
if sys.version_info >= (3, 9) and isinstance(f, staticmethod) and hasattr(f, "__func__"):
f = _fix_staticmethod(f)
f = _string_metadata_to_description_field(f)
f = _set_return_annotation_to_any(f)
f = inject(f)
f = _remove_injected_params_from_signature(f)
return f

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .browser_use import BrowserUseTool
from .crawl4ai import Crawl4AITool
from .deep_research import DeepResearchTool
from .duckduckgo import DuckDuckGoSearchTool
from .google_search import GoogleSearchTool, YoutubeSearchTool
from .messageplatform import (
DiscordRetrieveTool,
DiscordSendTool,
SlackRetrieveRepliesTool,
SlackRetrieveTool,
SlackSendTool,
TelegramRetrieveTool,
TelegramSendTool,
)
from .perplexity import PerplexitySearchTool
from .reliable import ReliableTool, ReliableToolError, SuccessfulExecutionParameters, ToolExecutionDetails
from .tavily import TavilySearchTool
from .web_search_preview import WebSearchPreviewTool
from .wikipedia import WikipediaPageLoadTool, WikipediaQueryRunTool
__all__ = [
"BrowserUseTool",
"Crawl4AITool",
"DeepResearchTool",
"DiscordRetrieveTool",
"DiscordSendTool",
"DuckDuckGoSearchTool",
"GoogleSearchTool",
"PerplexitySearchTool",
"ReliableTool",
"ReliableToolError",
"SlackRetrieveRepliesTool",
"SlackRetrieveTool",
"SlackSendTool",
"SuccessfulExecutionParameters",
"TavilySearchTool",
"TelegramRetrieveTool",
"TelegramSendTool",
"ToolExecutionDetails",
"WebSearchPreviewTool",
"WikipediaPageLoadTool",
"WikipediaQueryRunTool",
"YoutubeSearchTool",
]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .browser_use import BrowserUseResult, BrowserUseTool, ExtractedContent
__all__ = ["BrowserUseResult", "BrowserUseTool", "ExtractedContent"]

View File

@@ -0,0 +1,161 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Annotated, Any, Optional, Union
from pydantic import BaseModel, field_validator
from ....doc_utils import export_module
from ....import_utils import optional_import_block, require_optional_import
from ....llm_config import LLMConfig
from ... import Depends, Tool
from ...dependency_injection import on
with optional_import_block():
from browser_use import Agent, Controller
from browser_use.browser.browser import Browser, BrowserConfig
from ....interop.langchain.langchain_chat_model_factory import LangChainChatModelFactory
__all__ = ["BrowserUseResult", "BrowserUseTool", "ExtractedContent"]
@export_module("autogen.tools.experimental.browser_use")
class ExtractedContent(BaseModel):
"""Extracted content from the browser.
Attributes:
content: The extracted content.
url: The URL of the extracted content
"""
content: str
url: Optional[str]
@field_validator("url")
@classmethod
def check_url(cls, v: str) -> Optional[str]:
"""Check if the URL is about:blank and return None if it is.
Args:
v: The URL to check.
"""
if v == "about:blank":
return None
return v
@export_module("autogen.tools.experimental.browser_use")
class BrowserUseResult(BaseModel):
"""The result of using the browser to perform a task.
Attributes:
extracted_content: List of extracted content.
final_result: The final result.
"""
extracted_content: list[ExtractedContent]
final_result: Optional[str]
@require_optional_import(
[
"langchain_anthropic",
"langchain_google_genai",
"langchain_ollama",
"langchain_openai",
"langchain_core",
"browser_use",
],
"browser-use",
)
@export_module("autogen.tools.experimental")
class BrowserUseTool(Tool):
"""BrowserUseTool is a tool that uses the browser to perform a task."""
def __init__( # type: ignore[no-any-unimported]
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
browser: Optional["Browser"] = None,
agent_kwargs: Optional[dict[str, Any]] = None,
browser_config: Optional[dict[str, Any]] = None,
):
"""Use the browser to perform a task.
Args:
llm_config: The LLM configuration.
browser: The browser to use. If defined, browser_config must be None
agent_kwargs: Additional keyword arguments to pass to the Agent
browser_config: The browser configuration to use. If defined, browser must be None
"""
if agent_kwargs is None:
agent_kwargs = {}
if browser_config is None:
browser_config = {}
if browser is not None and browser_config:
raise ValueError(
f"Cannot provide both browser and additional keyword parameters: {browser=}, {browser_config=}"
)
async def browser_use( # type: ignore[no-any-unimported]
task: Annotated[str, "The task to perform."],
llm_config: Annotated[Union[LLMConfig, dict[str, Any]], Depends(on(llm_config))],
browser: Annotated[Optional[Browser], Depends(on(browser))],
agent_kwargs: Annotated[dict[str, Any], Depends(on(agent_kwargs))],
browser_config: Annotated[dict[str, Any], Depends(on(browser_config))],
) -> BrowserUseResult:
agent_kwargs = agent_kwargs.copy()
browser_config = browser_config.copy()
if browser is None:
# set default value for headless
headless = browser_config.pop("headless", True)
browser_config = BrowserConfig(headless=headless, **browser_config)
browser = Browser(config=browser_config)
# set default value for generate_gif
if "generate_gif" not in agent_kwargs:
agent_kwargs["generate_gif"] = False
llm = LangChainChatModelFactory.create_base_chat_model(llm_config)
max_steps = agent_kwargs.pop("max_steps", 100)
agent = Agent(
task=task,
llm=llm,
browser=browser,
controller=BrowserUseTool._get_controller(llm_config),
**agent_kwargs,
)
result = await agent.run(max_steps=max_steps)
extracted_content = [
ExtractedContent(content=content, url=url)
for content, url in zip(result.extracted_content(), result.urls())
]
return BrowserUseResult(
extracted_content=extracted_content,
final_result=result.final_result(),
)
super().__init__(
name="browser_use",
description="Use the browser to perform a task.",
func_or_tool=browser_use,
)
@staticmethod
def _get_controller(llm_config: Union[LLMConfig, dict[str, Any]]) -> Any:
response_format = (
llm_config["config_list"][0].get("response_format", None)
if "config_list" in llm_config
else llm_config.get("response_format")
)
return Controller(output_model=response_format)

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .crawl4ai import Crawl4AITool
__all__ = ["Crawl4AITool"]

View File

@@ -0,0 +1,153 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Annotated, Any, Optional, Union
from pydantic import BaseModel
from ....doc_utils import export_module
from ....import_utils import optional_import_block, require_optional_import
from ....interop import LiteLLmConfigFactory
from ....llm_config import LLMConfig
from ... import Tool
from ...dependency_injection import Depends, on
with optional_import_block():
from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode, CrawlerRunConfig
from crawl4ai.extraction_strategy import LLMExtractionStrategy
__all__ = ["Crawl4AITool"]
@require_optional_import(["crawl4ai"], "crawl4ai")
@export_module("autogen.tools.experimental")
class Crawl4AITool(Tool):
"""
Crawl a website and extract information using the crawl4ai library.
"""
def __init__(
self,
llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
extraction_model: Optional[type[BaseModel]] = None,
llm_strategy_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Initialize the Crawl4AITool.
Args:
llm_config: The config dictionary for the LLM model. If None, the tool will run without LLM.
extraction_model: The Pydantic model to use for extraction. If None, the tool will use the default schema.
llm_strategy_kwargs: The keyword arguments to pass to the LLM extraction strategy.
"""
Crawl4AITool._validate_llm_strategy_kwargs(llm_strategy_kwargs, llm_config_provided=(llm_config is not None))
async def crawl4ai_helper( # type: ignore[no-any-unimported]
url: str,
browser_cfg: Optional["BrowserConfig"] = None,
crawl_config: Optional["CrawlerRunConfig"] = None,
) -> Any:
async with AsyncWebCrawler(config=browser_cfg) as crawler:
result = await crawler.arun(
url=url,
config=crawl_config,
)
if crawl_config is None:
response = result.markdown
else:
response = result.extracted_content if result.success else result.error_message
return response
async def crawl4ai_without_llm(
url: Annotated[str, "The url to crawl and extract information from."],
) -> Any:
return await crawl4ai_helper(url=url)
async def crawl4ai_with_llm(
url: Annotated[str, "The url to crawl and extract information from."],
instruction: Annotated[str, "The instruction to provide on how and what to extract."],
llm_config: Annotated[Any, Depends(on(llm_config))],
llm_strategy_kwargs: Annotated[Optional[dict[str, Any]], Depends(on(llm_strategy_kwargs))],
extraction_model: Annotated[Optional[type[BaseModel]], Depends(on(extraction_model))],
) -> Any:
browser_cfg = BrowserConfig(headless=True)
crawl_config = Crawl4AITool._get_crawl_config(
llm_config=llm_config,
instruction=instruction,
extraction_model=extraction_model,
llm_strategy_kwargs=llm_strategy_kwargs,
)
return await crawl4ai_helper(url=url, browser_cfg=browser_cfg, crawl_config=crawl_config)
super().__init__(
name="crawl4ai",
description="Crawl a website and extract information.",
func_or_tool=crawl4ai_without_llm if llm_config is None else crawl4ai_with_llm,
)
@staticmethod
def _validate_llm_strategy_kwargs(llm_strategy_kwargs: Optional[dict[str, Any]], llm_config_provided: bool) -> None:
if not llm_strategy_kwargs:
return
if not llm_config_provided:
raise ValueError("llm_strategy_kwargs can only be provided if llm_config is also provided.")
check_parameters_error_msg = "".join(
f"'{key}' should not be provided in llm_strategy_kwargs. It is automatically set based on llm_config.\n"
for key in ["provider", "api_token"]
if key in llm_strategy_kwargs
)
check_parameters_error_msg += "".join(
"'schema' should not be provided in llm_strategy_kwargs. It is automatically set based on extraction_model type.\n"
if "schema" in llm_strategy_kwargs
else ""
)
check_parameters_error_msg += "".join(
"'instruction' should not be provided in llm_strategy_kwargs. It is provided at the time of calling the tool.\n"
if "instruction" in llm_strategy_kwargs
else ""
)
if check_parameters_error_msg:
raise ValueError(check_parameters_error_msg)
@staticmethod
def _get_crawl_config( # type: ignore[no-any-unimported]
llm_config: Union[LLMConfig, dict[str, Any]],
instruction: str,
llm_strategy_kwargs: Optional[dict[str, Any]] = None,
extraction_model: Optional[type[BaseModel]] = None,
) -> "CrawlerRunConfig":
lite_llm_config = LiteLLmConfigFactory.create_lite_llm_config(llm_config)
if llm_strategy_kwargs is None:
llm_strategy_kwargs = {}
schema = (
extraction_model.model_json_schema()
if (extraction_model and issubclass(extraction_model, BaseModel))
else None
)
extraction_type = llm_strategy_kwargs.pop("extraction_type", "schema" if schema else "block")
# 1. Define the LLM extraction strategy
llm_strategy = LLMExtractionStrategy(
**lite_llm_config,
schema=schema,
extraction_type=extraction_type,
instruction=instruction,
**llm_strategy_kwargs,
)
# 2. Build the crawler config
crawl_config = CrawlerRunConfig(extraction_strategy=llm_strategy, cache_mode=CacheMode.BYPASS)
return crawl_config

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .deep_research import DeepResearchTool
__all__ = ["DeepResearchTool"]

View File

@@ -0,0 +1,328 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import copy
from typing import Annotated, Any, Callable, Union
from pydantic import BaseModel, Field
from ....agentchat import ConversableAgent
from ....doc_utils import export_module
from ....llm_config import LLMConfig
from ... import Depends, Tool
from ...dependency_injection import on
__all__ = ["DeepResearchTool"]
class Subquestion(BaseModel):
question: Annotated[str, Field(description="The original question.")]
def format(self) -> str:
return f"Question: {self.question}\n"
class SubquestionAnswer(Subquestion):
answer: Annotated[str, Field(description="The answer to the question.")]
def format(self) -> str:
return f"Question: {self.question}\n{self.answer}\n"
class Task(BaseModel):
question: Annotated[str, Field(description="The original question.")]
subquestions: Annotated[list[Subquestion], Field(description="The subquestions that need to be answered.")]
def format(self) -> str:
return f"Task: {self.question}\n\n" + "\n".join(
"Subquestion " + str(i + 1) + ":\n" + subquestion.format()
for i, subquestion in enumerate(self.subquestions)
)
class CompletedTask(BaseModel):
question: Annotated[str, Field(description="The original question.")]
subquestions: Annotated[list[SubquestionAnswer], Field(description="The subquestions and their answers")]
def format(self) -> str:
return f"Task: {self.question}\n\n" + "\n".join(
"Subquestion " + str(i + 1) + ":\n" + subquestion.format()
for i, subquestion in enumerate(self.subquestions)
)
class InformationCrumb(BaseModel):
source_url: str
source_title: str
source_summary: str
relevant_info: str
class GatheredInformation(BaseModel):
information: list[InformationCrumb]
def format(self) -> str:
return "Here is the gathered information: \n" + "\n".join(
f"URL: {info.source_url}\nTitle: {info.source_title}\nSummary: {info.source_summary}\nRelevant Information: {info.relevant_info}\n\n"
for info in self.information
)
@export_module("autogen.tools.experimental")
class DeepResearchTool(Tool):
"""A tool that delegates a web research task to the subteams of agents."""
ANSWER_CONFIRMED_PREFIX = "Answer confirmed:"
def __init__(
self,
llm_config: Union[LLMConfig, dict[str, Any]],
max_web_steps: int = 30,
):
"""Initialize the DeepResearchTool.
Args:
llm_config (LLMConfig, dict[str, Any]): The LLM configuration.
max_web_steps (int, optional): The maximum number of web steps. Defaults to 30.
"""
self.llm_config = llm_config
self.summarizer_agent = ConversableAgent(
name="SummarizerAgent",
system_message=(
"You are an agent with a task of answering the question provided by the user."
"First you need to split the question into subquestions by calling the 'split_question_and_answer_subquestions' method."
"Then you need to sintesize the answers the original question by combining the answers to the subquestions."
),
is_termination_msg=lambda x: x.get("content", "")
and x.get("content", "").startswith(self.ANSWER_CONFIRMED_PREFIX),
llm_config=llm_config,
human_input_mode="NEVER",
)
self.critic_agent = ConversableAgent(
name="CriticAgent",
system_message=(
"You are a critic agent responsible for evaluating the answer provided by the summarizer agent.\n"
"Your task is to assess the quality of the answer based on its coherence, relevance, and completeness.\n"
"Provide constructive feedback on how the answer can be improved.\n"
"If the answer is satisfactory, call the 'confirm_answer' method to end the task.\n"
),
is_termination_msg=lambda x: x.get("content", "")
and x.get("content", "").startswith(self.ANSWER_CONFIRMED_PREFIX),
llm_config=llm_config,
human_input_mode="NEVER",
)
def delegate_research_task(
task: Annotated[str, "The task to perform a research on."],
llm_config: Annotated[Union[LLMConfig, dict[str, Any]], Depends(on(llm_config))],
max_web_steps: Annotated[int, Depends(on(max_web_steps))],
) -> str:
"""Delegate a research task to the agent.
Args:
task (str): The task to perform a research on.
llm_config (LLMConfig, dict[str, Any]): The LLM configuration.
max_web_steps (int): The maximum number of web steps.
Returns:
str: The answer to the research task.
"""
@self.summarizer_agent.register_for_execution()
@self.critic_agent.register_for_llm(description="Call this method to confirm the final answer.")
def confirm_summary(answer: str, reasoning: str) -> str:
return f"{self.ANSWER_CONFIRMED_PREFIX}" + answer + "\nReasoning: " + reasoning
split_question_and_answer_subquestions = DeepResearchTool._get_split_question_and_answer_subquestions(
llm_config=llm_config,
max_web_steps=max_web_steps,
)
self.summarizer_agent.register_for_llm(description="Split the question into subquestions and get answers.")(
split_question_and_answer_subquestions
)
self.critic_agent.register_for_execution()(split_question_and_answer_subquestions)
result = self.critic_agent.initiate_chat(
self.summarizer_agent,
message="Please answer the following question: " + task,
# This outer chat should preserve the history of the conversation
clear_history=False,
)
return result.summary
super().__init__(
name=delegate_research_task.__name__,
description="Delegate a research task to the deep research agent.",
func_or_tool=delegate_research_task,
)
SUBQUESTIONS_ANSWER_PREFIX = "Subquestions answered:"
@staticmethod
def _get_split_question_and_answer_subquestions(
llm_config: Union[LLMConfig, dict[str, Any]], max_web_steps: int
) -> Callable[..., Any]:
def split_question_and_answer_subquestions(
question: Annotated[str, "The question to split and answer."],
llm_config: Annotated[Union[LLMConfig, dict[str, Any]], Depends(on(llm_config))],
max_web_steps: Annotated[int, Depends(on(max_web_steps))],
) -> str:
decomposition_agent = ConversableAgent(
name="DecompositionAgent",
system_message=(
"You are an expert at breaking down complex questions into smaller, focused subquestions.\n"
"Your task is to take any question provided and divide it into clear, actionable subquestions that can be individually answered.\n"
"Ensure the subquestions are logical, non-redundant, and cover all key aspects of the original question.\n"
"Avoid providing answers or interpretations—focus solely on decomposition.\n"
"Do not include banal, general knowledge questions\n"
"Do not include questions that go into unnecessary detail that is not relevant to the original question\n"
"Do not include question that require knowledge of the original or other subquestions to answer\n"
"Some rule of thumb is to have only one subquestion for easy questions, 3 for medium questions, and 5 for hard questions.\n"
),
llm_config=llm_config,
is_termination_msg=lambda x: x.get("content", "")
and x.get("content", "").startswith(DeepResearchTool.SUBQUESTIONS_ANSWER_PREFIX),
human_input_mode="NEVER",
)
example_task = Task(
question="What is the capital of France?",
subquestions=[Subquestion(question="What is the capital of France?")],
)
decomposition_critic = ConversableAgent(
name="DecompositionCritic",
system_message=(
"You are a critic agent responsible for evaluating the subquestions provided by the initial analysis agent.\n"
"You need to confirm whether the subquestions are clear, actionable, and cover all key aspects of the original question.\n"
"Do not accept redundant or unnecessary subquestions, focus solely on the minimal viable subset of subqestions necessary to answer the original question. \n"
"Do not accept banal, general knowledge questions\n"
"Do not accept questions that go into unnecessary detail that is not relevant to the original question\n"
"Remove questions that can be answered with combining knowledge from other questions\n"
"After you are satisfied with the subquestions, call the 'generate_subquestions' method to answer each subquestion.\n"
"This is an example of an argument that can be passed to the 'generate_subquestions' method:\n"
f"{{'task': {example_task.model_dump()}}}\n"
"Some rule of thumb is to have only one subquestion for easy questions, 3 for medium questions, and 5 for hard questions.\n"
),
llm_config=llm_config,
is_termination_msg=lambda x: x.get("content", "")
and x.get("content", "").startswith(DeepResearchTool.SUBQUESTIONS_ANSWER_PREFIX),
human_input_mode="NEVER",
)
generate_subquestions = DeepResearchTool._get_generate_subquestions(
llm_config=llm_config, max_web_steps=max_web_steps
)
decomposition_agent.register_for_execution()(generate_subquestions)
decomposition_critic.register_for_llm(description="Generate subquestions for a task.")(
generate_subquestions
)
result = decomposition_critic.initiate_chat(
decomposition_agent,
message="Analyse and gather subqestions for the following question: " + question,
)
return result.summary
return split_question_and_answer_subquestions
@staticmethod
def _get_generate_subquestions(
llm_config: Union[LLMConfig, dict[str, Any]],
max_web_steps: int,
) -> Callable[..., str]:
"""Get the generate_subquestions method.
Args:
llm_config (Union[LLMConfig, dict[str, Any]]): The LLM configuration.
max_web_steps (int): The maximum number of web steps.
Returns:
Callable[..., str]: The generate_subquestions method.
"""
def generate_subquestions(
task: Task,
llm_config: Annotated[Union[LLMConfig, dict[str, Any]], Depends(on(llm_config))],
max_web_steps: Annotated[int, Depends(on(max_web_steps))],
) -> str:
if not task.subquestions:
task.subquestions = [Subquestion(question=task.question)]
subquestions_answers: list[SubquestionAnswer] = []
for subquestion in task.subquestions:
answer = DeepResearchTool._answer_question(
subquestion.question, llm_config=llm_config, max_web_steps=max_web_steps
)
subquestions_answers.append(SubquestionAnswer(question=subquestion.question, answer=answer))
completed_task = CompletedTask(question=task.question, subquestions=subquestions_answers)
return f"{DeepResearchTool.SUBQUESTIONS_ANSWER_PREFIX} \n" + completed_task.format()
return generate_subquestions
@staticmethod
def _answer_question(
question: str,
llm_config: Union[LLMConfig, dict[str, Any]],
max_web_steps: int,
) -> str:
from ....agents.experimental.websurfer import WebSurferAgent
websurfer_config = copy.deepcopy(llm_config)
websurfer_config["config_list"][0]["response_format"] = GatheredInformation
def is_termination_msg(x: dict[str, Any]) -> bool:
content = x.get("content", "")
return (content is not None) and content.startswith(DeepResearchTool.ANSWER_CONFIRMED_PREFIX)
websurfer_agent = WebSurferAgent(
llm_config=llm_config,
web_tool_llm_config=websurfer_config,
name="WebSurferAgent",
system_message=(
"You are a web surfer agent responsible for gathering information from the web to provide information for answering a question\n"
"You will be asked to find information related to the question and provide a summary of the information gathered.\n"
"The summary should include the URL, title, summary, and relevant information for each piece of information gathered.\n"
),
is_termination_msg=is_termination_msg,
human_input_mode="NEVER",
web_tool_kwargs={
"agent_kwargs": {"max_steps": max_web_steps},
},
)
websurfer_critic = ConversableAgent(
name="WebSurferCritic",
system_message=(
"You are a critic agent responsible for evaluating the answer provided by the web surfer agent.\n"
"You need to confirm whether the information provided by the websurfer is correct and sufficient to answer the question.\n"
"You can ask the web surfer to provide more information or provide and confirm the answer.\n"
),
llm_config=llm_config,
is_termination_msg=is_termination_msg,
human_input_mode="NEVER",
)
@websurfer_agent.register_for_execution()
@websurfer_critic.register_for_llm(
description="Call this method when you agree that the original question can be answered with the gathered information and provide the answer."
)
def confirm_answer(answer: str) -> str:
return f"{DeepResearchTool.ANSWER_CONFIRMED_PREFIX} " + answer
websurfer_critic.register_for_execution()(websurfer_agent.tool)
result = websurfer_critic.initiate_chat(
websurfer_agent,
message="Please find the answer to this question: " + question,
)
return result.summary

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .duckduckgo_search import DuckDuckGoSearchTool
__all__ = ["DuckDuckGoSearchTool"]

View File

@@ -0,0 +1,109 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Annotated, Any
from ....doc_utils import export_module
from ....import_utils import optional_import_block, require_optional_import
from ... import Tool
with optional_import_block():
from duckduckgo_search import DDGS
@require_optional_import(
[
"duckduckgo_search",
],
"duckduckgo_search",
)
def _execute_duckduckgo_query(
query: str,
num_results: int = 5,
) -> list[dict[str, Any]]:
"""
Execute a search query using the DuckDuckGo Search API.
Args:
query (str): The search query string.
num_results (int, optional): The maximum number of results to return. Defaults to 5.
Returns:
list[dict[str, Any]]: A list of search results from the DuckDuckGo API.
"""
with DDGS() as ddgs:
try:
# region='wt-wt' means worldwide
results = list(ddgs.text(query, region="wt-wt", max_results=num_results))
except Exception as e:
print(f"DuckDuckGo Search failed: {e}")
results = []
return results
def _duckduckgo_search(
query: str,
num_results: int = 5,
) -> list[dict[str, Any]]:
"""
Perform a DuckDuckGo search and format the results.
This function takes search parameters, executes the query using `_execute_duckduckgo_query`,
and formats the results into a list of dictionaries containing title, link, and snippet.
Args:
query (str): The search query string.
num_results (int, optional): The maximum number of results to return. Defaults to 5.
Returns:
list[dict[str, Any]]: A list of dictionaries, where each dictionary represents a search result
with keys 'title', 'link', and 'snippet'. Returns an empty list if no results are found.
"""
res = _execute_duckduckgo_query(
query=query,
num_results=num_results,
)
return [
{"title": item.get("title", ""), "link": item.get("href", ""), "snippet": item.get("body", "")} for item in res
]
@export_module("autogen.tools.experimental")
class DuckDuckGoSearchTool(Tool):
"""
DuckDuckGoSearchTool is a tool that uses DuckDuckGo to perform a search.
This tool allows agents to leverage the DuckDuckGo search engine for information retrieval.
DuckDuckGo does not require an API key, making it easy to use.
"""
def __init__(self) -> None:
"""
Initializes the DuckDuckGoSearchTool.
"""
def duckduckgo_search(
query: Annotated[str, "The search query."],
num_results: Annotated[int, "The number of results to return."] = 5,
) -> list[dict[str, Any]]:
"""
Performs a search using the DuckDuckGo Search API and returns formatted results.
Args:
query: The search query string.
num_results: The maximum number of results to return. Defaults to 5.
Returns:
A list of dictionaries, each containing 'title', 'link', and 'snippet' of a search result.
"""
return _duckduckgo_search(
query=query,
num_results=num_results,
)
super().__init__(
name="duckduckgo_search",
description="Use the DuckDuckGo Search API to perform a search.",
func_or_tool=duckduckgo_search,
)

View File

@@ -0,0 +1,14 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .authentication import GoogleCredentialsLocalProvider, GoogleCredentialsProvider
from .drive import GoogleDriveToolkit
from .toolkit_protocol import GoogleToolkitProtocol
__all__ = [
"GoogleCredentialsLocalProvider",
"GoogleCredentialsProvider",
"GoogleDriveToolkit",
"GoogleToolkitProtocol",
]

View File

@@ -0,0 +1,11 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .credentials_local_provider import GoogleCredentialsLocalProvider
from .credentials_provider import GoogleCredentialsProvider
__all__ = [
"GoogleCredentialsLocalProvider",
"GoogleCredentialsProvider",
]

View File

@@ -0,0 +1,43 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .....doc_utils import export_module
from .....import_utils import optional_import_block
from .credentials_provider import GoogleCredentialsProvider
with optional_import_block():
from google.oauth2.credentials import Credentials
__all__ = ["GoogleCredenentialsHostedProvider"]
@export_module("autogen.tools.experimental.google.authentication")
class GoogleCredenentialsHostedProvider(GoogleCredentialsProvider):
def __init__(
self,
host: str,
port: int = 8080,
*,
kwargs: dict[str, str],
) -> None:
self._host = host
self._port = port
self._kwargs = kwargs
raise NotImplementedError("This class is not implemented yet.")
@property
def host(self) -> str:
"""The host from which to get the credentials."""
return self._host
@property
def port(self) -> int:
"""The port from which to get the credentials."""
return self._port
def get_credentials(self) -> "Credentials": # type: ignore[no-any-unimported]
raise NotImplementedError("This class is not implemented yet.")

View File

@@ -0,0 +1,91 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
from .....doc_utils import export_module
from .....import_utils import optional_import_block, require_optional_import
from .credentials_provider import GoogleCredentialsProvider
with optional_import_block():
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
__all__ = ["GoogleCredentialsLocalProvider"]
@export_module("autogen.tools.experimental.google.authentication")
class GoogleCredentialsLocalProvider(GoogleCredentialsProvider):
def __init__(
self,
client_secret_file: str,
scopes: list[str], # e.g. ['https://www.googleapis.com/auth/drive/readonly']
token_file: Optional[str] = None,
port: int = 8080,
) -> None:
"""A Google credentials provider that gets the credentials locally.
Args:
client_secret_file (str): The path to the client secret file.
scopes (list[str]): The scopes to request.
token_file (str): Optional path to the token file. If not provided, the token will not be saved.
port (int): The port from which to get the credentials.
"""
self.client_secret_file = client_secret_file
self.scopes = scopes
self.token_file = token_file
self._port = port
@property
def host(self) -> str:
"""Localhost is the default host."""
return "localhost"
@property
def port(self) -> int:
"""The port from which to get the credentials."""
return self._port
@require_optional_import(
[
"google_auth_httplib2",
"google_auth_oauthlib",
],
"google-api",
)
def _refresh_or_get_new_credentials(self, creds: Optional["Credentials"]) -> "Credentials": # type: ignore[no-any-unimported]
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request()) # type: ignore[no-untyped-call]
else:
flow = InstalledAppFlow.from_client_secrets_file(self.client_secret_file, self.scopes)
creds = flow.run_local_server(host=self.host, port=self.port)
return creds # type: ignore[return-value]
@require_optional_import(
[
"google_auth_httplib2",
"google_auth_oauthlib",
],
"google-api",
)
def get_credentials(self) -> "Credentials": # type: ignore[no-any-unimported]
"""Get the Google credentials."""
creds = None
if self.token_file and os.path.exists(self.token_file):
creds = Credentials.from_authorized_user_file(self.token_file) # type: ignore[no-untyped-call]
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
creds = self._refresh_or_get_new_credentials(creds)
if self.token_file:
# Save the credentials for the next run
with open(self.token_file, "w") as token:
token.write(creds.to_json()) # type: ignore[no-untyped-call]
return creds # type: ignore[no-any-return]

View File

@@ -0,0 +1,35 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Protocol, runtime_checkable
from .....doc_utils import export_module
from .....import_utils import optional_import_block
with optional_import_block():
from google.oauth2.credentials import Credentials
__all__ = ["GoogleCredentialsProvider"]
@runtime_checkable
@export_module("autogen.tools.experimental.google.authentication")
class GoogleCredentialsProvider(Protocol):
"""A protocol for Google credentials provider."""
def get_credentials(self) -> Optional["Credentials"]: # type: ignore[no-any-unimported]
"""Get the Google credentials."""
...
@property
def host(self) -> str:
"""The host from which to get the credentials."""
...
@property
def port(self) -> int:
"""The port from which to get the credentials."""
...

View File

@@ -0,0 +1,9 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .toolkit import GoogleDriveToolkit
__all__ = [
"GoogleDriveToolkit",
]

View File

@@ -0,0 +1,124 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import io
from pathlib import Path
from typing import Any, Optional
from .....import_utils import optional_import_block, require_optional_import
from ..model import GoogleFileInfo
with optional_import_block():
from googleapiclient.http import MediaIoBaseDownload
__all__ = [
"download_file",
"list_files_and_folders",
]
@require_optional_import(
[
"googleapiclient",
],
"google-api",
)
def list_files_and_folders(service: Any, page_size: int, folder_id: Optional[str]) -> list[GoogleFileInfo]:
kwargs = {
"pageSize": page_size,
"fields": "nextPageToken, files(id, name, mimeType)",
}
if folder_id:
kwargs["q"] = f"'{folder_id}' in parents and trashed=false" # Search for files in the folder
response = service.files().list(**kwargs).execute()
result = response.get("files", [])
if not isinstance(result, list):
raise ValueError(f"Expected a list of files, but got {result}")
result = [GoogleFileInfo(**file_info) for file_info in result]
return result
def _get_file_extension(mime_type: str) -> Optional[str]:
"""Returns the correct file extension for a given MIME type."""
mime_extensions = {
"application/vnd.google-apps.document": "docx", # Google Docs → Word
"application/vnd.google-apps.spreadsheet": "csv", # Google Sheets → CSV
"application/vnd.google-apps.presentation": "pptx", # Google Slides → PowerPoint
"video/quicktime": "mov",
"application/vnd.google.colaboratory": "ipynb",
"application/pdf": "pdf",
"image/jpeg": "jpg",
"image/png": "png",
"text/plain": "txt",
"application/zip": "zip",
}
return mime_extensions.get(mime_type)
@require_optional_import(
[
"googleapiclient",
],
"google-api",
)
def download_file(
service: Any,
file_id: str,
file_name: str,
mime_type: str,
download_folder: Path,
subfolder_path: Optional[str] = None,
) -> str:
"""Download or export file based on its MIME type, optionally saving to a subfolder."""
file_extension = _get_file_extension(mime_type)
if file_extension and (not file_name.lower().endswith(file_extension.lower())):
file_name = f"{file_name}.{file_extension}"
# Define export formats for Google Docs, Sheets, and Slides
export_mime_types = {
"application/vnd.google-apps.document": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # Google Docs → Word
"application/vnd.google-apps.spreadsheet": "text/csv", # Google Sheets → CSV
"application/vnd.google-apps.presentation": "application/vnd.openxmlformats-officedocument.presentationml.presentation", # Google Slides → PowerPoint
}
# Google Docs, Sheets, and Slides cannot be downloaded directly using service.files().get_media() because they are Google-native files
if mime_type in export_mime_types:
request = service.files().export(fileId=file_id, mimeType=export_mime_types[mime_type])
else:
# Download normal files (videos, images, etc.)
request = service.files().get_media(fileId=file_id)
# Determine the final destination directory
destination_dir = download_folder
if subfolder_path:
destination_dir = download_folder / subfolder_path
# Ensure the subfolder exists, create it if necessary
destination_dir.mkdir(parents=True, exist_ok=True)
# Construct the full path for the file
file_path = destination_dir / file_name
# Save file
try:
with io.BytesIO() as buffer:
downloader = MediaIoBaseDownload(buffer, request)
done = False
while not done:
_, done = downloader.next_chunk()
buffer.seek(0)
with open(file_path, "wb") as f:
f.write(buffer.getvalue())
# Print out the relative path of the downloaded file
relative_path = Path(subfolder_path) / file_name if subfolder_path else Path(file_name)
return f"✅ Downloaded: {relative_path}"
except Exception as e:
# Error message if unable to download
relative_path = Path(subfolder_path) / file_name if subfolder_path else Path(file_name)
return f"❌ FAILED to download {relative_path}: {e}"

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Annotated, Literal, Optional, Union
from .....doc_utils import export_module
from .....import_utils import optional_import_block
from .... import Toolkit, tool
from ..model import GoogleFileInfo
from ..toolkit_protocol import GoogleToolkitProtocol
from .drive_functions import download_file, list_files_and_folders
with optional_import_block():
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
__all__ = [
"GoogleDriveToolkit",
]
@export_module("autogen.tools.experimental.google.drive")
class GoogleDriveToolkit(Toolkit, GoogleToolkitProtocol):
"""A tool map for Google Drive."""
def __init__( # type: ignore[no-any-unimported]
self,
*,
credentials: "Credentials",
download_folder: Union[Path, str],
exclude: Optional[list[Literal["list_drive_files_and_folders", "download_file_from_drive"]]] = None,
api_version: str = "v3",
) -> None:
"""Initialize the Google Drive tool map.
Args:
credentials: The Google OAuth2 credentials.
download_folder: The folder to download files to.
exclude: The tool names to exclude.
api_version: The Google Drive API version to use."
"""
self.service = build(serviceName="drive", version=api_version, credentials=credentials)
if isinstance(download_folder, str):
download_folder = Path(download_folder)
download_folder.mkdir(parents=True, exist_ok=True)
@tool(description="List files and folders in a Google Drive")
def list_drive_files_and_folders(
page_size: Annotated[int, "The number of files to list per page."] = 10,
folder_id: Annotated[
Optional[str],
"The ID of the folder to list files from. If not provided, lists all files in the root folder.",
] = None,
) -> list[GoogleFileInfo]:
return list_files_and_folders(service=self.service, page_size=page_size, folder_id=folder_id)
@tool(description="download a file from Google Drive")
def download_file_from_drive(
file_info: Annotated[GoogleFileInfo, "The file info to download."],
subfolder_path: Annotated[
Optional[str],
"The subfolder path to save the file in. If not provided, saves in the main download folder.",
] = None,
) -> str:
return download_file(
service=self.service,
file_id=file_info.id,
file_name=file_info.name,
mime_type=file_info.mime_type,
download_folder=download_folder,
subfolder_path=subfolder_path,
)
if exclude is None:
exclude = []
tools = [tool for tool in [list_drive_files_and_folders, download_file_from_drive] if tool.name not in exclude]
super().__init__(tools=tools)
@classmethod
def recommended_scopes(cls) -> list[str]:
"""Return the recommended scopes manatory for using tools from this tool map."""
return [
"https://www.googleapis.com/auth/drive.readonly",
]

View File

@@ -0,0 +1,17 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Annotated
from pydantic import BaseModel, Field
__all__ = [
"GoogleFileInfo",
]
class GoogleFileInfo(BaseModel):
name: Annotated[str, Field(description="The name of the file.")]
id: Annotated[str, Field(description="The ID of the file.")]
mime_type: Annotated[str, Field(alias="mimeType", description="The MIME type of the file.")]

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Protocol, runtime_checkable
__all__ = [
"GoogleToolkitProtocol",
]
@runtime_checkable
class GoogleToolkitProtocol(Protocol):
"""A protocol for Google tool maps."""
@classmethod
def recommended_scopes(cls) -> list[str]:
"""Defines a required static method without implementation."""
...

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .google_search import GoogleSearchTool
from .youtube_search import YoutubeSearchTool
__all__ = ["GoogleSearchTool", "YoutubeSearchTool"]

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Annotated, Any, Optional
from ....doc_utils import export_module
from ....import_utils import optional_import_block, require_optional_import
from ... import Depends, Tool
from ...dependency_injection import on
with optional_import_block():
from googleapiclient.discovery import build
@require_optional_import(
[
"googleapiclient",
],
"google-search",
)
def _execute_query(query: str, search_api_key: str, search_engine_id: str, num_results: int) -> Any:
service = build("customsearch", "v1", developerKey=search_api_key)
return service.cse().list(q=query, cx=search_engine_id, num=num_results).execute()
def _google_search(
query: str,
search_api_key: str,
search_engine_id: str,
num_results: int,
) -> list[dict[str, Any]]:
res = _execute_query(
query=query, search_api_key=search_api_key, search_engine_id=search_engine_id, num_results=num_results
)
return [
{"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")}
for item in res.get("items", [])
]
@export_module("autogen.tools.experimental")
class GoogleSearchTool(Tool):
"""GoogleSearchTool is a tool that uses the Google Search API to perform a search."""
def __init__(
self,
*,
search_api_key: Optional[str] = None,
search_engine_id: Optional[str] = None,
use_internal_llm_tool_if_available: bool = True,
):
"""GoogleSearchTool is a tool that uses the Google Search API to perform a search.
Args:
search_api_key: The API key for the Google Search API.
search_engine_id: The search engine ID for the Google Search API.
use_internal_llm_tool_if_available: Whether to use the predefined (e.g. Gemini GenaAI) search tool. Currently, this can only be used for agents with the Gemini (GenAI) configuration.
"""
self.search_api_key = search_api_key
self.search_engine_id = search_engine_id
self.use_internal_llm_tool_if_available = use_internal_llm_tool_if_available
if not use_internal_llm_tool_if_available and (search_api_key is None or search_engine_id is None):
raise ValueError(
"search_api_key and search_engine_id must be provided if use_internal_llm_tool_if_available is False"
)
if use_internal_llm_tool_if_available and (search_api_key is not None or search_engine_id is not None):
logging.warning("search_api_key and search_engine_id will be ignored if internal LLM tool is available")
def google_search(
query: Annotated[str, "The search query."],
search_api_key: Annotated[Optional[str], Depends(on(search_api_key))],
search_engine_id: Annotated[Optional[str], Depends(on(search_engine_id))],
num_results: Annotated[int, "The number of results to return."] = 10,
) -> list[dict[str, Any]]:
if search_api_key is None or search_engine_id is None:
raise ValueError(
"Your LLM is not configured to use prebuilt google-search tool.\n"
"Please provide search_api_key and search_engine_id.\n"
)
return _google_search(query, search_api_key, search_engine_id, num_results)
super().__init__(
# GeminiClient will look for a tool with the name "prebuilt_google_search"
name="prebuilt_google_search" if use_internal_llm_tool_if_available else "google_search",
description="Use the Google Search API to perform a search.",
func_or_tool=google_search,
)

View File

@@ -0,0 +1,181 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Annotated, Any, Dict, List, Optional
from ....doc_utils import export_module
from ....import_utils import optional_import_block, require_optional_import
from ... import Depends, Tool
from ...dependency_injection import on
with optional_import_block():
import googleapiclient.errors
from googleapiclient.discovery import build
@require_optional_import(
["googleapiclient"],
"google-search",
)
def _execute_search_query(query: str, youtube_api_key: str, max_results: int) -> Any:
"""Execute a YouTube search query using the YouTube Data API.
Args:
query: The search query string.
youtube_api_key: The API key for the YouTube Data API.
max_results: The maximum number of results to return.
Returns:
The search response from the YouTube Data API.
"""
youtube = build("youtube", "v3", developerKey=youtube_api_key)
try:
search_response = (
youtube.search().list(q=query, part="id,snippet", maxResults=max_results, type="video").execute()
)
return search_response
except googleapiclient.errors.HttpError as e:
logging.error(f"An HTTP error occurred: {e}")
raise
@require_optional_import(
["googleapiclient"],
"google-search",
)
def _get_video_details(video_ids: List[str], youtube_api_key: str) -> Any:
"""Get detailed information about specific YouTube videos.
Args:
video_ids: List of YouTube video IDs.
youtube_api_key: The API key for the YouTube Data API.
Returns:
The video details response from the YouTube Data API.
"""
if not video_ids:
return {"items": []}
youtube = build("youtube", "v3", developerKey=youtube_api_key)
try:
videos_response = (
youtube.videos().list(id=",".join(video_ids), part="snippet,contentDetails,statistics").execute()
)
return videos_response
except googleapiclient.errors.HttpError as e:
logging.error(f"An HTTP error occurred: {e}")
raise
def _youtube_search(
query: str,
youtube_api_key: str,
max_results: int,
include_video_details: bool = True,
) -> List[Dict[str, Any]]:
"""Search YouTube videos based on a query.
Args:
query: The search query string.
youtube_api_key: The API key for the YouTube Data API.
max_results: The maximum number of results to return.
include_video_details: Whether to include detailed video information.
Returns:
A list of dictionaries containing information about the videos.
"""
search_response = _execute_search_query(query=query, youtube_api_key=youtube_api_key, max_results=max_results)
results = []
video_ids = []
# Extract basic info from search results
for item in search_response.get("items", []):
if item["id"]["kind"] == "youtube#video":
video_ids.append(item["id"]["videoId"])
video_info = {
"title": item["snippet"]["title"],
"description": item["snippet"]["description"],
"publishedAt": item["snippet"]["publishedAt"],
"channelTitle": item["snippet"]["channelTitle"],
"videoId": item["id"]["videoId"],
"url": f"https://www.youtube.com/watch?v={item['id']['videoId']}",
}
results.append(video_info)
# If detailed info requested, get it
if include_video_details and video_ids:
video_details = _get_video_details(video_ids, youtube_api_key)
# Create a mapping of videoId to details
details_map = {item["id"]: item for item in video_details.get("items", [])}
# Update results with additional details
for result in results:
video_id = result["videoId"]
if video_id in details_map:
details = details_map[video_id]
result.update({
"viewCount": details["statistics"].get("viewCount"),
"likeCount": details["statistics"].get("likeCount"),
"commentCount": details["statistics"].get("commentCount"),
"duration": details["contentDetails"].get("duration"),
"definition": details["contentDetails"].get("definition"),
})
return results
@export_module("autogen.tools.experimental")
class YoutubeSearchTool(Tool):
"""YoutubeSearchTool is a tool that uses the YouTube Data API to search for videos."""
def __init__(
self,
*,
youtube_api_key: Optional[str] = None,
):
"""Initialize a YouTube search tool.
Args:
youtube_api_key: The API key for the YouTube Data API.
"""
self.youtube_api_key = youtube_api_key
if youtube_api_key is None:
raise ValueError("youtube_api_key must be provided")
def youtube_search(
query: Annotated[str, "The search query for YouTube videos."],
youtube_api_key: Annotated[str, Depends(on(youtube_api_key))],
max_results: Annotated[int, "The maximum number of results to return."] = 5,
include_video_details: Annotated[bool, "Whether to include detailed video information."] = True,
) -> List[Dict[str, Any]]:
"""Search for YouTube videos based on a query.
Args:
query: The search query string.
youtube_api_key: The API key for the YouTube Data API.
max_results: The maximum number of results to return.
include_video_details: Whether to include detailed video information.
Returns:
A list of dictionaries containing information about the videos.
"""
if youtube_api_key is None:
raise ValueError("YouTube API key is required")
return _youtube_search(query, youtube_api_key, max_results, include_video_details)
super().__init__(
name="youtube_search",
description="Search for YouTube videos based on a query, optionally including detailed information.",
func_or_tool=youtube_search,
)

View File

@@ -0,0 +1,17 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .discord import DiscordRetrieveTool, DiscordSendTool
from .slack import SlackRetrieveRepliesTool, SlackRetrieveTool, SlackSendTool
from .telegram import TelegramRetrieveTool, TelegramSendTool
__all__ = [
"DiscordRetrieveTool",
"DiscordSendTool",
"SlackRetrieveRepliesTool",
"SlackRetrieveTool",
"SlackSendTool",
"TelegramRetrieveTool",
"TelegramSendTool",
]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .discord import DiscordRetrieveTool, DiscordSendTool
__all__ = ["DiscordRetrieveTool", "DiscordSendTool"]

View File

@@ -0,0 +1,288 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
from datetime import datetime, timezone
from typing import Annotated, Any, Union
from .....doc_utils import export_module
from .....import_utils import optional_import_block, require_optional_import
from .... import Tool
from ....dependency_injection import Depends, on
__all__ = ["DiscordRetrieveTool", "DiscordSendTool"]
with optional_import_block():
from discord import Client, Intents, utils
MAX_MESSAGE_LENGTH = 2000
MAX_BATCH_RETRIEVE_MESSAGES = 100 # Discord's max per request
@require_optional_import(["discord"], "commsagent-discord")
@export_module("autogen.tools.experimental")
class DiscordSendTool(Tool):
"""Sends a message to a Discord channel."""
def __init__(self, *, bot_token: str, channel_name: str, guild_name: str) -> None:
"""
Initialize the DiscordSendTool.
Args:
bot_token: The bot token to use for sending messages.
channel_name: The name of the channel to send messages to.
guild_name: The name of the guild for the channel.
"""
# Function that sends the message, uses dependency injection for bot token / channel / guild
async def discord_send_message(
message: Annotated[str, "Message to send to the channel."],
bot_token: Annotated[str, Depends(on(bot_token))],
guild_name: Annotated[str, Depends(on(guild_name))],
channel_name: Annotated[str, Depends(on(channel_name))],
) -> Any:
"""
Sends a message to a Discord channel.
Args:
message: The message to send to the channel.
bot_token: The bot token to use for Discord. (uses dependency injection)
guild_name: The name of the server. (uses dependency injection)
channel_name: The name of the channel. (uses dependency injection)
"""
intents = Intents.default()
intents.message_content = True
intents.guilds = True
intents.guild_messages = True
client = Client(intents=intents)
result_future: asyncio.Future[str] = asyncio.Future() # Stores the result of the send
# When the client is ready, we'll send the message
@client.event # type: ignore[misc]
async def on_ready() -> None:
try:
# Server
guild = utils.get(client.guilds, name=guild_name)
if guild:
# Channel
channel = utils.get(guild.text_channels, name=channel_name)
if channel:
# Send the message
if len(message) > MAX_MESSAGE_LENGTH:
chunks = [
message[i : i + (MAX_MESSAGE_LENGTH - 1)]
for i in range(0, len(message), (MAX_MESSAGE_LENGTH - 1))
]
for i, chunk in enumerate(chunks):
sent = await channel.send(chunk)
# Store ID for the first chunk
if i == 0:
sent_message_id = str(sent.id)
result_future.set_result(
f"Message sent successfully ({len(chunks)} chunks, first ID: {sent_message_id}):\n{message}"
)
else:
sent = await channel.send(message)
result_future.set_result(f"Message sent successfully (ID: {sent.id}):\n{message}")
else:
result_future.set_result(f"Message send failed, could not find channel: {channel_name}")
else:
result_future.set_result(f"Message send failed, could not find guild: {guild_name}")
except Exception as e:
result_future.set_exception(e)
finally:
try:
await client.close()
except Exception as e:
raise Exception(f"Unable to close Discord client: {e}")
# Start the client and when it's ready it'll send the message in on_ready
try:
await client.start(bot_token)
# Capture the result of the send
return await result_future
except Exception as e:
raise Exception(f"Failed to start Discord client: {e}")
super().__init__(
name="discord_send",
description="Sends a message to a Discord channel.",
func_or_tool=discord_send_message,
)
@require_optional_import(["discord"], "commsagent-discord")
@export_module("autogen.tools.experimental")
class DiscordRetrieveTool(Tool):
"""Retrieves messages from a Discord channel."""
def __init__(self, *, bot_token: str, channel_name: str, guild_name: str) -> None:
"""
Initialize the DiscordRetrieveTool.
Args:
bot_token: The bot token to use for retrieving messages.
channel_name: The name of the channel to retrieve messages from.
guild_name: The name of the guild for the channel.
"""
async def discord_retrieve_messages(
bot_token: Annotated[str, Depends(on(bot_token))],
guild_name: Annotated[str, Depends(on(guild_name))],
channel_name: Annotated[str, Depends(on(channel_name))],
messages_since: Annotated[
Union[str, None],
"Date to retrieve messages from (ISO format) OR Discord snowflake ID. If None, retrieves latest messages.",
] = None,
maximum_messages: Annotated[
Union[int, None], "Maximum number of messages to retrieve. If None, retrieves all messages since date."
] = None,
) -> Any:
"""
Retrieves messages from a Discord channel.
Args:
bot_token: The bot token to use for Discord. (uses dependency injection)
guild_name: The name of the server. (uses dependency injection)
channel_name: The name of the channel. (uses dependency injection)
messages_since: ISO format date string OR Discord snowflake ID, to retrieve messages from. If None, retrieves latest messages.
maximum_messages: Maximum number of messages to retrieve. If None, retrieves all messages since date.
"""
intents = Intents.default()
intents.message_content = True
intents.guilds = True
intents.guild_messages = True
client = Client(intents=intents)
result_future: asyncio.Future[list[dict[str, Any]]] = asyncio.Future()
messages_since_date: Union[str, None] = None
if messages_since is not None:
if DiscordRetrieveTool._is_snowflake(messages_since):
messages_since_date = DiscordRetrieveTool._snowflake_to_iso(messages_since)
else:
messages_since_date = messages_since
@client.event # type: ignore[misc]
async def on_ready() -> None:
try:
messages = []
# Get guild and channel
guild = utils.get(client.guilds, name=guild_name)
if not guild:
result_future.set_result([{"error": f"Could not find guild: {guild_name}"}])
return
channel = utils.get(guild.text_channels, name=channel_name)
if not channel:
result_future.set_result([{"error": f"Could not find channel: {channel_name}"}])
return
# Setup retrieval parameters
last_message_id = None
messages_retrieved = 0
# Convert to ISO format
after_date = None
if messages_since_date:
try:
from datetime import datetime
after_date = datetime.fromisoformat(messages_since_date)
except ValueError:
result_future.set_result([
{"error": f"Invalid date format: {messages_since_date}. Use ISO format."}
])
return
while True:
# Setup fetch options
fetch_options = {
"limit": MAX_BATCH_RETRIEVE_MESSAGES,
"before": last_message_id if last_message_id else None,
"after": after_date if after_date else None,
}
# Fetch batch of messages
message_batch = []
async for message in channel.history(**fetch_options): # type: ignore[arg-type]
message_batch.append(message)
messages_retrieved += 1
# Check if we've reached the maximum
if maximum_messages and messages_retrieved >= maximum_messages:
break
if not message_batch:
break
# Process messages
for msg in message_batch:
messages.append({
"id": str(msg.id),
"content": msg.content,
"author": str(msg.author),
"timestamp": msg.created_at.isoformat(),
})
# Update last message ID for pagination
last_message_id = message_batch[-1] # Use message object directly as 'before' parameter
# Break if we've reached the maximum
if maximum_messages and messages_retrieved >= maximum_messages:
break
result_future.set_result(messages)
except Exception as e:
result_future.set_exception(e)
finally:
try:
await client.close()
except Exception as e:
raise Exception(f"Unable to close Discord client: {e}")
try:
await client.start(bot_token)
return await result_future
except Exception as e:
raise Exception(f"Failed to start Discord client: {e}")
super().__init__(
name="discord_retrieve",
description="Retrieves messages from a Discord channel based datetime/message ID and/or number of latest messages.",
func_or_tool=discord_retrieve_messages,
)
@staticmethod
def _is_snowflake(value: str) -> bool:
"""Check if a string is a valid Discord snowflake ID."""
# Must be numeric and 17-20 digits
if not value.isdigit():
return False
digit_count = len(value)
return 17 <= digit_count <= 20
@staticmethod
def _snowflake_to_iso(snowflake: str) -> str:
"""Convert a Discord snowflake ID to ISO timestamp string."""
if not DiscordRetrieveTool._is_snowflake(snowflake):
raise ValueError(f"Invalid snowflake ID: {snowflake}")
# Discord epoch (2015-01-01)
discord_epoch = 1420070400000
# Convert ID to int and shift right 22 bits to get timestamp
timestamp_ms = (int(snowflake) >> 22) + discord_epoch
# Convert to datetime and format as ISO string
dt = datetime.fromtimestamp(timestamp_ms / 1000.0, tz=timezone.utc)
return dt.isoformat()

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .slack import SlackRetrieveRepliesTool, SlackRetrieveTool, SlackSendTool
__all__ = ["SlackRetrieveRepliesTool", "SlackRetrieveTool", "SlackSendTool"]

View File

@@ -0,0 +1,391 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
from datetime import datetime, timedelta
from typing import Annotated, Any, Optional, Tuple, Union
from .....doc_utils import export_module
from .....import_utils import optional_import_block, require_optional_import
from .... import Tool
from ....dependency_injection import Depends, on
__all__ = ["SlackSendTool"]
with optional_import_block():
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
MAX_MESSAGE_LENGTH = 40000
@require_optional_import(["slack_sdk"], "commsagent-slack")
@export_module("autogen.tools.experimental")
class SlackSendTool(Tool):
"""Sends a message to a Slack channel."""
def __init__(self, *, bot_token: str, channel_id: str) -> None:
"""
Initialize the SlackSendTool.
Args:
bot_token: Bot User OAuth Token starting with "xoxb-".
channel_id: Channel ID where messages will be sent.
"""
# Function that sends the message, uses dependency injection for bot token / channel / guild
async def slack_send_message(
message: Annotated[str, "Message to send to the channel."],
bot_token: Annotated[str, Depends(on(bot_token))],
channel_id: Annotated[str, Depends(on(channel_id))],
) -> Any:
"""
Sends a message to a Slack channel.
Args:
message: The message to send to the channel.
bot_token: The bot token to use for Slack. (uses dependency injection)
channel_id: The ID of the channel. (uses dependency injection)
"""
try:
web_client = WebClient(token=bot_token)
# Send the message
if len(message) > MAX_MESSAGE_LENGTH:
chunks = [
message[i : i + (MAX_MESSAGE_LENGTH - 1)]
for i in range(0, len(message), (MAX_MESSAGE_LENGTH - 1))
]
for i, chunk in enumerate(chunks):
response = web_client.chat_postMessage(channel=channel_id, text=chunk)
if not response["ok"]:
return f"Message send failed on chunk {i + 1}, Slack response error: {response['error']}"
# Store ID for the first chunk
if i == 0:
sent_message_id = response["ts"]
return f"Message sent successfully ({len(chunks)} chunks, first ID: {sent_message_id}):\n{message}"
else:
response = web_client.chat_postMessage(channel=channel_id, text=message)
if not response["ok"]:
return f"Message send failed, Slack response error: {response['error']}"
return f"Message sent successfully (ID: {response['ts']}):\n{message}"
except SlackApiError as e:
return f"Message send failed, Slack API exception: {e.response['error']} (See https://api.slack.com/automation/cli/errors#{e.response['error']})"
except Exception as e:
return f"Message send failed, exception: {e}"
super().__init__(
name="slack_send",
description="Sends a message to a Slack channel.",
func_or_tool=slack_send_message,
)
@require_optional_import(["slack_sdk"], "commsagent-slack")
@export_module("autogen.tools.experimental")
class SlackRetrieveTool(Tool):
"""Retrieves messages from a Slack channel."""
def __init__(self, *, bot_token: str, channel_id: str) -> None:
"""
Initialize the SlackRetrieveTool.
Args:
bot_token: Bot User OAuth Token starting with "xoxb-".
channel_id: Channel ID where messages will be sent.
"""
async def slack_retrieve_messages(
bot_token: Annotated[str, Depends(on(bot_token))],
channel_id: Annotated[str, Depends(on(channel_id))],
messages_since: Annotated[
Union[str, None],
"Date to retrieve messages from (ISO format) OR Slack message ID. If None, retrieves latest messages.",
] = None,
maximum_messages: Annotated[
Union[int, None], "Maximum number of messages to retrieve. If None, retrieves all messages since date."
] = None,
) -> Any:
"""
Retrieves messages from a Discord channel.
Args:
bot_token: The bot token to use for Discord. (uses dependency injection)
channel_id: The ID of the channel. (uses dependency injection)
messages_since: ISO format date string OR Slack message ID, to retrieve messages from. If None, retrieves latest messages.
maximum_messages: Maximum number of messages to retrieve. If None, retrieves all messages since date.
"""
try:
web_client = WebClient(token=bot_token)
# Convert ISO datetime to Unix timestamp if needed
oldest = None
if messages_since:
if "." in messages_since: # Likely a Slack message ID
oldest = messages_since
else: # Assume ISO format
try:
dt = datetime.fromisoformat(messages_since.replace("Z", "+00:00"))
oldest = str(dt.timestamp())
except ValueError as e:
return f"Invalid date format. Please provide either a Slack message ID or ISO format date (e.g., '2025-01-25T00:00:00Z'). Error: {e}"
messages = []
cursor = None
while True:
try:
# Prepare API call parameters
params = {
"channel": channel_id,
"limit": min(1000, maximum_messages) if maximum_messages else 1000,
}
if oldest:
params["oldest"] = oldest
if cursor:
params["cursor"] = cursor
# Make API call
response = web_client.conversations_history(**params) # type: ignore[arg-type]
if not response["ok"]:
return f"Message retrieval failed, Slack response error: {response['error']}"
# Add messages to our list
messages.extend(response["messages"])
# Check if we've hit our maximum
if maximum_messages and len(messages) >= maximum_messages:
messages = messages[:maximum_messages]
break
# Check if there are more messages
if not response["has_more"]:
break
cursor = response["response_metadata"]["next_cursor"]
except SlackApiError as e:
return f"Message retrieval failed on pagination, Slack API error: {e.response['error']}"
return {
"message_count": len(messages),
"messages": messages,
"start_time": oldest or "latest",
}
except SlackApiError as e:
return f"Message retrieval failed, Slack API exception: {e.response['error']} (See https://api.slack.com/automation/cli/errors#{e.response['error']})"
except Exception as e:
return f"Message retrieval failed, exception: {e}"
super().__init__(
name="slack_retrieve",
description="Retrieves messages from a Slack channel based datetime/message ID and/or number of latest messages.",
func_or_tool=slack_retrieve_messages,
)
@require_optional_import(["slack_sdk"], "commsagent-slack")
@export_module("autogen.tools.experimental")
class SlackRetrieveRepliesTool(Tool):
"""Retrieves replies to a specific Slack message from both threads and the channel."""
def __init__(self, *, bot_token: str, channel_id: str) -> None:
"""
Initialize the SlackRetrieveRepliesTool.
Args:
bot_token: Bot User OAuth Token starting with "xoxb-".
channel_id: Channel ID where the parent message exists.
"""
async def slack_retrieve_replies(
message_ts: Annotated[str, "Timestamp (ts) of the parent message to retrieve replies for."],
bot_token: Annotated[str, Depends(on(bot_token))],
channel_id: Annotated[str, Depends(on(channel_id))],
min_replies: Annotated[
Optional[int],
"Minimum number of replies to wait for before returning (thread + channel). If None, returns immediately.",
] = None,
timeout_seconds: Annotated[
int, "Maximum time in seconds to wait for the requested number of replies."
] = 60,
poll_interval: Annotated[int, "Time in seconds between polling attempts when waiting for replies."] = 5,
include_channel_messages: Annotated[
bool, "Whether to include messages in the channel after the original message."
] = True,
) -> Any:
"""
Retrieves replies to a specific Slack message, from both threads and the main channel.
Args:
message_ts: The timestamp (ts) identifier of the parent message.
bot_token: The bot token to use for Slack. (uses dependency injection)
channel_id: The ID of the channel. (uses dependency injection)
min_replies: Minimum number of combined replies to wait for before returning. If None, returns immediately.
timeout_seconds: Maximum time in seconds to wait for the requested number of replies.
poll_interval: Time in seconds between polling attempts when waiting for replies.
include_channel_messages: Whether to include messages posted in the channel after the original message.
"""
try:
web_client = WebClient(token=bot_token)
# Function to get current thread replies
async def get_thread_replies() -> tuple[Optional[list[dict[str, Any]]], Optional[str]]:
try:
response = web_client.conversations_replies(
channel=channel_id,
ts=message_ts,
)
if not response["ok"]:
return None, f"Thread reply retrieval failed, Slack response error: {response['error']}"
# The first message is the parent message itself, so exclude it when counting replies
replies = response["messages"][1:] if len(response["messages"]) > 0 else []
return replies, None
except SlackApiError as e:
return None, f"Thread reply retrieval failed, Slack API exception: {e.response['error']}"
except Exception as e:
return None, f"Thread reply retrieval failed, exception: {e}"
# Function to get messages in the channel after the original message
async def get_channel_messages() -> Tuple[Optional[list[dict[str, Any]]], Optional[str]]:
try:
response = web_client.conversations_history(
channel=channel_id,
oldest=message_ts, # Start from the original message timestamp
inclusive=False, # Don't include the original message
)
if not response["ok"]:
return None, f"Channel message retrieval failed, Slack response error: {response['error']}"
# Return all messages in the channel after the original message
# We need to filter out any that are part of the thread we're already getting
messages = []
for msg in response["messages"]:
# Skip if the message is part of the thread we're already retrieving
if "thread_ts" in msg and msg["thread_ts"] == message_ts:
continue
messages.append(msg)
return messages, None
except SlackApiError as e:
return None, f"Channel message retrieval failed, Slack API exception: {e.response['error']}"
except Exception as e:
return None, f"Channel message retrieval failed, exception: {e}"
# Function to get all replies (both thread and channel)
async def get_all_replies() -> Tuple[
Optional[list[dict[str, Any]]], Optional[list[dict[str, Any]]], Optional[str]
]:
thread_replies, thread_error = await get_thread_replies()
if thread_error:
return None, None, thread_error
channel_messages: list[dict[str, Any]] = []
channel_error = None
if include_channel_messages:
channel_results, channel_error = await get_channel_messages()
if channel_error:
return thread_replies, None, channel_error
channel_messages = channel_results if channel_results is not None else []
return thread_replies, channel_messages, None
# If no waiting is required, just get replies and return
if min_replies is None:
thread_replies, channel_messages, error = await get_all_replies()
if error:
return error
thread_replies_list: list[dict[str, Any]] = [] if thread_replies is None else thread_replies
channel_messages_list: list[dict[str, Any]] = [] if channel_messages is None else channel_messages
# Combine replies for counting but keep them separate in the result
total_reply_count = len(thread_replies_list) + len(channel_messages_list)
return {
"parent_message_ts": message_ts,
"total_reply_count": total_reply_count,
"thread_replies": thread_replies_list,
"thread_reply_count": len(thread_replies_list),
"channel_messages": channel_messages_list if include_channel_messages else None,
"channel_message_count": len(channel_messages_list) if include_channel_messages else None,
}
# Wait for the required number of replies with timeout
start_time = datetime.now()
end_time = start_time + timedelta(seconds=timeout_seconds)
while datetime.now() < end_time:
thread_replies, channel_messages, error = await get_all_replies()
if error:
return error
thread_replies_current: list[dict[str, Any]] = [] if thread_replies is None else thread_replies
channel_messages_current: list[dict[str, Any]] = (
[] if channel_messages is None else channel_messages
)
# Combine replies for counting
total_reply_count = len(thread_replies_current) + len(channel_messages_current)
# If we have enough total replies, return them
if total_reply_count >= min_replies:
return {
"parent_message_ts": message_ts,
"total_reply_count": total_reply_count,
"thread_replies": thread_replies_current,
"thread_reply_count": len(thread_replies_current),
"channel_messages": channel_messages_current if include_channel_messages else None,
"channel_message_count": len(channel_messages_current)
if include_channel_messages
else None,
"waited_seconds": (datetime.now() - start_time).total_seconds(),
}
# Wait before checking again
await asyncio.sleep(poll_interval)
# If we reach here, we timed out waiting for replies
thread_replies, channel_messages, error = await get_all_replies()
if error:
return error
# Combine replies for counting
total_reply_count = len(thread_replies or []) + len(channel_messages or [])
return {
"parent_message_ts": message_ts,
"total_reply_count": total_reply_count,
"thread_replies": thread_replies or [],
"thread_reply_count": len(thread_replies or []),
"channel_messages": channel_messages or [] if include_channel_messages else None,
"channel_message_count": len(channel_messages or []) if include_channel_messages else None,
"timed_out": True,
"waited_seconds": timeout_seconds,
"requested_replies": min_replies,
}
except SlackApiError as e:
return f"Reply retrieval failed, Slack API exception: {e.response['error']} (See https://api.slack.com/automation/cli/errors#{e.response['error']})"
except Exception as e:
return f"Reply retrieval failed, exception: {e}"
super().__init__(
name="slack_retrieve_replies",
description="Retrieves replies to a specific Slack message, checking both thread replies and messages in the channel after the original message.",
func_or_tool=slack_retrieve_replies,
)

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .telegram import TelegramRetrieveTool, TelegramSendTool
__all__ = ["TelegramRetrieveTool", "TelegramSendTool"]

View File

@@ -0,0 +1,275 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from typing import Annotated, Any, Union
from .....doc_utils import export_module
from .....import_utils import optional_import_block, require_optional_import
from .... import Tool
from ....dependency_injection import Depends, on
__all__ = ["TelegramRetrieveTool", "TelegramSendTool"]
with optional_import_block():
from telethon import TelegramClient
from telethon.tl.types import InputMessagesFilterEmpty, Message, PeerChannel, PeerChat, PeerUser
MAX_MESSAGE_LENGTH = 4096
@require_optional_import(["telethon", "telethon.tl.types"], "commsagent-telegram")
@export_module("autogen.tools.experimental")
class BaseTelegramTool:
"""Base class for Telegram tools containing shared functionality."""
def __init__(self, api_id: str, api_hash: str, session_name: str) -> None:
self._api_id = api_id
self._api_hash = api_hash
self._session_name = session_name
def _get_client(self) -> "TelegramClient": # type: ignore[no-any-unimported]
"""Get a fresh TelegramClient instance."""
return TelegramClient(self._session_name, self._api_id, self._api_hash)
@staticmethod
def _get_peer_from_id(chat_id: str) -> Union["PeerChat", "PeerChannel", "PeerUser"]: # type: ignore[no-any-unimported]
"""Convert a chat ID string to appropriate Peer type."""
try:
# Convert string to integer
id_int = int(chat_id)
# Channel/Supergroup: -100 prefix
if str(chat_id).startswith("-100"):
channel_id = int(str(chat_id)[4:]) # Remove -100 prefix
return PeerChannel(channel_id)
# Group: negative number without -100 prefix
elif id_int < 0:
group_id = -id_int # Remove the negative sign
return PeerChat(group_id)
# User/Bot: positive number
else:
return PeerUser(id_int)
except ValueError as e:
raise ValueError(f"Invalid chat_id format: {chat_id}. Error: {str(e)}")
async def _initialize_entity(self, client: "TelegramClient", chat_id: str) -> Any: # type: ignore[no-any-unimported]
"""Initialize and cache the entity by trying different methods."""
peer = self._get_peer_from_id(chat_id)
try:
# Try direct entity resolution first
entity = await client.get_entity(peer)
return entity
except ValueError:
try:
# Get all dialogs (conversations)
async for dialog in client.iter_dialogs():
# For users/bots, we need to find the dialog with the user
if (
isinstance(peer, PeerUser)
and dialog.entity.id == peer.user_id
or dialog.entity.id == getattr(peer, "channel_id", getattr(peer, "chat_id", None))
):
return dialog.entity
# If we get here, we didn't find the entity in dialogs
raise ValueError(f"Could not find entity {chat_id} in dialogs")
except Exception as e:
raise ValueError(
f"Could not initialize entity for {chat_id}. "
f"Make sure you have access to this chat. Error: {str(e)}"
)
@require_optional_import(["telethon"], "commsagent-telegram")
@export_module("autogen.tools.experimental")
class TelegramSendTool(BaseTelegramTool, Tool):
"""Sends a message to a Telegram channel, group, or user."""
def __init__(self, *, api_id: str, api_hash: str, chat_id: str) -> None:
"""
Initialize the TelegramSendTool.
Args:
api_id: Telegram API ID from https://my.telegram.org/apps.
api_hash: Telegram API hash from https://my.telegram.org/apps.
chat_id: The ID of the destination (Channel, Group, or User ID).
"""
BaseTelegramTool.__init__(self, api_id, api_hash, "telegram_send_session")
async def telegram_send_message(
message: Annotated[str, "Message to send to the chat."],
chat_id: Annotated[str, Depends(on(chat_id))],
) -> Any:
"""
Sends a message to a Telegram chat.
Args:
message: The message to send.
chat_id: The ID of the destination. (uses dependency injection)
"""
try:
client = self._get_client()
async with client:
# Initialize and cache the entity
entity = await self._initialize_entity(client, chat_id)
if len(message) > MAX_MESSAGE_LENGTH:
chunks = [
message[i : i + (MAX_MESSAGE_LENGTH - 1)]
for i in range(0, len(message), (MAX_MESSAGE_LENGTH - 1))
]
first_message: Union[Message, None] = None # type: ignore[no-any-unimported]
for i, chunk in enumerate(chunks):
sent = await client.send_message(
entity=entity,
message=chunk,
parse_mode="html",
reply_to=first_message.id if first_message else None,
)
# Store the first message to chain replies
if i == 0:
first_message = sent
sent_message_id = str(sent.id)
return (
f"Message sent successfully ({len(chunks)} chunks, first ID: {sent_message_id}):\n{message}"
)
else:
sent = await client.send_message(entity=entity, message=message, parse_mode="html")
return f"Message sent successfully (ID: {sent.id}):\n{message}"
except Exception as e:
return f"Message send failed, exception: {str(e)}"
Tool.__init__(
self,
name="telegram_send",
description="Sends a message to a personal channel, bot channel, group, or channel.",
func_or_tool=telegram_send_message,
)
@require_optional_import(["telethon"], "commsagent-telegram")
@export_module("autogen.tools.experimental")
class TelegramRetrieveTool(BaseTelegramTool, Tool):
"""Retrieves messages from a Telegram channel."""
def __init__(self, *, api_id: str, api_hash: str, chat_id: str) -> None:
"""
Initialize the TelegramRetrieveTool.
Args:
api_id: Telegram API ID from https://my.telegram.org/apps.
api_hash: Telegram API hash from https://my.telegram.org/apps.
chat_id: The ID of the chat to retrieve messages from (Channel, Group, Bot Chat ID).
"""
BaseTelegramTool.__init__(self, api_id, api_hash, "telegram_retrieve_session")
self._chat_id = chat_id
async def telegram_retrieve_messages(
chat_id: Annotated[str, Depends(on(chat_id))],
messages_since: Annotated[
Union[str, None],
"Date to retrieve messages from (ISO format) OR message ID. If None, retrieves latest messages.",
] = None,
maximum_messages: Annotated[
Union[int, None], "Maximum number of messages to retrieve. If None, retrieves all messages since date."
] = None,
search: Annotated[Union[str, None], "Optional string to search for in messages."] = None,
) -> Any:
"""
Retrieves messages from a Telegram chat.
Args:
chat_id: The ID of the chat. (uses dependency injection)
messages_since: ISO format date string OR message ID to retrieve messages from.
maximum_messages: Maximum number of messages to retrieve.
search: Optional string to search for in messages.
"""
try:
client = self._get_client()
async with client:
# Initialize and cache the entity
entity = await self._initialize_entity(client, chat_id)
# Setup retrieval parameters
params = {
"entity": entity,
"limit": maximum_messages if maximum_messages else None,
"search": search if search else None,
"filter": InputMessagesFilterEmpty(),
"wait_time": None, # No wait time between requests
}
# Handle messages_since parameter
if messages_since:
try:
# Try to parse as message ID first
msg_id = int(messages_since)
params["min_id"] = msg_id
except ValueError:
# Not a message ID, try as ISO date
try:
date = datetime.fromisoformat(messages_since.replace("Z", "+00:00"))
params["offset_date"] = date
params["reverse"] = (
True # Need this because the date gets messages before a certain date by default
)
except ValueError:
return {
"error": "Invalid messages_since format. Please provide either a message ID or ISO format date (e.g., '2025-01-25T00:00:00Z')"
}
# Retrieve messages
messages = []
count = 0
# For bot users, we need to get both sent and received messages
if isinstance(self._get_peer_from_id(chat_id), PeerUser):
print(f"Retrieving messages for bot chat {chat_id}")
async for message in client.iter_messages(**params):
count += 1
messages.append({
"id": str(message.id),
"date": message.date.isoformat(),
"from_id": str(message.from_id) if message.from_id else None,
"text": message.text,
"reply_to_msg_id": str(message.reply_to_msg_id) if message.reply_to_msg_id else None,
"forward_from": str(message.forward.from_id) if message.forward else None,
"edit_date": message.edit_date.isoformat() if message.edit_date else None,
"media": bool(message.media),
"entities": [
{"type": e.__class__.__name__, "offset": e.offset, "length": e.length}
for e in message.entities
]
if message.entities
else None,
})
# Check if we've hit the maximum
if maximum_messages and len(messages) >= maximum_messages:
break
return {
"message_count": len(messages),
"messages": messages,
"start_time": messages_since or "latest",
}
except Exception as e:
return f"Message retrieval failed, exception: {str(e)}"
Tool.__init__(
self,
name="telegram_retrieve",
description="Retrieves messages from a Telegram chat based on datetime/message ID and/or number of latest messages.",
func_or_tool=telegram_retrieve_messages,
)

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .perplexity_search import PerplexitySearchTool
__all__ = ["PerplexitySearchTool"]

View File

@@ -0,0 +1,260 @@
"""
Module: perplexity_search_tool
Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
SPDX-License-Identifier: Apache-2.0
This module provides classes for interacting with the Perplexity AI search API.
It defines data models for responses and a tool for executing web and conversational searches.
"""
import json
import os
from typing import Any, Optional, Union
import requests
from pydantic import BaseModel, ValidationError
from autogen.tools import Tool
class Message(BaseModel):
"""
Represents a message in the chat conversation.
Attributes:
role (str): The role of the message sender (e.g., "system", "user").
content (str): The text content of the message.
"""
role: str
content: str
class Usage(BaseModel):
"""
Model representing token usage details.
Attributes:
prompt_tokens (int): The number of tokens used for the prompt.
completion_tokens (int): The number of tokens generated in the completion.
total_tokens (int): The total number of tokens (prompt + completion).
search_context_size (str): The size context used in the search (e.g., "high").
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
search_context_size: str
class Choice(BaseModel):
"""
Represents one choice in the response from the Perplexity API.
Attributes:
index (int): The index of this choice.
finish_reason (str): The reason why the API finished generating this choice.
message (Message): The message object containing the response text.
"""
index: int
finish_reason: str
message: Message
class PerplexityChatCompletionResponse(BaseModel):
"""
Represents the full chat completion response from the Perplexity API.
Attributes:
id (str): Unique identifier for the response.
model (str): The model name used for generating the response.
created (int): Timestamp when the response was created.
usage (Usage): Token usage details.
citations (list[str]): list of citation strings included in the response.
object (str): Type of the response object.
choices (list[Choice]): list of choices returned by the API.
"""
id: str
model: str
created: int
usage: Usage
citations: list[str]
object: str
choices: list[Choice]
class SearchResponse(BaseModel):
"""
Represents the response from a search query.
Attributes:
content (Optional[str]): The textual content returned from the search.
citations (Optional[list[str]]): A list of citation URLs relevant to the search result.
error (Optional[str]): An error message if the search failed.
"""
content: Union[str, None]
citations: Union[list[str], None]
error: Union[str, None]
class PerplexitySearchTool(Tool):
"""
Tool for interacting with the Perplexity AI search API.
This tool uses the Perplexity API to perform web search, news search,
and conversational search, returning concise and precise responses.
Attributes:
url (str): API endpoint URL.
model (str): Name of the model to be used.
api_key (str): API key for authenticating with the Perplexity API.
max_tokens (int): Maximum tokens allowed for the API response.
search_domain_filters (Optional[list[str]]): Optional list of domain filters for the search.
"""
def __init__(
self,
model: str = "sonar",
api_key: Optional[str] = None,
max_tokens: int = 1000,
search_domain_filter: Optional[list[str]] = None,
):
"""
Initializes a new instance of the PerplexitySearchTool.
Args:
model (str, optional): The model to use. Defaults to "sonar".
api_key (Optional[str], optional): API key for authentication.
max_tokens (int, optional): Maximum number of tokens for the response. Defaults to 1000.
search_domain_filter (Optional[list[str]], optional): list of domain filters to restrict search.
Raises:
ValueError: If the API key is missing, the model is empty, max_tokens is not positive,
or if search_domain_filter is not a list when provided.
"""
self.api_key = api_key or os.getenv("PERPLEXITY_API_KEY")
self._validate_tool_config(model, self.api_key, max_tokens, search_domain_filter)
self.url = "https://api.perplexity.ai/chat/completions"
self.model = model
self.max_tokens = max_tokens
self.search_domain_filters = search_domain_filter
super().__init__(
name="perplexity-search",
description="Perplexity AI search tool for web search, news search, and conversational search "
"for finding answers to everyday questions, conducting in-depth research and analysis.",
func_or_tool=self.search,
)
@staticmethod
def _validate_tool_config(
model: str, api_key: Union[str, None], max_tokens: int, search_domain_filter: Union[list[str], None]
) -> None:
"""
Validates the configuration parameters for the search tool.
Args:
model (str): The model to use.
api_key (Union[str, None]): The API key for authentication.
max_tokens (int): Maximum tokens allowed.
search_domain_filter (Union[list[str], None]): Domain filters for search.
Raises:
ValueError: If the API key is missing, model is empty, max_tokens is not positive,
or search_domain_filter is not a list.
"""
if not api_key:
raise ValueError("Perplexity API key is missing")
if not model:
raise ValueError("model cannot be empty")
if max_tokens <= 0:
raise ValueError("max_tokens must be positive")
if search_domain_filter is not None and not isinstance(search_domain_filter, list):
raise ValueError("search_domain_filter must be a list")
def _execute_query(self, payload: dict[str, Any]) -> "PerplexityChatCompletionResponse":
"""
Executes a query by sending a POST request to the Perplexity API.
Args:
payload (dict[str, Any]): The payload to send in the API request.
Returns:
PerplexityChatCompletionResponse: Parsed response from the Perplexity API.
Raises:
RuntimeError: If there is a network error, HTTP error, JSON parsing error, or if the response
cannot be parsed into a PerplexityChatCompletionResponse.
"""
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.request("POST", self.url, json=payload, headers=headers, timeout=10)
try:
response.raise_for_status()
except requests.exceptions.Timeout as e:
raise RuntimeError(
f"Perplexity API => Request timed out: {response.text}. Status code: {response.status_code}"
) from e
except requests.exceptions.HTTPError as e:
raise RuntimeError(
f"Perplexity API => HTTP error occurred: {response.text}. Status code: {response.status_code}"
) from e
except requests.exceptions.RequestException as e:
raise RuntimeError(
f"Perplexity API => Error during request: {response.text}. Status code: {response.status_code}"
) from e
try:
response_json = response.json()
except json.JSONDecodeError as e:
raise RuntimeError(f"Perplexity API => Invalid JSON response received. Error: {e}") from e
try:
# This may raise a pydantic.ValidationError if the response structure is not as expected.
perp_resp = PerplexityChatCompletionResponse(**response_json)
except ValidationError as e:
raise RuntimeError("Perplexity API => Validation error when parsing API response: " + str(e)) from e
except Exception as e:
raise RuntimeError(
"Perplexity API => Failed to parse API response into PerplexityChatCompletionResponse: " + str(e)
) from e
return perp_resp
def search(self, query: str) -> "SearchResponse":
"""
Perform a search query using the Perplexity AI API.
Constructs the payload, executes the query, and parses the response to return
a concise search result along with any provided citations.
Args:
query (str): The search query.
Returns:
SearchResponse: A model containing the search result content and citations.
Raises:
ValueError: If the search query is invalid.
RuntimeError: If there is an error during the search process.
"""
if not query or not isinstance(query, str):
raise ValueError("A valid non-empty query string must be provided.")
payload = {
"model": self.model,
"messages": [{"role": "system", "content": "Be precise and concise."}, {"role": "user", "content": query}],
"max_tokens": self.max_tokens,
"search_domain_filter": self.search_domain_filters,
"web_search_options": {"search_context_size": "high"},
}
try:
perplexity_response = self._execute_query(payload)
content = perplexity_response.choices[0].message.content
citations = perplexity_response.citations
return SearchResponse(content=content, citations=citations, error=None)
except Exception as e:
# Return a SearchResponse with an error message if something goes wrong.
return SearchResponse(content=None, citations=None, error=f"{e}")

View File

@@ -0,0 +1,10 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from .reliable import ReliableTool, ReliableToolError, SuccessfulExecutionParameters, ToolExecutionDetails
__all__ = ["ReliableTool", "ReliableToolError", "SuccessfulExecutionParameters", "ToolExecutionDetails"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .tavily_search import TavilySearchTool
__all__ = ["TavilySearchTool"]

View File

@@ -0,0 +1,183 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Annotated, Any, Optional, Union
from ....doc_utils import export_module
from ....import_utils import optional_import_block, require_optional_import
from ....llm_config import LLMConfig
from ... import Depends, Tool
from ...dependency_injection import on
with optional_import_block():
from tavily import TavilyClient
@require_optional_import(
[
"tavily",
],
"tavily",
)
def _execute_tavily_query(
query: str,
tavily_api_key: str,
search_depth: str = "basic",
topic: str = "general",
include_answer: str = "basic",
include_raw_content: bool = False,
include_domains: list[str] = [],
num_results: int = 5,
) -> Any:
"""
Execute a search query using the Tavily API.
Args:
query (str): The search query string.
tavily_api_key (str): The API key for Tavily.
search_depth (str, optional): The depth of the search ('basic' or 'advanced'). Defaults to "basic".
topic (str, optional): The topic of the search. Defaults to "general".
include_answer (str, optional): Whether to include an AI-generated answer ('basic' or 'advanced'). Defaults to "basic".
include_raw_content (bool, optional): Whether to include raw content in the results. Defaults to False.
include_domains (list[str], optional): A list of domains to include in the search. Defaults to [].
num_results (int, optional): The maximum number of results to return. Defaults to 5.
Returns:
Any: The raw response object from the Tavily API client.
"""
tavily_client = TavilyClient(api_key=tavily_api_key)
return tavily_client.search(
query=query,
search_depth=search_depth,
topic=topic,
include_answer=include_answer,
include_raw_content=include_raw_content,
include_domains=include_domains,
max_results=num_results,
)
def _tavily_search(
query: str,
tavily_api_key: str,
search_depth: str = "basic",
topic: str = "general",
include_answer: str = "basic",
include_raw_content: bool = False,
include_domains: list[str] = [],
num_results: int = 5,
) -> list[dict[str, Any]]:
"""
Perform a Tavily search and format the results.
This function takes search parameters, executes the query using `_execute_tavily_query`,
and formats the results into a list of dictionaries containing title, link, and snippet.
Args:
query (str): The search query string.
tavily_api_key (str): The API key for Tavily.
search_depth (str, optional): The depth of the search ('basic' or 'advanced'). Defaults to "basic".
topic (str, optional): The topic of the search. Defaults to "general".
include_answer (str, optional): Whether to include an AI-generated answer ('basic' or 'advanced'). Defaults to "basic".
include_raw_content (bool, optional): Whether to include raw content in the results. Defaults to False.
include_domains (list[str], optional): A list of domains to include in the search. Defaults to [].
num_results (int, optional): The maximum number of results to return. Defaults to 5.
Returns:
list[dict[str, Any]]: A list of dictionaries, where each dictionary represents a search result
with keys 'title', 'link', and 'snippet'. Returns an empty list if no results are found.
"""
res = _execute_tavily_query(
query=query,
tavily_api_key=tavily_api_key,
search_depth=search_depth,
topic=topic,
include_answer=include_answer,
include_raw_content=include_raw_content,
include_domains=include_domains,
num_results=num_results,
)
return [
{"title": item.get("title", ""), "link": item.get("url", ""), "snippet": item.get("content", "")}
for item in res.get("results", [])
]
@export_module("autogen.tools.experimental")
class TavilySearchTool(Tool):
"""
TavilySearchTool is a tool that uses the Tavily Search API to perform a search.
This tool allows agents to leverage the Tavily search engine for information retrieval.
It requires a Tavily API key, which can be provided during initialization or set as
an environment variable `TAVILY_API_KEY`.
Attributes:
tavily_api_key (str): The API key used for authenticating with the Tavily API.
"""
def __init__(
self, *, llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None, tavily_api_key: Optional[str] = None
):
"""
Initializes the TavilySearchTool.
Args:
llm_config (Optional[Union[LLMConfig, dict[str, Any]]]): LLM configuration. (Currently unused but kept for potential future integration).
tavily_api_key (Optional[str]): The API key for the Tavily Search API. If not provided,
it attempts to read from the `TAVILY_API_KEY` environment variable.
Raises:
ValueError: If `tavily_api_key` is not provided either directly or via the environment variable.
"""
self.tavily_api_key = tavily_api_key or os.getenv("TAVILY_API_KEY")
if self.tavily_api_key is None:
raise ValueError("tavily_api_key must be provided either as an argument or via TAVILY_API_KEY env var")
def tavily_search(
query: Annotated[str, "The search query."],
tavily_api_key: Annotated[Optional[str], Depends(on(self.tavily_api_key))],
search_depth: Annotated[Optional[str], "Either 'advanced' or 'basic'"] = "basic",
include_answer: Annotated[Optional[str], "Either 'advanced' or 'basic'"] = "basic",
include_raw_content: Annotated[Optional[bool], "Include the raw contents"] = False,
include_domains: Annotated[Optional[list[str]], "Specific web domains to search"] = [],
num_results: Annotated[int, "The number of results to return."] = 5,
) -> list[dict[str, Any]]:
"""
Performs a search using the Tavily API and returns formatted results.
Args:
query: The search query string.
tavily_api_key: The API key for Tavily (injected dependency).
search_depth: The depth of the search ('basic' or 'advanced'). Defaults to "basic".
include_answer: Whether to include an AI-generated answer ('basic' or 'advanced'). Defaults to "basic".
include_raw_content: Whether to include raw content in the results. Defaults to False.
include_domains: A list of domains to include in the search. Defaults to [].
num_results: The maximum number of results to return. Defaults to 5.
Returns:
A list of dictionaries, each containing 'title', 'link', and 'snippet' of a search result.
Raises:
ValueError: If the Tavily API key is not available.
"""
if tavily_api_key is None:
raise ValueError("Tavily API key is missing.")
return _tavily_search(
query=query,
tavily_api_key=tavily_api_key,
search_depth=search_depth or "basic",
include_answer=include_answer or "basic",
include_raw_content=include_raw_content or False,
include_domains=include_domains or [],
num_results=num_results,
)
super().__init__(
name="tavily_search",
description="Use the Tavily Search API to perform a search.",
func_or_tool=tavily_search,
)

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .web_search_preview import WebSearchPreviewTool
__all__ = ["WebSearchPreviewTool"]

View File

@@ -0,0 +1,114 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import copy
import logging
import os
from typing import Annotated, Any, Literal, Optional, Type, Union
from pydantic import BaseModel
from ....doc_utils import export_module
from ....import_utils import optional_import_block, require_optional_import
from ....llm_config import LLMConfig
from ... import Tool
with optional_import_block():
from openai import OpenAI
from openai.types.responses import WebSearchToolParam
from openai.types.responses.web_search_tool import UserLocation
@require_optional_import("openai>=1.66.2", "openai")
@export_module("autogen.tools.experimental")
class WebSearchPreviewTool(Tool):
"""WebSearchPreviewTool is a tool that uses OpenAI's web_search_preview tool to perform a search."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
search_context_size: Literal["low", "medium", "high"] = "medium",
user_location: Optional[dict[str, str]] = None,
instructions: Optional[str] = None,
text_format: Optional[Type[BaseModel]] = None,
):
"""Initialize the WebSearchPreviewTool.
Args:
llm_config: The LLM configuration to use. This should be a dictionary
containing the model name and other parameters.
search_context_size: The size of the search context. One of `low`, `medium`, or `high`.
`medium` is the default.
user_location: The location of the user. This should be a dictionary containing
the city, country, region, and timezone.
instructions: Inserts a system (or developer) message as the first item in the model's context.
text_format: The format of the text to be returned. This should be a subclass of `BaseModel`.
The default is `None`, which means the text will be returned as a string.
"""
self.web_search_tool_param = WebSearchToolParam(
type="web_search_preview",
search_context_size=search_context_size,
user_location=UserLocation(**user_location) if user_location else None, # type: ignore[typeddict-item]
)
self.instructions = instructions
self.text_format = text_format
if isinstance(llm_config, LLMConfig):
llm_config = llm_config.model_dump()
llm_config = copy.deepcopy(llm_config)
if "config_list" not in llm_config:
raise ValueError("llm_config must contain 'config_list' key")
# Find first OpenAI model which starts with "gpt-4"
self.model = None
self.api_key = None
for model in llm_config["config_list"]:
if model["model"].startswith("gpt-4") and model.get("api_type", "openai") == "openai":
self.model = model["model"]
self.api_key = model.get("api_key", os.getenv("OPENAI_API_KEY"))
break
if self.model is None:
raise ValueError(
"No OpenAI model starting with 'gpt-4' found in llm_config, other models do not support web_search_preview"
)
if not self.model.startswith("gpt-4.1") and not self.model.startswith("gpt-4o-search-preview"):
logging.warning(
f"We recommend using a model starting with 'gpt-4.1' or 'gpt-4o-search-preview' for web_search_preview, but found {self.model}. "
"This may result in suboptimal performance."
)
def web_search_preview(
query: Annotated[str, "The search query. Add all relevant context to the query."],
) -> Union[str, Optional[BaseModel]]:
client = OpenAI()
if not self.text_format:
response = client.responses.create(
model=self.model, # type: ignore[arg-type]
tools=[self.web_search_tool_param],
input=query,
instructions=self.instructions,
)
return response.output_text
else:
response = client.responses.parse(
model=self.model, # type: ignore[arg-type]
tools=[self.web_search_tool_param],
input=query,
instructions=self.instructions,
text_format=self.text_format,
)
return response.output_parsed
super().__init__(
name="web_search_preview",
description="Tool used to perform a web search. It can be used as google search or directly searching a specific website.",
func_or_tool=web_search_preview,
)

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .wikipedia import WikipediaPageLoadTool, WikipediaQueryRunTool
__all__ = ["WikipediaPageLoadTool", "WikipediaQueryRunTool"]

View File

@@ -0,0 +1,287 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Union
import requests
from pydantic import BaseModel
from autogen.import_utils import optional_import_block, require_optional_import
from autogen.tools import Tool
with optional_import_block():
import wikipediaapi
# Maximum allowed length for a query string.
MAX_QUERY_LENGTH = 300
# Maximum number of pages to retrieve from a search.
MAX_PAGE_RETRIEVE = 100
# Maximum number of characters to return from a Wikipedia page.
MAX_ARTICLE_LENGTH = 10000
class Document(BaseModel):
"""Pydantic model representing a Wikipedia document.
Attributes:
page_content (str): Textual content of the Wikipedia page
(possibly truncated).
metadata (dict[str, str]): Additional info, including:
- source URL
- title
- pageid
- timestamp
- word count
- size
"""
page_content: str
metadata: dict[str, str]
class WikipediaClient:
"""Client for interacting with the Wikipedia API.
Supports searching and page retrieval on a specified language edition.
Public methods:
search(query: str, limit: int) -> list[dict[str, Any]]
get_page(title: str) -> Optional[wikipediaapi.WikipediaPage]
Attributes:
base_url (str): URL of the MediaWiki API endpoint.
headers (dict[str, str]): HTTP headers, including User-Agent.
wiki (wikipediaapi.Wikipedia): Low-level Wikipedia API client.
"""
def __init__(self, language: str = "en", tool_name: str = "wikipedia-client") -> None:
"""Initialize the WikipediaClient.
Args:
language (str): ISO code of the Wikipedia edition (e.g., 'en', 'es').
tool_name (str): Identifier for User-Agent header.
"""
self.base_url = f"https://{language}.wikipedia.org/w/api.php"
self.headers = {"User-Agent": f"autogen.Agent ({tool_name})"}
self.wiki = wikipediaapi.Wikipedia(
language=language,
extract_format=wikipediaapi.ExtractFormat.WIKI,
user_agent=f"autogen.Agent ({tool_name})",
)
def search(self, query: str, limit: int = 3) -> Any:
"""Search Wikipedia for pages matching a query string.
Args:
query (str): The search keywords.
limit (int): Max number of results to return.
Returns:
list[dict[str, Any]]: Each dict has keys:
- 'title' (str)
- 'size' (int)
- 'wordcount' (int)
- 'timestamp' (str)
Raises:
requests.HTTPError: If the HTTP request to the API fails.
"""
params = {
"action": "query",
"format": "json",
"list": "search",
"srsearch": query,
"srlimit": str(limit),
"srprop": "size|wordcount|timestamp",
}
response = requests.get(url=self.base_url, params=params, headers=self.headers)
response.raise_for_status()
data = response.json()
search_data = data.get("query", {}).get("search", [])
return search_data
def get_page(self, title: str) -> Optional[Any]:
"""Retrieve a WikipediaPage object by title.
Args:
title (str): Title of the Wikipedia page.
Returns:
wikipediaapi.WikipediaPage | None: The page object if it exists,
otherwise None.
Raises:
wikipediaapi.WikipediaException: On lowerlevel API errors.
"""
page = self.wiki.page(title)
if not page.exists():
return None
return page
@require_optional_import(["wikipediaapi"], "wikipedia")
class WikipediaQueryRunTool(Tool):
"""Tool for querying Wikipedia and returning summarized page results.
This tool uses the `wikipediaapi` package to perform searches
against a specified language edition of Wikipedia and returns
up to `top_k` page summaries.
Public methods:
query_run(query: str) -> list[str] | str
Attributes:
language (str): Language code for the Wikipedia edition (e.g., 'en', 'es').
top_k (int): Max number of page summaries returned (≤ MAX_PAGE_RETRIEVE).
verbose (bool): If True, enables debug logging to stdout.
wiki_cli (WikipediaClient): Internal client for Wikipedia API calls.
"""
def __init__(self, language: str = "en", top_k: int = 3, verbose: bool = False) -> None:
"""Initialize the WikipediaQueryRunTool.
Args:
language (str): ISO code of the Wikipedia edition to query.
top_k (int): Desired number of summaries (capped by MAX_PAGE_RETRIEVE).
verbose (bool): If True, print debug information during searches.
"""
self.language = language
self.tool_name = "wikipedia-query-run"
self.wiki_cli = WikipediaClient(language, self.tool_name)
self.top_k = min(top_k, MAX_PAGE_RETRIEVE)
self.verbose = verbose
super().__init__(
name=self.tool_name,
description="Run a Wikipedia query and return page summaries.",
func_or_tool=self.query_run,
)
def query_run(self, query: str) -> Union[list[str], str]:
"""Search Wikipedia and return formatted page summaries.
Truncates `query` to MAX_QUERY_LENGTH before searching.
Args:
query (str): Search term(s) to look up in Wikipedia.
Returns:
list[str]: Each element is "Page: <title>\nSummary: <text>".
str: Error message if no results are found or on exception.
Note:
Automatically handles API exceptions and returns error strings for robust operation
"""
try:
if self.verbose:
print(f"INFO\t [{self.tool_name}] search query='{query[:MAX_QUERY_LENGTH]}' top_k={self.top_k}")
search_results = self.wiki_cli.search(query[:MAX_QUERY_LENGTH], limit=self.top_k)
summaries: list[str] = []
for item in search_results:
title = item["title"]
page = self.wiki_cli.get_page(title)
# Only format the summary if the page exists and has a summary.
if page is not None and page.summary:
summary = f"Page: {title}\nSummary: {page.summary}"
summaries.append(summary)
if not summaries:
return "No good Wikipedia Search Result was found"
return summaries
except Exception as e:
return f"wikipedia search failed: {str(e)}"
@require_optional_import(["wikipediaapi"], "wikipedia")
class WikipediaPageLoadTool(Tool):
"""
A tool to load up to N characters of Wikipedia page content along with metadata.
This tool uses a language-specific Wikipedia client to search for relevant articles
and returns a list of Document objects containing truncated page content and metadata
(source URL, title, page ID, timestamp, word count, and size). Ideal for agents
requiring structured Wikipedia data for research, summarization, or contextual enrichment.
Attributes:
language (str): Wikipedia language code (default: "en").
top_k (int): Maximum number of pages to retrieve per query (default: 3).
truncate (int): Maximum number of characters of content per page (default: 4000).
verbose (bool): If True, prints debug information (default: False).
tool_name (str): Identifier used in User-Agent header.
wiki_cli (WikipediaClient): Client for interacting with the Wikipedia API.
"""
def __init__(self, language: str = "en", top_k: int = 3, truncate: int = 4000, verbose: bool = False) -> None:
"""
Initializes the WikipediaPageLoadTool with configurable language, result count, and content length.
Args:
language (str): The language code for the Wikipedia edition (default is "en").
top_k (int): The maximum number of pages to retrieve per query (default is 3;
capped at MAX_PAGE_RETRIEVE).
truncate (int): The maximum number of characters to extract from each page (default is 4000;
capped at MAX_ARTICLE_LENGTH).
verbose (bool): If True, enables verbose/debug logging (default is False).
"""
self.language = language
self.top_k = min(top_k, MAX_PAGE_RETRIEVE)
self.truncate = min(truncate, MAX_ARTICLE_LENGTH)
self.verbose = verbose
self.tool_name = "wikipedia-page-load"
self.wiki_cli = WikipediaClient(language, self.tool_name)
super().__init__(
name=self.tool_name,
description=(
"Search Wikipedia for relevant pages using a language-specific client. "
"Returns a list of documents with truncated content and metadata including title, URL, "
"page ID, timestamp, word count, and page size. Configure number of results with the 'top_k' parameter "
"and content length with 'truncate'. Useful for research, summarization, or contextual enrichment."
),
func_or_tool=self.content_search,
)
def content_search(self, query: str) -> Union[list[Document], str]:
"""
Executes a Wikipedia search and returns page content plus metadata.
Args:
query (str): The search term to query Wikipedia.
Returns:
Union[list[Document], str]:
- list[Document]: Documents with up to `truncate` characters of page text
and metadata if pages are found.
- str: Error message if the search fails or no pages are found.
Notes:
- Errors are caught internally and returned as strings.
- If no matching pages have text content, returns
"No good Wikipedia Search Result was found".
"""
try:
if self.verbose:
print(f"INFO\t [{self.tool_name}] search query='{query[:MAX_QUERY_LENGTH]}' top_k={self.top_k}")
search_results = self.wiki_cli.search(query[:MAX_QUERY_LENGTH], limit=self.top_k)
docs: list[Document] = []
for item in search_results:
page = self.wiki_cli.get_page(item["title"])
# Only process pages that exist and have text content.
if page is not None and page.text:
document = Document(
page_content=page.text[: self.truncate],
metadata={
"source": f"https://{self.language}.wikipedia.org/?curid={item['pageid']}",
"title": item["title"],
"pageid": str(item["pageid"]),
"timestamp": str(item["timestamp"]),
"wordcount": str(item["wordcount"]),
"size": str(item["size"]),
},
)
docs.append(document)
if not docs:
return "No good Wikipedia Search Result was found"
return docs
except Exception as e:
return f"wikipedia search failed: {str(e)}"

View File

@@ -0,0 +1,411 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import functools
import inspect
import json
from logging import getLogger
from typing import Annotated, Any, Callable, ForwardRef, Optional, TypeVar, Union
from packaging.version import parse
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import __version__ as pydantic_version
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import Literal, get_args, get_origin
from ..doc_utils import export_module
from .dependency_injection import Field as AG2Field
if parse(pydantic_version) < parse("2.10.2"):
from pydantic._internal._typing_extra import eval_type_lenient as try_eval_type
else:
from pydantic._internal._typing_extra import try_eval_type
__all__ = ["get_function_schema", "load_basemodels_if_needed", "serialize_to_str"]
logger = getLogger(__name__)
T = TypeVar("T")
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
"""Get the type annotation of a parameter.
Args:
annotation: The annotation of the parameter
globalns: The global namespace of the function
Returns:
The type annotation of the parameter
"""
if isinstance(annotation, AG2Field):
annotation = annotation.description
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation, _ = try_eval_type(annotation, globalns, globalns)
return annotation
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""Get the signature of a function with type annotations.
Args:
call: The function to get the signature for
Returns:
The signature of the function with type annotations
"""
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns),
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
"""Get the return annotation of a function.
Args:
call: The function to get the return annotation for
Returns:
The return annotation of the function
"""
signature = inspect.signature(call)
annotation = signature.return_annotation
if annotation is inspect.Signature.empty:
return None
globalns = getattr(call, "__globals__", {})
return get_typed_annotation(annotation, globalns)
def get_param_annotations(typed_signature: inspect.Signature) -> dict[str, Union[Annotated[type[Any], str], type[Any]]]:
"""Get the type annotations of the parameters of a function
Args:
typed_signature: The signature of the function with type annotations
Returns:
A dictionary of the type annotations of the parameters of the function
"""
return {
k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty
}
class Parameters(BaseModel):
"""Parameters of a function as defined by the OpenAI API"""
type: Literal["object"] = "object"
properties: dict[str, JsonSchemaValue]
required: list[str]
class Function(BaseModel):
"""A function as defined by the OpenAI API"""
description: Annotated[str, Field(description="Description of the function")]
name: Annotated[str, Field(description="Name of the function")]
parameters: Annotated[Parameters, Field(description="Parameters of the function")]
class ToolFunction(BaseModel):
"""A function under tool as defined by the OpenAI API."""
type: Literal["function"] = "function"
function: Annotated[Function, Field(description="Function under tool")]
def get_parameter_json_schema(k: str, v: Any, default_values: dict[str, Any]) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API
Args:
k: The name of the parameter
v: The type of the parameter
default_values: The default values of the parameters of the function
Returns:
A Pydanitc model for the parameter
"""
def type2description(k: str, v: Union[Annotated[type[Any], str], type[Any]]) -> str:
if not hasattr(v, "__metadata__"):
return k
# handles Annotated
retval = v.__metadata__[0]
if isinstance(retval, AG2Field):
return retval.description # type: ignore[return-value]
else:
raise ValueError(f"Invalid {retval} for parameter {k}, should be a DescriptionField, got {type(retval)}")
schema = TypeAdapter(v).json_schema()
if k in default_values:
dv = default_values[k]
schema["default"] = dv
schema["description"] = type2description(k, v)
return schema
def get_required_params(typed_signature: inspect.Signature) -> list[str]:
"""Get the required parameters of a function
Args:
typed_signature: The signature of the function as returned by inspect.signature
Returns:
A list of the required parameters of the function
"""
return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty]
def get_default_values(typed_signature: inspect.Signature) -> dict[str, Any]:
"""Get default values of parameters of a function
Args:
typed_signature: The signature of the function as returned by inspect.signature
Returns:
A dictionary of the default values of the parameters of the function
"""
return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty}
def get_parameters(
required: list[str],
param_annotations: dict[str, Union[Annotated[type[Any], str], type[Any]]],
default_values: dict[str, Any],
) -> Parameters:
"""Get the parameters of a function as defined by the OpenAI API
Args:
required: The required parameters of the function
param_annotations: The type annotations of the parameters of the function
default_values: The default values of the parameters of the function
Returns:
A Pydantic model for the parameters of the function
"""
return Parameters(
properties={
k: get_parameter_json_schema(k, v, default_values)
for k, v in param_annotations.items()
if v is not inspect.Signature.empty
},
required=required,
)
def get_missing_annotations(typed_signature: inspect.Signature, required: list[str]) -> tuple[set[str], set[str]]:
"""Get the missing annotations of a function
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
Args:
typed_signature: The signature of the function with type annotations
required: The required parameters of the function
Returns:
A set of the missing annotations of the function
"""
all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty}
missing = all_missing.intersection(set(required))
unannotated_with_default = all_missing.difference(missing)
return missing, unannotated_with_default
@export_module("autogen.tools")
def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> dict[str, Any]:
"""Get a JSON schema for a function as defined by the OpenAI API
Args:
f: The function to get the JSON schema for
name: The name of the function
description: The description of the function
Returns:
A JSON schema for the function
Raises:
TypeError: If the function is not annotated
Examples:
```python
def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None:
pass
get_function_schema(f, description="function f")
# {'type': 'function',
# 'function': {'description': 'function f',
# 'name': 'f',
# 'parameters': {'type': 'object',
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
# 'b': {'type': 'int', 'description': 'b'},
# 'c': {'type': 'float', 'description': 'Parameter c'}},
# 'required': ['a']}}}
```
"""
typed_signature = get_typed_signature(f)
required = get_required_params(typed_signature)
default_values = get_default_values(typed_signature)
param_annotations = get_param_annotations(typed_signature)
return_annotation = get_typed_return_annotation(f)
missing, unannotated_with_default = get_missing_annotations(typed_signature, required)
if return_annotation is None:
logger.warning(
f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is "
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
)
if unannotated_with_default != set():
unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)]
logger.warning(
f"The following parameters of the function '{f.__name__}' with default values are not annotated: "
+ f"{', '.join(unannotated_with_default_s)}."
)
if missing != set():
missing_s = [f"'{k}'" for k in sorted(missing)]
raise TypeError(
f"All parameters of the function '{f.__name__}' without default values must be annotated. "
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
)
fname = name if name else f.__name__
parameters = get_parameters(required, param_annotations, default_values=default_values)
function = ToolFunction(
function=Function(
description=description,
name=fname,
parameters=parameters,
)
)
return function.model_dump()
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[dict[str, Any], type[BaseModel]], BaseModel]]:
"""Get a function to load a parameter if it is a Pydantic model
Args:
t: The type annotation of the parameter
Returns:
A function to load the parameter if it is a Pydantic model, otherwise None
"""
origin = get_origin(t)
if origin is Annotated:
args = get_args(t)
if args:
return get_load_param_if_needed_function(args[0])
else:
# Invalid Annotated usage
return None
# Handle generic types (list[str], dict[str,Any], Union[...], etc.) or where t is not a type at all
# This means it's not a BaseModel subclass
if origin is not None or not isinstance(t, type):
return None
def load_base_model(v: dict[str, Any], model_type: type[BaseModel]) -> BaseModel:
return model_type(**v)
# Check if it's a class and a subclass of BaseModel
if issubclass(t, BaseModel):
return load_base_model
else:
return None
@export_module("autogen.tools")
def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator to load the parameters of a function if they are Pydantic models
Args:
func: The function with annotated parameters
Returns:
A function that loads the parameters before calling the original function
"""
# get the type annotations of the parameters
typed_signature = get_typed_signature(func)
param_annotations = get_param_annotations(typed_signature)
# get functions for loading BaseModels when needed based on the type annotations
kwargs_mapping_with_nones = {k: get_load_param_if_needed_function(t) for k, t in param_annotations.items()}
# remove the None values
kwargs_mapping = {k: f for k, f in kwargs_mapping_with_nones.items() if f is not None}
# a function that loads the parameters before calling the original function
@functools.wraps(func)
def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])
# call the original function
return func(*args, **kwargs)
@functools.wraps(func)
async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])
# call the original function
return await func(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return _a_load_parameters_if_needed
else:
return _load_parameters_if_needed
class _SerializableResult(BaseModel):
result: Any
@export_module("autogen.tools")
def serialize_to_str(x: Any) -> str:
if isinstance(x, str):
return x
if isinstance(x, BaseModel):
return x.model_dump_json()
retval_model = _SerializableResult(result=x)
try:
return str(retval_model.model_dump()["result"])
except Exception:
pass
# try json.dumps() and then just return str(x) if that fails too
try:
return json.dumps(x, ensure_ascii=False)
except Exception:
return str(x)

View File

@@ -0,0 +1,187 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from ..doc_utils import export_module
from ..tools.function_utils import get_function_schema
from .dependency_injection import ChatContext, get_context_params, inject_params
if TYPE_CHECKING:
from ..agentchat.conversable_agent import ConversableAgent
__all__ = ["Tool", "tool"]
@export_module("autogen.tools")
class Tool:
"""A class representing a Tool that can be used by an agent for various tasks.
This class encapsulates a tool with a name, description, and an executable function.
The tool can be registered with a ConversableAgent for use either with an LLM or for direct execution.
Attributes:
name (str): The name of the tool.
description (str): The description of the tool.
func_or_tool (Union[Tool, Callable[..., Any]]): The function or Tool instance to create a Tool from.
parameters_json_schema (Optional[ict[str, Any]]): A schema describing the parameters that the function accepts. If None, the schema will be generated from the function signature.
"""
def __init__(
self,
*,
name: Optional[str] = None,
description: Optional[str] = None,
func_or_tool: Union["Tool", Callable[..., Any]],
parameters_json_schema: Optional[dict[str, Any]] = None,
) -> None:
"""Create a new Tool object.
Args:
name (str): The name of the tool.
description (str): The description of the tool.
func_or_tool (Union[Tool, Callable[..., Any]]): The function or Tool instance to create a Tool from.
parameters_json_schema (Optional[dict[str, Any]]): A schema describing the parameters that the function accepts. If None, the schema will be generated from the function signature.
"""
if isinstance(func_or_tool, Tool):
self._name: str = name or func_or_tool.name
self._description: str = description or func_or_tool.description
self._func: Callable[..., Any] = func_or_tool.func
self._chat_context_param_names: list[str] = func_or_tool._chat_context_param_names
elif inspect.isfunction(func_or_tool) or inspect.ismethod(func_or_tool):
self._chat_context_param_names = get_context_params(func_or_tool, subclass=ChatContext)
self._func = inject_params(func_or_tool)
self._name = name or func_or_tool.__name__
self._description = description or func_or_tool.__doc__ or ""
else:
raise ValueError(
f"Parameter 'func_or_tool' must be a function, method or a Tool instance, it is '{type(func_or_tool)}' instead."
)
self._func_schema = (
{
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters_json_schema,
},
}
if parameters_json_schema
else None
)
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def func(self) -> Callable[..., Any]:
return self._func
def register_for_llm(self, agent: "ConversableAgent") -> None:
"""Registers the tool for use with a ConversableAgent's language model (LLM).
This method registers the tool so that it can be invoked by the agent during
interactions with the language model.
Args:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
if self._func_schema:
agent.update_tool_signature(self._func_schema, is_remove=False)
else:
agent.register_for_llm()(self)
def register_for_execution(self, agent: "ConversableAgent") -> None:
"""Registers the tool for direct execution by a ConversableAgent.
This method registers the tool so that it can be executed by the agent,
typically outside of the context of an LLM interaction.
Args:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
agent.register_for_execution()(self)
def register_tool(self, agent: "ConversableAgent") -> None:
"""Register a tool to be both proposed and executed by an agent.
Equivalent to calling both `register_for_llm` and `register_for_execution` with the same agent.
Note: This will not make the agent recommend and execute the call in the one step. If the agent
recommends the tool, it will need to be the next agent to speak in order to execute the tool.
Args:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
self.register_for_llm(agent)
self.register_for_execution(agent)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Execute the tool by calling its underlying function with the provided arguments.
Args:
*args: Positional arguments to pass to the tool
**kwargs: Keyword arguments to pass to the tool
Returns:
The result of executing the tool's function.
"""
return self._func(*args, **kwargs)
@property
def tool_schema(self) -> dict[str, Any]:
"""Get the schema for the tool.
This is the preferred way of handling function calls with OpeaAI and compatible frameworks.
"""
return get_function_schema(self.func, name=self.name, description=self.description)
@property
def function_schema(self) -> dict[str, Any]:
"""Get the schema for the function.
This is the old way of handling function calls with OpenAI and compatible frameworks.
It is provided for backward compatibility.
"""
schema = get_function_schema(self.func, name=self.name, description=self.description)
return schema["function"] # type: ignore[no-any-return]
@property
def realtime_tool_schema(self) -> dict[str, Any]:
"""Get the schema for the tool.
This is the preferred way of handling function calls with OpeaAI and compatible frameworks.
"""
schema = get_function_schema(self.func, name=self.name, description=self.description)
schema = {"type": schema["type"], **schema["function"]}
return schema
@export_module("autogen.tools")
def tool(name: Optional[str] = None, description: Optional[str] = None) -> Callable[[Callable[..., Any]], Tool]:
"""Decorator to create a Tool from a function.
Args:
name (str): The name of the tool.
description (str): The description of the tool.
Returns:
Callable[[Callable[..., Any]], Tool]: A decorator that creates a Tool from a function.
"""
def decorator(func: Callable[..., Any]) -> Tool:
return Tool(name=name, description=description, func_or_tool=func)
return decorator

View File

@@ -0,0 +1,86 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING
from ..doc_utils import export_module
from .tool import Tool
if TYPE_CHECKING:
from ..agentchat.conversable_agent import ConversableAgent
__all__ = ["Toolkit"]
@export_module("autogen.tools")
class Toolkit:
"""A class representing a set of tools that can be used by an agent for various tasks."""
def __init__(self, tools: list[Tool]) -> None:
"""Create a new Toolkit object.
Args:
tools (list[Tool]): The list of tools in the
"""
self.toolkit = {tool.name: tool for tool in tools}
@property
def tools(self) -> list[Tool]:
"""Get the list of tools in the set."""
return list(self.toolkit.values())
def register_for_llm(self, agent: "ConversableAgent") -> None:
"""Register the tools in the set with an LLM agent.
Args:
agent (ConversableAgent): The LLM agent to register the tools with.
"""
for tool in self.toolkit.values():
tool.register_for_llm(agent)
def register_for_execution(self, agent: "ConversableAgent") -> None:
"""Register the tools in the set with an agent for
Args:
agent (ConversableAgent): The agent to register the tools with.
"""
for tool in self.toolkit.values():
tool.register_for_execution(agent)
def get_tool(self, tool_name: str) -> Tool:
"""Get a tool from the set by name.
Args:
tool_name (str): The name of the tool to get.
Returns:
Tool: The tool with the given name.
"""
if tool_name in self.toolkit:
return self.toolkit[tool_name]
raise ValueError(f"Tool '{tool_name}' not found in Toolkit.")
def set_tool(self, tool: Tool) -> None:
"""Set a tool in the set.
Args:
tool (Tool): The tool to set.
"""
self.toolkit[tool.name] = tool
def remove_tool(self, tool_name: str) -> None:
"""Remove a tool from the set by name.
Args:
tool_name (str): The name of the tool to remove.
"""
if tool_name in self.toolkit:
del self.toolkit[tool_name]
else:
raise ValueError(f"Tool '{tool_name}' not found in Toolkit.")
def __len__(self) -> int:
"""Get the number of tools in the map."""
return len(self.toolkit)