CoACT initialize (#292)
This commit is contained in:
8
mm_agents/coact/autogen/interop/langchain/__init__.py
Normal file
8
mm_agents/coact/autogen/interop/langchain/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .langchain_chat_model_factory import LangChainChatModelFactory
|
||||
from .langchain_tool import LangChainInteroperability
|
||||
|
||||
__all__ = ["LangChainChatModelFactory", "LangChainInteroperability"]
|
||||
@@ -0,0 +1,155 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from ...import_utils import optional_import_block, require_optional_import
|
||||
from ...llm_config import LLMConfig
|
||||
from ...oai import get_first_llm_config
|
||||
|
||||
with optional_import_block():
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
|
||||
|
||||
__all__ = ["LangChainChatModelFactory"]
|
||||
|
||||
T = TypeVar("T", bound="LangChainChatModelFactory")
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
["langchain_anthropic", "langchain_google_genai", "langchain_ollama", "langchain_openai", "langchain_core"],
|
||||
"browser-use",
|
||||
except_for=["__init__", "register_factory"],
|
||||
)
|
||||
@export_module("autogen.interop")
|
||||
class LangChainChatModelFactory(ABC):
|
||||
_factories: set["LangChainChatModelFactory"] = set()
|
||||
|
||||
@classmethod
|
||||
def create_base_chat_model(cls, llm_config: Union[LLMConfig, dict[str, Any]]) -> "BaseChatModel": # type: ignore [no-any-unimported]
|
||||
first_llm_config = get_first_llm_config(llm_config)
|
||||
for factory in LangChainChatModelFactory._factories:
|
||||
if factory.accepts(first_llm_config):
|
||||
return factory.create(first_llm_config)
|
||||
|
||||
raise ValueError("Could not find a factory for the given config.")
|
||||
|
||||
@classmethod
|
||||
def register_factory(cls) -> Callable[[type[T]], type[T]]:
|
||||
def decorator(factory: type[T]) -> type[T]:
|
||||
cls._factories.add(factory())
|
||||
return factory
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def prepare_config(cls, first_llm_config: dict[str, Any]) -> dict[str, Any]:
|
||||
for pop_keys in ["api_type", "response_format"]:
|
||||
first_llm_config.pop(pop_keys, None)
|
||||
return first_llm_config
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "BaseChatModel": # type: ignore [no-any-unimported]
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_api_type(cls) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def accepts(cls, first_llm_config: dict[str, Any]) -> bool:
|
||||
return first_llm_config.get("api_type", "openai") == cls.get_api_type() # type: ignore [no-any-return]
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class ChatOpenAIFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatOpenAI": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
|
||||
return ChatOpenAI(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "openai"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class DeepSeekFactory(ChatOpenAIFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatOpenAI": # type: ignore [no-any-unimported]
|
||||
if "base_url" not in first_llm_config:
|
||||
raise ValueError("base_url is required for deepseek api type.")
|
||||
return super().create(first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "deepseek"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class ChatAnthropicFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatAnthropic": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
|
||||
return ChatAnthropic(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "anthropic"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class ChatGoogleGenerativeAIFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatGoogleGenerativeAI": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
|
||||
return ChatGoogleGenerativeAI(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "google"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class AzureChatOpenAIFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "AzureChatOpenAI": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
for param in ["base_url", "api_version"]:
|
||||
if param not in first_llm_config:
|
||||
raise ValueError(f"{param} is required for azure api type.")
|
||||
first_llm_config["azure_endpoint"] = first_llm_config.pop("base_url")
|
||||
|
||||
return AzureChatOpenAI(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "azure"
|
||||
|
||||
|
||||
@LangChainChatModelFactory.register_factory()
|
||||
class ChatOllamaFactory(LangChainChatModelFactory):
|
||||
@classmethod
|
||||
def create(cls, first_llm_config: dict[str, Any]) -> "ChatOllama": # type: ignore [no-any-unimported]
|
||||
first_llm_config = cls.prepare_config(first_llm_config)
|
||||
first_llm_config["base_url"] = first_llm_config.pop("client_host", None)
|
||||
if "num_ctx" not in first_llm_config:
|
||||
# In all Browser Use examples, num_ctx is set to 32000
|
||||
first_llm_config["num_ctx"] = 32000
|
||||
|
||||
return ChatOllama(**first_llm_config)
|
||||
|
||||
@classmethod
|
||||
def get_api_type(cls) -> str:
|
||||
return "ollama"
|
||||
82
mm_agents/coact/autogen/interop/langchain/langchain_tool.py
Normal file
82
mm_agents/coact/autogen/interop/langchain/langchain_tool.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from ...import_utils import optional_import_block, require_optional_import
|
||||
from ...tools import Tool
|
||||
from ..registry import register_interoperable_class
|
||||
|
||||
__all__ = ["LangChainInteroperability"]
|
||||
|
||||
with optional_import_block():
|
||||
from langchain_core.tools import BaseTool as LangchainTool
|
||||
|
||||
|
||||
@register_interoperable_class("langchain")
|
||||
@export_module("autogen.interop")
|
||||
class LangChainInteroperability:
|
||||
"""A class implementing the `Interoperable` protocol for converting Langchain tools
|
||||
into a general `Tool` format.
|
||||
|
||||
This class takes a `LangchainTool` and converts it into a standard `Tool` object,
|
||||
ensuring compatibility between Langchain tools and other systems that expect
|
||||
the `Tool` format.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@require_optional_import("langchain_core", "interop-langchain")
|
||||
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
|
||||
"""Converts a given Langchain tool into a general `Tool` format.
|
||||
|
||||
This method verifies that the provided tool is a valid `LangchainTool`,
|
||||
processes the tool's input and description, and returns a standardized
|
||||
`Tool` object.
|
||||
|
||||
Args:
|
||||
tool (Any): The tool to convert, expected to be an instance of `LangchainTool`.
|
||||
**kwargs (Any): Additional arguments, which are not supported by this method.
|
||||
|
||||
Returns:
|
||||
Tool: A standardized `Tool` object converted from the Langchain tool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided tool is not an instance of `LangchainTool`, or if
|
||||
any additional arguments are passed.
|
||||
"""
|
||||
if not isinstance(tool, LangchainTool):
|
||||
raise ValueError(f"Expected an instance of `langchain_core.tools.BaseTool`, got {type(tool)}")
|
||||
if kwargs:
|
||||
raise ValueError(f"The LangchainInteroperability does not support any additional arguments, got {kwargs}")
|
||||
|
||||
# needed for type checking
|
||||
langchain_tool: LangchainTool = tool # type: ignore[no-any-unimported]
|
||||
|
||||
model_type = langchain_tool.get_input_schema()
|
||||
|
||||
def func(tool_input: model_type) -> Any: # type: ignore[valid-type]
|
||||
return langchain_tool.run(tool_input.model_dump()) # type: ignore[attr-defined]
|
||||
|
||||
return Tool(
|
||||
name=langchain_tool.name,
|
||||
description=langchain_tool.description,
|
||||
func_or_tool=func,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_unsupported_reason(cls) -> Optional[str]:
|
||||
if sys.version_info < (3, 9):
|
||||
return "This submodule is only supported for Python versions 3.9 and above"
|
||||
|
||||
with optional_import_block() as result:
|
||||
import langchain_core.tools # noqa: F401
|
||||
|
||||
if not result.is_successful:
|
||||
return (
|
||||
"Please install `interop-langchain` extra to use this module:\n\n\tpip install ag2[interop-langchain]"
|
||||
)
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user