CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .crewai import CrewAIInteroperability
from .interoperability import Interoperability
from .interoperable import Interoperable
from .langchain import LangChainChatModelFactory, LangChainInteroperability
from .litellm import LiteLLmConfigFactory
from .pydantic_ai import PydanticAIInteroperability
from .registry import register_interoperable_class
__all__ = [
"CrewAIInteroperability",
"Interoperability",
"Interoperable",
"LangChainChatModelFactory",
"LangChainInteroperability",
"LiteLLmConfigFactory",
"PydanticAIInteroperability",
"register_interoperable_class",
]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .crewai import CrewAIInteroperability
__all__ = ["CrewAIInteroperability"]

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import re
import sys
from typing import Any, Optional
from ...doc_utils import export_module
from ...import_utils import optional_import_block, require_optional_import
from ...tools import Tool
from ..registry import register_interoperable_class
__all__ = ["CrewAIInteroperability"]
def _sanitize_name(s: str) -> str:
return re.sub(r"\W|^(?=\d)", "_", s)
with optional_import_block():
from crewai.tools import BaseTool as CrewAITool
@register_interoperable_class("crewai")
@export_module("autogen.interop")
class CrewAIInteroperability:
"""A class implementing the `Interoperable` protocol for converting CrewAI tools
to a general `Tool` format.
This class takes a `CrewAITool` and converts it into a standard `Tool` object.
"""
@classmethod
@require_optional_import("crewai", "interop-crewai")
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
"""Converts a given CrewAI tool into a general `Tool` format.
This method ensures that the provided tool is a valid `CrewAITool`, sanitizes
the tool's name, processes its description, and prepares a function to interact
with the tool's arguments. It then returns a standardized `Tool` object.
Args:
tool (Any): The tool to convert, expected to be an instance of `CrewAITool`.
**kwargs (Any): Additional arguments, which are not supported by this method.
Returns:
Tool: A standardized `Tool` object converted from the CrewAI tool.
Raises:
ValueError: If the provided tool is not an instance of `CrewAITool`, or if
any additional arguments are passed.
"""
if not isinstance(tool, CrewAITool):
raise ValueError(f"Expected an instance of `crewai.tools.BaseTool`, got {type(tool)}")
if kwargs:
raise ValueError(f"The CrewAIInteroperability does not support any additional arguments, got {kwargs}")
# needed for type checking
crewai_tool: CrewAITool = tool # type: ignore[no-any-unimported]
name = _sanitize_name(crewai_tool.name)
description = (
crewai_tool.description.split("Tool Description: ")[-1]
+ " (IMPORTANT: When using arguments, put them all in an `args` dictionary)"
)
def func(args: crewai_tool.args_schema) -> Any: # type: ignore[no-any-unimported]
return crewai_tool.run(**args.model_dump())
return Tool(
name=name,
description=description,
func_or_tool=func,
)
@classmethod
def get_unsupported_reason(cls) -> Optional[str]:
if sys.version_info < (3, 10) or sys.version_info >= (3, 13):
return "This submodule is only supported for Python versions 3.10, 3.11, and 3.12"
with optional_import_block() as result:
import crewai.tools # noqa: F401
if not result.is_successful:
return "Please install `interop-crewai` extra to use this module:\n\n\tpip install ag2[interop-crewai]"
return None

View File

@@ -0,0 +1,71 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from ..doc_utils import export_module
from ..tools import Tool
from .interoperable import Interoperable
from .registry import InteroperableRegistry
__all__ = ["Interoperable"]
@export_module("autogen.interop")
class Interoperability:
"""A class to handle interoperability between different tool types.
This class allows the conversion of tools to various interoperability classes and provides functionality
for retrieving and registering interoperability classes.
"""
registry = InteroperableRegistry.get_instance()
@classmethod
def convert_tool(cls, *, tool: Any, type: str, **kwargs: Any) -> Tool:
"""Converts a given tool to an instance of a specified interoperability type.
Args:
tool (Any): The tool object to be converted.
type (str): The type of interoperability to convert the tool to.
**kwargs (Any): Additional arguments to be passed during conversion.
Returns:
Tool: The converted tool.
Raises:
ValueError: If the interoperability class for the provided type is not found.
"""
interop = cls.get_interoperability_class(type)
return interop.convert_tool(tool, **kwargs)
@classmethod
def get_interoperability_class(cls, type: str) -> type[Interoperable]:
"""Retrieves the interoperability class corresponding to the specified type.
Args:
type (str): The type of the interoperability class to retrieve.
Returns:
type[Interoperable]: The interoperability class type.
Raises:
ValueError: If no interoperability class is found for the provided type.
"""
supported_types = cls.registry.get_supported_types()
if type not in supported_types:
supported_types_formatted = ", ".join(["'t'" for t in supported_types])
raise ValueError(
f"Interoperability class {type} is not supported, supported types: {supported_types_formatted}"
)
return cls.registry.get_class(type)
@classmethod
def get_supported_types(cls) -> list[str]:
"""Returns a sorted list of all supported interoperability types.
Returns:
List[str]: A sorted list of strings representing the supported interoperability types.
"""
return sorted(cls.registry.get_supported_types())

