CoACT initialize (#292)
This commit is contained in:
22
mm_agents/coact/autogen/interop/__init__.py
Normal file
22
mm_agents/coact/autogen/interop/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .crewai import CrewAIInteroperability
|
||||
from .interoperability import Interoperability
|
||||
from .interoperable import Interoperable
|
||||
from .langchain import LangChainChatModelFactory, LangChainInteroperability
|
||||
from .litellm import LiteLLmConfigFactory
|
||||
from .pydantic_ai import PydanticAIInteroperability
|
||||
from .registry import register_interoperable_class
|
||||
|
||||
__all__ = [
|
||||
"CrewAIInteroperability",
|
||||
"Interoperability",
|
||||
"Interoperable",
|
||||
"LangChainChatModelFactory",
|
||||
"LangChainInteroperability",
|
||||
"LiteLLmConfigFactory",
|
||||
"PydanticAIInteroperability",
|
||||
"register_interoperable_class",
|
||||
]
|
||||
7
mm_agents/coact/autogen/interop/crewai/__init__.py
Normal file
7
mm_agents/coact/autogen/interop/crewai/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .crewai import CrewAIInteroperability
|
||||
|
||||
__all__ = ["CrewAIInteroperability"]
|
||||
88
mm_agents/coact/autogen/interop/crewai/crewai.py
Normal file
88
mm_agents/coact/autogen/interop/crewai/crewai.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from ...import_utils import optional_import_block, require_optional_import
|
||||
from ...tools import Tool
|
||||
from ..registry import register_interoperable_class
|
||||
|
||||
__all__ = ["CrewAIInteroperability"]
|
||||
|
||||
|
||||
def _sanitize_name(s: str) -> str:
|
||||
return re.sub(r"\W|^(?=\d)", "_", s)
|
||||
|
||||
|
||||
with optional_import_block():
|
||||
from crewai.tools import BaseTool as CrewAITool
|
||||
|
||||
|
||||
@register_interoperable_class("crewai")
|
||||
@export_module("autogen.interop")
|
||||
class CrewAIInteroperability:
|
||||
"""A class implementing the `Interoperable` protocol for converting CrewAI tools
|
||||
to a general `Tool` format.
|
||||
|
||||
This class takes a `CrewAITool` and converts it into a standard `Tool` object.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@require_optional_import("crewai", "interop-crewai")
|
||||
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
|
||||
"""Converts a given CrewAI tool into a general `Tool` format.
|
||||
|
||||
This method ensures that the provided tool is a valid `CrewAITool`, sanitizes
|
||||
the tool's name, processes its description, and prepares a function to interact
|
||||
with the tool's arguments. It then returns a standardized `Tool` object.
|
||||
|
||||
Args:
|
||||
tool (Any): The tool to convert, expected to be an instance of `CrewAITool`.
|
||||
**kwargs (Any): Additional arguments, which are not supported by this method.
|
||||
|
||||
Returns:
|
||||
Tool: A standardized `Tool` object converted from the CrewAI tool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided tool is not an instance of `CrewAITool`, or if
|
||||
any additional arguments are passed.
|
||||
"""
|
||||
if not isinstance(tool, CrewAITool):
|
||||
raise ValueError(f"Expected an instance of `crewai.tools.BaseTool`, got {type(tool)}")
|
||||
if kwargs:
|
||||
raise ValueError(f"The CrewAIInteroperability does not support any additional arguments, got {kwargs}")
|
||||
|
||||
# needed for type checking
|
||||
crewai_tool: CrewAITool = tool # type: ignore[no-any-unimported]
|
||||
|
||||
name = _sanitize_name(crewai_tool.name)
|
||||
description = (
|
||||
crewai_tool.description.split("Tool Description: ")[-1]
|
||||
+ " (IMPORTANT: When using arguments, put them all in an `args` dictionary)"
|
||||
)
|
||||
|
||||
def func(args: crewai_tool.args_schema) -> Any: # type: ignore[no-any-unimported]
|
||||
return crewai_tool.run(**args.model_dump())
|
||||
|
||||
return Tool(
|
||||
name=name,
|
||||
description=description,
|
||||
func_or_tool=func,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_unsupported_reason(cls) -> Optional[str]:
|
||||
if sys.version_info < (3, 10) or sys.version_info >= (3, 13):
|
||||
return "This submodule is only supported for Python versions 3.10, 3.11, and 3.12"
|
||||
|
||||
with optional_import_block() as result:
|
||||
import crewai.tools # noqa: F401
|
||||
|
||||
if not result.is_successful:
|
||||
return "Please install `interop-crewai` extra to use this module:\n\n\tpip install ag2[interop-crewai]"
|
||||
|
||||
return None
|
||||
71
mm_agents/coact/autogen/interop/interoperability.py
Normal file
71
mm_agents/coact/autogen/interop/interoperability.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from ..tools import Tool
|
||||
from .interoperable import Interoperable
|
||||
from .registry import InteroperableRegistry
|
||||
|
||||
__all__ = ["Interoperable"]
|
||||
|
||||
|
||||
@export_module("autogen.interop")
|
||||
class Interoperability:
|
||||
"""A class to handle interoperability between different tool types.
|
||||
|
||||
This class allows the conversion of tools to various interoperability classes and provides functionality
|
||||
for retrieving and registering interoperability classes.
|
||||
"""
|
||||
|
||||
registry = InteroperableRegistry.get_instance()
|
||||
|
||||
@classmethod
|
||||
def convert_tool(cls, *, tool: Any, type: str, **kwargs: Any) -> Tool:
|
||||
"""Converts a given tool to an instance of a specified interoperability type.
|
||||
|
||||
Args:
|
||||
tool (Any): The tool object to be converted.
|
||||
type (str): The type of interoperability to convert the tool to.
|
||||
**kwargs (Any): Additional arguments to be passed during conversion.
|
||||
|
||||
Returns:
|
||||
Tool: The converted tool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the interoperability class for the provided type is not found.
|
||||
"""
|
||||
interop = cls.get_interoperability_class(type)
|
||||
return interop.convert_tool(tool, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_interoperability_class(cls, type: str) -> type[Interoperable]:
|
||||
"""Retrieves the interoperability class corresponding to the specified type.
|
||||
|
||||
Args:
|
||||
type (str): The type of the interoperability class to retrieve.
|
||||
|
||||
Returns:
|
||||
type[Interoperable]: The interoperability class type.
|
||||
|
||||
Raises:
|
||||
ValueError: If no interoperability class is found for the provided type.
|
||||
"""
|
||||
supported_types = cls.registry.get_supported_types()
|
||||
if type not in supported_types:
|
||||
supported_types_formatted = ", ".join(["'t'" for t in supported_types])
|
||||
raise ValueError(
|
||||
f"Interoperability class {type} is not supported, supported types: {supported_types_formatted}"
|
||||
)
|
||||
|
||||
return cls.registry.get_class(type)
|
||||
|
||||
@classmethod
|
||||
def get_supported_types(cls) -> list[str]:
|
||||
"""Returns a sorted list of all supported interoperability types.
|
||||
|
||||
Returns:
|
||||
List[str]: A sorted list of strings representing the supported interoperability types.
|
||||
"""
|
||||
return sorted(cls.registry.get_supported_types())
|
||||
46
mm_agents/coact/autogen/interop/interoperable.py
Normal file
46
mm_agents/coact/autogen/interop/interoperable.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Optional, Protocol, runtime_checkable
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from ..tools import Tool
|
||||
|
||||
__all__ = ["Interoperable"]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.interop")
|
||||
class Interoperable(Protocol):
|
||||
"""A Protocol defining the interoperability interface for tool conversion.
|
||||
|
||||
This protocol ensures that any class implementing it provides the method
|
||||
`convert_tool` to convert a given tool into a desired format or type.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
|
||||
"""Converts a given tool to a desired format or type.
|
||||
|
||||
This method should be implemented by any class adhering to the `Interoperable` protocol.
|
||||
|
||||
Args:
|
||||
tool (Any): The tool object to be converted.
|
||||
**kwargs (Any): Additional parameters to pass during the conversion process.
|
||||
|
||||
Returns:
|
||||
Tool: The converted tool in the desired format or type.
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_unsupported_reason(cls) -> Optional[str]:
|
||||
"""Returns the reason for the tool being unsupported.
|
||||
|
||||
This method should be implemented by any class adhering to the `Interoperable` protocol.
|
||||
|
||||
Returns:
|
||||
str: The reason for the interoperability class being unsupported.
|
||||
"""
|
||||
...
|
||||
8
mm_agents/coact/autogen/interop/langchain/__init__.py
Normal file
8
mm_agents/coact/autogen/interop/langchain/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .langchain_chat_model_factory import LangChainChatModelFactory
|
||||
from .langchain_tool import LangChainInteroperability
|
||||
|
||||
__all__ = ["LangChainChatModelFactory", "LangChainInteroperability"]
|
||||
@@ -0,0 +1,155 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from ...import_utils import optional_import_block, require_optional_import
|
||||
from ...llm_config import LLMConfig
|
||||
from ...oai import get_first_llm_config
|
||||
|
||||
with optional_import_block():
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
|
||||
|
||||
__all__ = ["LangChainChatModelFactory"]
|
||||
|
||||
T = TypeVar("T", bound="LangChainChatModelFactory")
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
["langchain_anthropic", "langchain_google_genai", "langchain_ollama", "langchain_openai", "langchain_core"],
|
||||
"browser-use",
|
||||
except_for=["__init__", "register_factory"],
|
||||
)
|
||||
@export_module("autogen.interop")
|
||||
class LangChainChatModelFactory(ABC):
|
||||
_factories: set["LangChainChatModelFactory"] = set()
|
||||
|
||||
@classmethod
|
||||
def create_base_chat_model(cls, llm_config: Union[LLMConfig, dict[str, Any]]) -> "BaseChatModel": # type: ignore [no-any-unimported]
|
||||
first_llm_config = get_first_llm_config(llm_config)
|
||||
for factory in LangChainChatModelFactory._factories:
|
||||
if factory.accepts(first_llm_config):
|
||||
return factory.create(first_llm_config)
|
||||
|
||||
raise ValueError("Could not find a factory for the given config.")
|
||||
|
||||
@classmethod
|
||||
def register_factory(cls) -> Callable[[type[T]], type[T]]:
|
||||
def decorator(factory: type[T]) -> type[T]:
|
||||
cls._factories.add(factory())
|
||||
return factory
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def prepare_config(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
|
||||
for pop_keys in ["api_type", "response_format"]:
|
||||
first_llm_config.pop(pop_keys, None)
|
||||
return first_llm_config
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "BaseChatModel": # type: ignore [no-any-unimported]
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_api_type(cls) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def accepts(cls, first_llm_config: dict[str, Any]) -> bool:
|
||||
return first_llm_config.get("api_type", "openai") == cls.get_api_type() # type: ignore [no-any-return]
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class ChatOpenAIFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatOpenAI": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
|
||||
return ChatOpenAI(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "openai"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class DeepSeekFactory(ChatOpenAIFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatOpenAI": # type: ignore [no-any-unimported]
|
||||
if "base_url" not in first_llm_config:
|
||||
raise ValueError("base_url is required for deepseek api type.")
|
||||
return super().create(first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "deepseek"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class ChatAnthropicFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatAnthropic": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
|
||||
return ChatAnthropic(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "anthropic"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class ChatGoogleGenerativeAIFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatGoogleGenerativeAI": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
|
||||
return ChatGoogleGenerativeAI(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "google"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class AzureChatOpenAIFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "AzureChatOpenAI": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
for param in ["base_url", "api_version"]:
|
||||
if param not in first_llm_config:
|
||||
raise ValueError(f"{param} is required for azure api type.")
|
||||
first_llm_config["azure_endpoint"] = first_llm_config.pop("base_url")
|
||||
|
||||
return AzureChatOpenAI(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "azure"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class ChatOllamaFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatOllama": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
first_llm_config["base_url"] = first_llm_config.pop("client_host", None)
|
||||
if "num_ctx" not in first_llm_config:
|
||||
# In all Browser Use examples, num_ctx is set to 32000
|
||||
first_llm_config["num_ctx"] = 32000
|
||||
|
||||
return ChatOllama(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "ollama"
|
||||
82
mm_agents/coact/autogen/interop/langchain/langchain_tool.py
Normal file
82
mm_agents/coact/autogen/interop/langchain/langchain_tool.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from ...import_utils import optional_import_block, require_optional_import
|
||||
from ...tools import Tool
|
||||
from ..registry import register_interoperable_class
|
||||
|
||||
__all__ = ["LangChainInteroperability"]
|
||||
|
||||
with optional_import_block():
|
||||
from langchain_core.tools import BaseTool as LangchainTool
|
||||
|
||||
|
||||
@register_interoperable_class("langchain")
|
||||
@export_module("autogen.interop")
|
||||
class LangChainInteroperability:
|
||||
"""A class implementing the `Interoperable` protocol for converting Langchain tools
|
||||
into a general `Tool` format.
|
||||
|
||||
This class takes a `LangchainTool` and converts it into a standard `Tool` object,
|
||||
ensuring compatibility between Langchain tools and other systems that expect
|
||||
the `Tool` format.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@require_optional_import("langchain_core", "interop-langchain")
|
||||
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
|
||||
"""Converts a given Langchain tool into a general `Tool` format.
|
||||
|
||||
This method verifies that the provided tool is a valid `LangchainTool`,
|
||||
processes the tool's input and description, and returns a standardized
|
||||
`Tool` object.
|
||||
|
||||
Args:
|
||||
tool (Any): The tool to convert, expected to be an instance of `LangchainTool`.
|
||||
**kwargs (Any): Additional arguments, which are not supported by this method.
|
||||
|
||||
Returns:
|
||||
Tool: A standardized `Tool` object converted from the Langchain tool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided tool is not an instance of `LangchainTool`, or if
|
||||
any additional arguments are passed.
|
||||
"""
|
||||
if not isinstance(tool, LangchainTool):
|
||||
raise ValueError(f"Expected an instance of `langchain_core.tools.BaseTool`, got {type(tool)}")
|
||||
if kwargs:
|
||||
raise ValueError(f"The LangchainInteroperability does not support any additional arguments, got {kwargs}")
|
||||
|
||||
# needed for type checking
|
||||
langchain_tool: LangchainTool = tool # type: ignore[no-any-unimported]
|
||||
|
||||
model_type = langchain_tool.get_input_schema()
|
||||
|
||||
def func(tool_input: model_type) -> Any: # type: ignore[valid-type]
|
||||
return langchain_tool.run(tool_input.model_dump()) # type: ignore[attr-defined]
|
||||
|
||||
return Tool(
|
||||
name=langchain_tool.name,
|
||||
description=langchain_tool.description,
|
||||
func_or_tool=func,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_unsupported_reason(cls) -> Optional[str]:
|
||||
if sys.version_info < (3, 9):
|
||||
return "This submodule is only supported for Python versions 3.9 and above"
|
||||
|
||||
with optional_import_block() as result:
|
||||
import langchain_core.tools # noqa: F401
|
||||
|
||||
if not result.is_successful:
|
||||
return (
|
||||
"Please install `interop-langchain` extra to use this module:\n\n\tpip install ag2[interop-langchain]"
|
||||
)
|
||||
|
||||
return None
|
||||
7
mm_agents/coact/autogen/interop/litellm/__init__.py
Normal file
7
mm_agents/coact/autogen/interop/litellm/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .litellm_config_factory import LiteLLmConfigFactory
|
||||
|
||||
__all__ = ["LiteLLmConfigFactory"]
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from ...llm_config import LLMConfig
|
||||
from ...oai import get_first_llm_config
|
||||
|
||||
__all__ = ["LiteLLmConfigFactory"]
|
||||
|
||||
T = TypeVar("T", bound="LiteLLmConfigFactory")
|
||||
|
||||
|
||||
def get_crawl4ai_version() -> Optional[str]:
|
||||
"""Get the installed crawl4ai version."""
|
||||
try:
|
||||
import crawl4ai
|
||||
|
||||
version = getattr(crawl4ai, "__version__", None)
|
||||
return version if isinstance(version, str) else None
|
||||
except (ImportError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def is_crawl4ai_v05_or_higher() -> bool:
|
||||
"""Check if crawl4ai version is 0.5 or higher."""
|
||||
version = get_crawl4ai_version()
|
||||
if version is None:
|
||||
return False
|
||||
|
||||
# Parse version string (e.g., "0.5.0" -> [0, 5, 0])
|
||||
try:
|
||||
version_parts = [int(x) for x in version.split(".")]
|
||||
# Check if version >= 0.5.0
|
||||
return version_parts >= [0, 5, 0]
|
||||
except (ValueError, IndexError):
|
||||
return False
|
||||
|
||||
|
||||
@export_module("autogen.interop")
|
||||
class LiteLLmConfigFactory(ABC):
|
||||
_factories: set["LiteLLmConfigFactory"] = set()
|
||||
|
||||
@classmethod
|
||||
def create_lite_llm_config(cls, llm_config: Union[LLMConfig, dict[str, Any]]) -> dict[str, Any]:
|
||||
"""
|
||||
Create a lite LLM config compatible with the installed crawl4ai version.
|
||||
|
||||
For crawl4ai >=0.5: Returns config with llmConfig parameter
|
||||
For crawl4ai <0.5: Returns config with provider parameter (legacy)
|
||||
"""
|
||||
first_llm_config = get_first_llm_config(llm_config)
|
||||
for factory in LiteLLmConfigFactory._factories:
|
||||
if factory.accepts(first_llm_config):
|
||||
base_config = factory.create(first_llm_config)
|
||||
|
||||
# Check crawl4ai version and adapt config accordingly
|
||||
if is_crawl4ai_v05_or_higher():
|
||||
return cls._adapt_for_crawl4ai_v05(base_config)
|
||||
else:
|
||||
return base_config # Use legacy format
|
||||
|
||||
raise ValueError("Could not find a factory for the given config.")
|
||||
|
||||
@classmethod
|
||||
def _adapt_for_crawl4ai_v05(cls, base_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Adapt the config for crawl4ai >=0.5 by moving deprecated parameters
|
||||
into an llmConfig object.
|
||||
"""
|
||||
adapted_config = base_config.copy()
|
||||
|
||||
# Extract deprecated parameters
|
||||
llm_config_params = {}
|
||||
|
||||
if "provider" in adapted_config:
|
||||
llm_config_params["provider"] = adapted_config.pop("provider")
|
||||
|
||||
if "api_token" in adapted_config:
|
||||
llm_config_params["api_token"] = adapted_config.pop("api_token")
|
||||
|
||||
# Add other parameters that should be in llmConfig
|
||||
for param in ["base_url", "api_base", "api_version"]:
|
||||
if param in adapted_config:
|
||||
llm_config_params[param] = adapted_config.pop(param)
|
||||
|
||||
# Create the llmConfig object if we have parameters for it
|
||||
if llm_config_params:
|
||||
adapted_config["llmConfig"] = llm_config_params
|
||||
|
||||
return adapted_config
|
||||
|
||||
@classmethod
|
||||
def register_factory(cls) -> Callable[[type[T]], type[T]]:
|
||||
def decorator(factory: type[T]) -> type[T]:
|
||||
cls._factories.add(factory())
|
||||
return factory
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
|
||||
model = first_llm_config.pop("model")
|
||||
api_type = first_llm_config.pop("api_type", "openai")
|
||||
|
||||
first_llm_config["provider"] = f"{api_type}/{model}"
|
||||
return first_llm_config
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_api_type(cls) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def accepts(cls, first_llm_config: dict[str, Any]) -> bool:
|
||||
return first_llm_config.get("api_type", "openai") == cls.get_api_type() # type: ignore [no-any-return]
|
||||
|
||||
|
||||
@LiteLLmConfigFactory.register_factory()
|
||||
class DefaultLiteLLmConfigFactory(LiteLLmConfigFactory):
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
raise NotImplementedError("DefaultLiteLLmConfigFactory does not have an API type.")
|
||||
|
||||
@classmethod
|
||||
def accepts(cls, first_llm_config: dict[str, Any]) -> bool:
|
||||
non_base_api_types = ["google", "ollama"]
|
||||
return first_llm_config.get("api_type", "openai") not in non_base_api_types
|
||||
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
|
||||
api_type = first_llm_config.get("api_type", "openai")
|
||||
if api_type != "openai" and "api_key" not in first_llm_config:
|
||||
raise ValueError("API key is required.")
|
||||
first_llm_config["api_token"] = first_llm_config.pop("api_key", os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
first_llm_config = super().create(first_llm_config)
|
||||
|
||||
return first_llm_config
|
||||
|
||||
|
||||
@LiteLLmConfigFactory.register_factory()
|
||||
class GoogleLiteLLmConfigFactory(LiteLLmConfigFactory):
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "google"
|
||||
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
|
||||
# api type must be changed before calling super().create
|
||||
# litellm uses gemini as the api type for google
|
||||
first_llm_config["api_type"] = "gemini"
|
||||
first_llm_config["api_token"] = first_llm_config.pop("api_key")
|
||||
first_llm_config = super().create(first_llm_config)
|
||||
|
||||
return first_llm_config
|
||||
|
||||
@classmethod
|
||||
def accepts(cls, first_llm_config: dict[str, Any]) -> bool:
|
||||
api_type: str = first_llm_config.get("api_type", "")
|
||||
return api_type == cls.get_api_type() or api_type == "gemini"
|
||||
|
||||
|
||||
@LiteLLmConfigFactory.register_factory()
|
||||
class OllamaLiteLLmConfigFactory(LiteLLmConfigFactory):
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "ollama"
|
||||
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
|
||||
first_llm_config = super().create(first_llm_config)
|
||||
if "client_host" in first_llm_config:
|
||||
first_llm_config["api_base"] = first_llm_config.pop("client_host")
|
||||
|
||||
return first_llm_config
|
||||
7
mm_agents/coact/autogen/interop/pydantic_ai/__init__.py
Normal file
7
mm_agents/coact/autogen/interop/pydantic_ai/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .pydantic_ai import PydanticAIInteroperability
|
||||
|
||||
__all__ = ["PydanticAIInteroperability"]
|
||||
168
mm_agents/coact/autogen/interop/pydantic_ai/pydantic_ai.py
Normal file
168
mm_agents/coact/autogen/interop/pydantic_ai/pydantic_ai.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from ...import_utils import optional_import_block, require_optional_import
|
||||
from ...tools import Tool
|
||||
from ..registry import register_interoperable_class
|
||||
|
||||
__all__ = ["PydanticAIInteroperability"]
|
||||
|
||||
with optional_import_block():
|
||||
from pydantic_ai import RunContext
|
||||
from pydantic_ai.tools import Tool as PydanticAITool
|
||||
from pydantic_ai.usage import Usage
|
||||
|
||||
|
||||
@register_interoperable_class("pydanticai")
|
||||
@export_module("autogen.interop")
|
||||
class PydanticAIInteroperability:
|
||||
"""A class implementing the `Interoperable` protocol for converting Pydantic AI tools
|
||||
into a general `Tool` format.
|
||||
|
||||
This class takes a `PydanticAITool` and converts it into a standard `Tool` object,
|
||||
ensuring compatibility between Pydantic AI tools and other systems that expect
|
||||
the `Tool` format. It also provides a mechanism for injecting context parameters
|
||||
into the tool's function.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@require_optional_import("pydantic_ai", "interop-pydantic-ai")
|
||||
def inject_params(
|
||||
ctx: Any,
|
||||
tool: Any,
|
||||
) -> Callable[..., Any]:
|
||||
"""Wraps the tool's function to inject context parameters and handle retries.
|
||||
|
||||
This method ensures that context parameters are properly passed to the tool
|
||||
when invoked and that retries are managed according to the tool's settings.
|
||||
|
||||
Args:
|
||||
ctx (Optional[RunContext[Any]]): The run context, which may include dependencies and retry information.
|
||||
tool (PydanticAITool): The Pydantic AI tool whose function is to be wrapped.
|
||||
|
||||
Returns:
|
||||
Callable[..., Any]: A wrapped function that includes context injection and retry handling.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tool fails after the maximum number of retries.
|
||||
"""
|
||||
ctx_typed: Optional[RunContext[Any]] = ctx # type: ignore[no-any-unimported]
|
||||
tool_typed: PydanticAITool[Any] = tool # type: ignore[no-any-unimported]
|
||||
|
||||
max_retries = tool_typed.max_retries if tool_typed.max_retries is not None else 1
|
||||
f = tool_typed.function
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if tool_typed.current_retry >= max_retries:
|
||||
raise ValueError(f"{tool_typed.name} failed after {max_retries} retries")
|
||||
|
||||
try:
|
||||
if ctx_typed is not None:
|
||||
kwargs.pop("ctx", None)
|
||||
ctx_typed.retry = tool_typed.current_retry
|
||||
result = f(**kwargs, ctx=ctx_typed) # type: ignore[call-arg]
|
||||
else:
|
||||
result = f(**kwargs) # type: ignore[call-arg]
|
||||
tool_typed.current_retry = 0
|
||||
except Exception as e:
|
||||
tool_typed.current_retry += 1
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
sig = signature(f)
|
||||
if ctx_typed is not None:
|
||||
new_params = [param for name, param in sig.parameters.items() if name != "ctx"]
|
||||
else:
|
||||
new_params = list(sig.parameters.values())
|
||||
|
||||
wrapper.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined]
|
||||
|
||||
return wrapper
|
||||
|
||||
@classmethod
|
||||
@require_optional_import("pydantic_ai", "interop-pydantic-ai")
|
||||
def convert_tool(cls, tool: Any, deps: Any = None, **kwargs: Any) -> Tool:
|
||||
"""Converts a given Pydantic AI tool into a general `Tool` format.
|
||||
|
||||
This method verifies that the provided tool is a valid `PydanticAITool`,
|
||||
handles context dependencies if necessary, and returns a standardized `Tool` object.
|
||||
|
||||
Args:
|
||||
tool (Any): The tool to convert, expected to be an instance of `PydanticAITool`.
|
||||
deps (Any, optional): The dependencies to inject into the context, required if
|
||||
the tool takes a context. Defaults to None.
|
||||
**kwargs (Any): Additional arguments that are not used in this method.
|
||||
|
||||
Returns:
|
||||
Tool: A standardized `Tool` object converted from the Pydantic AI tool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided tool is not an instance of `PydanticAITool`, or if
|
||||
dependencies are missing for tools that require a context.
|
||||
UserWarning: If the `deps` argument is provided for a tool that does not take a context.
|
||||
"""
|
||||
if not isinstance(tool, PydanticAITool):
|
||||
raise ValueError(f"Expected an instance of `pydantic_ai.tools.Tool`, got {type(tool)}")
|
||||
|
||||
# needed for type checking
|
||||
pydantic_ai_tool: PydanticAITool[Any] = tool # type: ignore[no-any-unimported]
|
||||
|
||||
if tool.takes_ctx and deps is None:
|
||||
raise ValueError("If the tool takes a context, the `deps` argument must be provided")
|
||||
if not tool.takes_ctx and deps is not None:
|
||||
warnings.warn(
|
||||
"The `deps` argument is provided but will be ignored because the tool does not take a context.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
ctx = (
|
||||
RunContext(
|
||||
model=None, # type: ignore [arg-type]
|
||||
usage=Usage(),
|
||||
prompt="",
|
||||
deps=deps,
|
||||
retry=0,
|
||||
# All messages send to or returned by a model.
|
||||
# This is mostly used on pydantic_ai Agent level.
|
||||
messages=[], # TODO: check in the future if this is needed on Tool level
|
||||
tool_name=pydantic_ai_tool.name,
|
||||
)
|
||||
if tool.takes_ctx
|
||||
else None
|
||||
)
|
||||
|
||||
func = PydanticAIInteroperability.inject_params(
|
||||
ctx=ctx,
|
||||
tool=pydantic_ai_tool,
|
||||
)
|
||||
|
||||
return Tool(
|
||||
name=pydantic_ai_tool.name,
|
||||
description=pydantic_ai_tool.description,
|
||||
func_or_tool=func,
|
||||
parameters_json_schema=pydantic_ai_tool._parameters_json_schema,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_unsupported_reason(cls) -> Optional[str]:
|
||||
if sys.version_info < (3, 9):
|
||||
return "This submodule is only supported for Python versions 3.9 and above"
|
||||
|
||||
with optional_import_block() as result:
|
||||
import pydantic_ai.tools # noqa: F401
|
||||
|
||||
if not result.is_successful:
|
||||
return "Please install `interop-pydantic-ai` extra to use this module:\n\n\tpip install ag2[interop-pydantic-ai]"
|
||||
|
||||
return None
|
||||
69
mm_agents/coact/autogen/interop/registry.py
Normal file
69
mm_agents/coact/autogen/interop/registry.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from .interoperable import Interoperable
|
||||
|
||||
__all__ = ["InteroperableRegistry", "register_interoperable_class"]
|
||||
|
||||
InteroperableClass = TypeVar("InteroperableClass", bound=type[Interoperable])
|
||||
|
||||
|
||||
class InteroperableRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._registry: dict[str, type[Interoperable]] = {}
|
||||
|
||||
def register(self, short_name: str, cls: InteroperableClass) -> InteroperableClass:
|
||||
if short_name in self._registry:
|
||||
raise ValueError(f"Duplicate registration for {short_name}")
|
||||
|
||||
self._registry[short_name] = cls
|
||||
|
||||
return cls
|
||||
|
||||
def get_short_names(self) -> list[str]:
|
||||
return sorted(self._registry.keys())
|
||||
|
||||
def get_supported_types(self) -> list[str]:
|
||||
short_names = self.get_short_names()
|
||||
supported_types = [name for name in short_names if self._registry[name].get_unsupported_reason() is None]
|
||||
return supported_types
|
||||
|
||||
def get_class(self, short_name: str) -> type[Interoperable]:
|
||||
return self._registry[short_name]
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "InteroperableRegistry":
|
||||
return _register
|
||||
|
||||
|
||||
# global registry
|
||||
_register = InteroperableRegistry()
|
||||
|
||||
|
||||
# register decorator
|
||||
@export_module("autogen.interop")
|
||||
def register_interoperable_class(short_name: str) -> Callable[[InteroperableClass], InteroperableClass]:
|
||||
"""Register an Interoperable class in the global registry.
|
||||
|
||||
Returns:
|
||||
Callable[[InteroperableClass], InteroperableClass]: Decorator function
|
||||
|
||||
Example:
|
||||
```python
|
||||
@register_interoperable_class("myinterop")
|
||||
class MyInteroperability(Interoperable):
|
||||
def convert_tool(self, tool: Any) -> Tool:
|
||||
# implementation
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def inner(cls: InteroperableClass) -> InteroperableClass:
|
||||
global _register
|
||||
return _register.register(short_name, cls)
|
||||
|
||||
return inner
|
||||
Reference in New Issue
Block a user