View File

@@ -0,0 +1,46 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Protocol, runtime_checkable
from ..doc_utils import export_module
from ..tools import Tool
__all__ = ["Interoperable"]
@runtime_checkable
@export_module("autogen.interop")
class Interoperable(Protocol):
"""A Protocol defining the interoperability interface for tool conversion.
This protocol ensures that any class implementing it provides the method
`convert_tool` to convert a given tool into a desired format or type.
"""
@classmethod
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
"""Converts a given tool to a desired format or type.
This method should be implemented by any class adhering to the `Interoperable` protocol.
Args:
tool (Any): The tool object to be converted.
**kwargs (Any): Additional parameters to pass during the conversion process.
Returns:
Tool: The converted tool in the desired format or type.
"""
...
@classmethod
def get_unsupported_reason(cls) -> Optional[str]:
"""Returns the reason for the tool being unsupported.
This method should be implemented by any class adhering to the `Interoperable` protocol.
Returns:
str: The reason for the interoperability class being unsupported.
"""
...

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .langchain_chat_model_factory import LangChainChatModelFactory
from .langchain_tool import LangChainInteroperability
__all__ = ["LangChainChatModelFactory", "LangChainInteroperability"]

View File

@@ -0,0 +1,155 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Callable, TypeVar, Union
from ...doc_utils import export_module
from ...import_utils import optional_import_block, require_optional_import
from ...llm_config import LLMConfig
from ...oai import get_first_llm_config
with optional_import_block():
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
__all__ = ["LangChainChatModelFactory"]
T = TypeVar("T", bound="LangChainChatModelFactory")
@require_optional_import(
["langchain_anthropic", "langchain_google_genai", "langchain_ollama", "langchain_openai", "langchain_core"],
"browser-use",
except_for=["__init__", "register_factory"],
)
@export_module("autogen.interop")
class LangChainChatModelFactory(ABC):
_factories: set["LangChainChatModelFactory"] = set()
@classmethod
def create_base_chat_model(cls, llm_config: Union[LLMConfig, dict[str, Any]]) -> "BaseChatModel": # type: ignore [no-any-unimported]
first_llm_config = get_first_llm_config(llm_config)
for factory in LangChainChatModelFactory._factories:
if factory.accepts(first_llm_config):
return factory.create(first_llm_config)
raise ValueError("Could not find a factory for the given config.")
@classmethod
def register_factory(cls) -> Callable[[type[T]], type[T]]:
def decorator(factory: type[T]) -> type[T]:
cls._factories.add(factory())
return factory
return decorator
@classmethod
def prepare_config(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
for pop_keys in ["api_type", "response_format"]:
first_llm_config.pop(pop_keys, None)
return first_llm_config
@classmethod
@abstractmethod
def create(cls, first_llm_config: dict[str, Any]) -> "BaseChatModel": # type: ignore [no-any-unimported]
...
@classmethod
@abstractmethod
def get_api_type(cls) -> str: ...
@classmethod
def accepts(cls, first_llm_config: dict[str, Any]) -> bool:
return first_llm_config.get("api_type", "openai") == cls.get_api_type() # type: ignore [no-any-return]
@LangChainChatModelFactory.register_factory()
class ChatOpenAIFactory(LangChainChatModelFactory):
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> "ChatOpenAI": # type: ignore [no-any-unimported]
first_llm_config = cls.prepare_config(first_llm_config)
return ChatOpenAI(**first_llm_config)
@classmethod
def get_api_type(cls) -> str:
return "openai"
@LangChainChatModelFactory.register_factory()
class DeepSeekFactory(ChatOpenAIFactory):
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> "ChatOpenAI": # type: ignore [no-any-unimported]
if "base_url" not in first_llm_config:
raise ValueError("base_url is required for deepseek api type.")
return super().create(first_llm_config)
@classmethod
def get_api_type(cls) -> str:
return "deepseek"
@LangChainChatModelFactory.register_factory()
class ChatAnthropicFactory(LangChainChatModelFactory):
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> "ChatAnthropic": # type: ignore [no-any-unimported]
first_llm_config = cls.prepare_config(first_llm_config)
return ChatAnthropic(**first_llm_config)
@classmethod
def get_api_type(cls) -> str:
return "anthropic"
@LangChainChatModelFactory.register_factory()
class ChatGoogleGenerativeAIFactory(LangChainChatModelFactory):
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> "ChatGoogleGenerativeAI": # type: ignore [no-any-unimported]
first_llm_config = cls.prepare_config(first_llm_config)
return ChatGoogleGenerativeAI(**first_llm_config)
@classmethod
def get_api_type(cls) -> str:
return "google"
@LangChainChatModelFactory.register_factory()
class AzureChatOpenAIFactory(LangChainChatModelFactory):
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> "AzureChatOpenAI": # type: ignore [no-any-unimported]
first_llm_config = cls.prepare_config(first_llm_config)
for param in ["base_url", "api_version"]:
if param not in first_llm_config:
raise ValueError(f"{param} is required for azure api type.")
first_llm_config["azure_endpoint"] = first_llm_config.pop("base_url")
return AzureChatOpenAI(**first_llm_config)
@classmethod
def get_api_type(cls) -> str:
return "azure"
@LangChainChatModelFactory.register_factory()
class ChatOllamaFactory(LangChainChatModelFactory):
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> "ChatOllama": # type: ignore [no-any-unimported]
first_llm_config = cls.prepare_config(first_llm_config)
first_llm_config["base_url"] = first_llm_config.pop("client_host", None)
if "num_ctx" not in first_llm_config:
# In all Browser Use examples, num_ctx is set to 32000
first_llm_config["num_ctx"] = 32000
return ChatOllama(**first_llm_config)
@classmethod
def get_api_type(cls) -> str:
return "ollama"

View File

@@ -0,0 +1,82 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import sys
from typing import Any, Optional
from ...doc_utils import export_module
from ...import_utils import optional_import_block, require_optional_import
from ...tools import Tool
from ..registry import register_interoperable_class
__all__ = ["LangChainInteroperability"]
with optional_import_block():
from langchain_core.tools import BaseTool as LangchainTool
@register_interoperable_class("langchain")
@export_module("autogen.interop")
class LangChainInteroperability:
"""A class implementing the `Interoperable` protocol for converting Langchain tools
into a general `Tool` format.
This class takes a `LangchainTool` and converts it into a standard `Tool` object,
ensuring compatibility between Langchain tools and other systems that expect
the `Tool` format.
"""
@classmethod
@require_optional_import("langchain_core", "interop-langchain")
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
"""Converts a given Langchain tool into a general `Tool` format.
This method verifies that the provided tool is a valid `LangchainTool`,
processes the tool's input and description, and returns a standardized
`Tool` object.
Args:
tool (Any): The tool to convert, expected to be an instance of `LangchainTool`.
**kwargs (Any): Additional arguments, which are not supported by this method.
Returns:
Tool: A standardized `Tool` object converted from the Langchain tool.
Raises:
ValueError: If the provided tool is not an instance of `LangchainTool`, or if
any additional arguments are passed.
"""
if not isinstance(tool, LangchainTool):
raise ValueError(f"Expected an instance of `langchain_core.tools.BaseTool`, got {type(tool)}")
if kwargs:
raise ValueError(f"The LangchainInteroperability does not support any additional arguments, got {kwargs}")
# needed for type checking
langchain_tool: LangchainTool = tool # type: ignore[no-any-unimported]
model_type = langchain_tool.get_input_schema()
def func(tool_input: model_type) -> Any: # type: ignore[valid-type]
return langchain_tool.run(tool_input.model_dump()) # type: ignore[attr-defined]
return Tool(
name=langchain_tool.name,
description=langchain_tool.description,
func_or_tool=func,
)
@classmethod
def get_unsupported_reason(cls) -> Optional[str]:
if sys.version_info < (3, 9):
return "This submodule is only supported for Python versions 3.9 and above"
with optional_import_block() as result:
import langchain_core.tools # noqa: F401
if not result.is_successful:
return (
"Please install `interop-langchain` extra to use this module:\n\n\tpip install ag2[interop-langchain]"
)
return None

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .litellm_config_factory import LiteLLmConfigFactory
__all__ = ["LiteLLmConfigFactory"]

View File

@@ -0,0 +1,179 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, TypeVar, Union
from ...doc_utils import export_module
from ...llm_config import LLMConfig
from ...oai import get_first_llm_config
__all__ = ["LiteLLmConfigFactory"]
T = TypeVar("T", bound="LiteLLmConfigFactory")
def get_crawl4ai_version() -> Optional[str]:
"""Get the installed crawl4ai version."""
try:
import crawl4ai
version = getattr(crawl4ai, "__version__", None)
return version if isinstance(version, str) else None
except (ImportError, AttributeError):
return None
def is_crawl4ai_v05_or_higher() -> bool:
"""Check if crawl4ai version is 0.5 or higher."""
version = get_crawl4ai_version()
if version is None:
return False
# Parse version string (e.g., "0.5.0" -> [0, 5, 0])
try:
version_parts = [int(x) for x in version.split(".")]
# Check if version >= 0.5.0
return version_parts >= [0, 5, 0]
except (ValueError, IndexError):
return False
@export_module("autogen.interop")
class LiteLLmConfigFactory(ABC):
_factories: set["LiteLLmConfigFactory"] = set()
@classmethod
def create_lite_llm_config(cls, llm_config: Union[LLMConfig, dict[str, Any]]) -> dict[str, Any]:
"""
Create a lite LLM config compatible with the installed crawl4ai version.
For crawl4ai >=0.5: Returns config with llmConfig parameter
For crawl4ai <0.5: Returns config with provider parameter (legacy)
"""
first_llm_config = get_first_llm_config(llm_config)
for factory in LiteLLmConfigFactory._factories:
if factory.accepts(first_llm_config):
base_config = factory.create(first_llm_config)
# Check crawl4ai version and adapt config accordingly
if is_crawl4ai_v05_or_higher():
return cls._adapt_for_crawl4ai_v05(base_config)
else:
return base_config # Use legacy format
raise ValueError("Could not find a factory for the given config.")
@classmethod
def _adapt_for_crawl4ai_v05(cls, base_config: dict[str, Any]) -> dict[str, Any]:
"""
Adapt the config for crawl4ai >=0.5 by moving deprecated parameters
into an llmConfig object.
"""
adapted_config = base_config.copy()
# Extract deprecated parameters
llm_config_params = {}
if "provider" in adapted_config:
llm_config_params["provider"] = adapted_config.pop("provider")
if "api_token" in adapted_config:
llm_config_params["api_token"] = adapted_config.pop("api_token")
# Add other parameters that should be in llmConfig
for param in ["base_url", "api_base", "api_version"]:
if param in adapted_config:
llm_config_params[param] = adapted_config.pop(param)
# Create the llmConfig object if we have parameters for it
if llm_config_params:
adapted_config["llmConfig"] = llm_config_params
return adapted_config
@classmethod
def register_factory(cls) -> Callable[[type[T]], type[T]]:
def decorator(factory: type[T]) -> type[T]:
cls._factories.add(factory())
return factory
return decorator
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
model = first_llm_config.pop("model")
api_type = first_llm_config.pop("api_type", "openai")
first_llm_config["provider"] = f"{api_type}/{model}"
return first_llm_config
@classmethod
@abstractmethod
def get_api_type(cls) -> str: ...
@classmethod
def accepts(cls, first_llm_config: dict[str, Any]) -> bool:
return first_llm_config.get("api_type", "openai") == cls.get_api_type() # type: ignore [no-any-return]
@LiteLLmConfigFactory.register_factory()
class DefaultLiteLLmConfigFactory(LiteLLmConfigFactory):
@classmethod
def get_api_type(cls) -> str:
raise NotImplementedError("DefaultLiteLLmConfigFactory does not have an API type.")
@classmethod
def accepts(cls, first_llm_config: dict[str, Any]) -> bool:
non_base_api_types = ["google", "ollama"]
return first_llm_config.get("api_type", "openai") not in non_base_api_types
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
api_type = first_llm_config.get("api_type", "openai")
if api_type != "openai" and "api_key" not in first_llm_config:
raise ValueError("API key is required.")
first_llm_config["api_token"] = first_llm_config.pop("api_key", os.getenv("OPENAI_API_KEY"))
first_llm_config = super().create(first_llm_config)
return first_llm_config
@LiteLLmConfigFactory.register_factory()
class GoogleLiteLLmConfigFactory(LiteLLmConfigFactory):
@classmethod
def get_api_type(cls) -> str:
return "google"
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
# api type must be changed before calling super().create
# litellm uses gemini as the api type for google
first_llm_config["api_type"] = "gemini"
first_llm_config["api_token"] = first_llm_config.pop("api_key")
first_llm_config = super().create(first_llm_config)
return first_llm_config
@classmethod
def accepts(cls, first_llm_config: dict[str, Any]) -> bool:
api_type: str = first_llm_config.get("api_type", "")
return api_type == cls.get_api_type() or api_type == "gemini"
@LiteLLmConfigFactory.register_factory()
class OllamaLiteLLmConfigFactory(LiteLLmConfigFactory):
@classmethod
def get_api_type(cls) -> str:
return "ollama"
@classmethod
def create(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
first_llm_config = super().create(first_llm_config)
if "client_host" in first_llm_config:
first_llm_config["api_base"] = first_llm_config.pop("client_host")
return first_llm_config

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .pydantic_ai import PydanticAIInteroperability
__all__ = ["PydanticAIInteroperability"]

View File

@@ -0,0 +1,168 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import sys
import warnings
from functools import wraps
from inspect import signature
from typing import Any, Callable, Optional
from ...doc_utils import export_module
from ...import_utils import optional_import_block, require_optional_import
from ...tools import Tool
from ..registry import register_interoperable_class
__all__ = ["PydanticAIInteroperability"]
with optional_import_block():
from pydantic_ai import RunContext
from pydantic_ai.tools import Tool as PydanticAITool
from pydantic_ai.usage import Usage
@register_interoperable_class("pydanticai")
@export_module("autogen.interop")
class PydanticAIInteroperability:
"""A class implementing the `Interoperable` protocol for converting Pydantic AI tools
into a general `Tool` format.
This class takes a `PydanticAITool` and converts it into a standard `Tool` object,
ensuring compatibility between Pydantic AI tools and other systems that expect
the `Tool` format. It also provides a mechanism for injecting context parameters
into the tool's function.
"""
@staticmethod
@require_optional_import("pydantic_ai", "interop-pydantic-ai")
def inject_params(
ctx: Any,
tool: Any,
) -> Callable[..., Any]:
"""Wraps the tool's function to inject context parameters and handle retries.
This method ensures that context parameters are properly passed to the tool
when invoked and that retries are managed according to the tool's settings.
Args:
ctx (Optional[RunContext[Any]]): The run context, which may include dependencies and retry information.
tool (PydanticAITool): The Pydantic AI tool whose function is to be wrapped.
Returns:
Callable[..., Any]: A wrapped function that includes context injection and retry handling.
Raises:
ValueError: If the tool fails after the maximum number of retries.
"""
ctx_typed: Optional[RunContext[Any]] = ctx # type: ignore[no-any-unimported]
tool_typed: PydanticAITool[Any] = tool # type: ignore[no-any-unimported]
max_retries = tool_typed.max_retries if tool_typed.max_retries is not None else 1
f = tool_typed.function
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if tool_typed.current_retry >= max_retries:
raise ValueError(f"{tool_typed.name} failed after {max_retries} retries")
try:
if ctx_typed is not None:
kwargs.pop("ctx", None)
ctx_typed.retry = tool_typed.current_retry
result = f(**kwargs, ctx=ctx_typed) # type: ignore[call-arg]
else:
result = f(**kwargs) # type: ignore[call-arg]
tool_typed.current_retry = 0
except Exception as e:
tool_typed.current_retry += 1
raise e
return result
sig = signature(f)
if ctx_typed is not None:
new_params = [param for name, param in sig.parameters.items() if name != "ctx"]
else:
new_params = list(sig.parameters.values())
wrapper.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined]
return wrapper
@classmethod
@require_optional_import("pydantic_ai", "interop-pydantic-ai")
def convert_tool(cls, tool: Any, deps: Any = None, **kwargs: Any) -> Tool:
"""Converts a given Pydantic AI tool into a general `Tool` format.
This method verifies that the provided tool is a valid `PydanticAITool`,
handles context dependencies if necessary, and returns a standardized `Tool` object.
Args:
tool (Any): The tool to convert, expected to be an instance of `PydanticAITool`.
deps (Any, optional): The dependencies to inject into the context, required if
the tool takes a context. Defaults to None.
**kwargs (Any): Additional arguments that are not used in this method.
Returns:
Tool: A standardized `Tool` object converted from the Pydantic AI tool.
Raises:
ValueError: If the provided tool is not an instance of `PydanticAITool`, or if
dependencies are missing for tools that require a context.
UserWarning: If the `deps` argument is provided for a tool that does not take a context.
"""
if not isinstance(tool, PydanticAITool):
raise ValueError(f"Expected an instance of `pydantic_ai.tools.Tool`, got {type(tool)}")
# needed for type checking
pydantic_ai_tool: PydanticAITool[Any] = tool # type: ignore[no-any-unimported]
if tool.takes_ctx and deps is None:
raise ValueError("If the tool takes a context, the `deps` argument must be provided")
if not tool.takes_ctx and deps is not None:
warnings.warn(
"The `deps` argument is provided but will be ignored because the tool does not take a context.",
UserWarning,
)
ctx = (
RunContext(
model=None, # type: ignore [arg-type]
usage=Usage(),
prompt="",
deps=deps,
retry=0,
# All messages send to or returned by a model.
# This is mostly used on pydantic_ai Agent level.
messages=[], # TODO: check in the future if this is needed on Tool level
tool_name=pydantic_ai_tool.name,
)
if tool.takes_ctx
else None
)
func = PydanticAIInteroperability.inject_params(
ctx=ctx,
tool=pydantic_ai_tool,
)
return Tool(
name=pydantic_ai_tool.name,
description=pydantic_ai_tool.description,
func_or_tool=func,
parameters_json_schema=pydantic_ai_tool._parameters_json_schema,
)
@classmethod
def get_unsupported_reason(cls) -> Optional[str]:
if sys.version_info < (3, 9):
return "This submodule is only supported for Python versions 3.9 and above"
with optional_import_block() as result:
import pydantic_ai.tools # noqa: F401
if not result.is_successful:
return "Please install `interop-pydantic-ai` extra to use this module:\n\n\tpip install ag2[interop-pydantic-ai]"
return None

View File

@@ -0,0 +1,69 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, TypeVar
from ..doc_utils import export_module
from .interoperable import Interoperable
__all__ = ["InteroperableRegistry", "register_interoperable_class"]
InteroperableClass = TypeVar("InteroperableClass", bound=type[Interoperable])
class InteroperableRegistry:
def __init__(self) -> None:
self._registry: dict[str, type[Interoperable]] = {}
def register(self, short_name: str, cls: InteroperableClass) -> InteroperableClass:
if short_name in self._registry:
raise ValueError(f"Duplicate registration for {short_name}")
self._registry[short_name] = cls
return cls
def get_short_names(self) -> list[str]:
return sorted(self._registry.keys())
def get_supported_types(self) -> list[str]:
short_names = self.get_short_names()
supported_types = [name for name in short_names if self._registry[name].get_unsupported_reason() is None]
return supported_types
def get_class(self, short_name: str) -> type[Interoperable]:
return self._registry[short_name]
@classmethod
def get_instance(cls) -> "InteroperableRegistry":
return _register
# global registry
_register = InteroperableRegistry()
# register decorator
@export_module("autogen.interop")
def register_interoperable_class(short_name: str) -> Callable[[InteroperableClass], InteroperableClass]:
"""Register an Interoperable class in the global registry.
Returns:
Callable[[InteroperableClass], InteroperableClass]: Decorator function
Example:
```python
@register_interoperable_class("myinterop")
class MyInteroperability(Interoperable):
def convert_tool(self, tool: Any) -> Tool:
# implementation
...
```
"""
def inner(cls: InteroperableClass) -> InteroperableClass:
global _register
return _register.register(short_name, cls)
return inner