Files
multi_mcp/sci_mcp/core/llm_tools.py
2025-05-09 14:16:33 +08:00

316 lines
10 KiB
Python
Executable File

"""
LLM Tools Module
This module provides decorators and utilities for defining, registering, and managing LLM tools.
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
registered tools for use with LLM APIs.
"""
import asyncio
import inspect
import importlib
import pkgutil
import os
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
import docstring_parser
from pydantic import BaseModel, create_model, Field
# Registry to store all registered tools
_TOOL_REGISTRY = {}
# Mapping of domain names to their module paths
_DOMAIN_MODULE_MAPPING = {
'material': 'sci_mcp.material_mcp',
'general': 'sci_mcp.general_mcp',
'biology': 'sci_mcp.biology_mcp',
'chemistry': 'sci_mcp.chemistry_mcp'
}
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
"""
Decorator to mark a function as an LLM tool.
This decorator registers the function as an LLM tool, generates a JSON schema for it,
and makes it available for retrieval through the get_tools function.
Args:
name: Optional custom name for the tool. If not provided, the function name will be used.
description: Optional custom description for the tool. If not provided, the function's
docstring will be used.
Returns:
The decorated function with additional attributes for LLM tool functionality.
Example:
@llm_tool(name="weather_lookup", description="Get current weather for a location")
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
'''Get weather information for a specific location.'''
# Implementation...
return {"temperature": 22.5, "conditions": "sunny"}
"""
# Handle case when decorator is used without parentheses: @llm_tool
if callable(name):
func = name
name = None
description = None
return _llm_tool_impl(func, name, description)
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
def decorator(func: Callable) -> Callable:
return _llm_tool_impl(func, name, description)
return decorator
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
"""Implementation of the llm_tool decorator."""
# Get function signature and docstring
sig = inspect.signature(func)
doc = inspect.getdoc(func) or ""
parsed_doc = docstring_parser.parse(doc)
# Determine tool name
tool_name = name or func.__name__
# Determine tool description
tool_description = description or doc
# Create parameter properties for JSON schema
properties = {}
required = []
for param_name, param in sig.parameters.items():
# Skip self parameter for methods
if param_name == "self":
continue
param_type = param.annotation
param_default = None if param.default is inspect.Parameter.empty else param.default
param_required = param.default is inspect.Parameter.empty
# Get parameter description from docstring if available
param_desc = ""
for param_doc in parsed_doc.params:
if param_doc.arg_name == param_name:
param_desc = param_doc.description
break
# Handle Annotated types
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
args = get_args(param_type)
param_type = args[0] # The actual type
if len(args) > 1 and isinstance(args[1], str):
param_desc = args[1] # The description
# Create property for parameter
param_schema = {
"type": _get_json_type(param_type),
"description": param_desc,
"title": param_name.replace("_", " ").title()
}
# Add default value if available
if param_default is not None:
param_schema["default"] = param_default
properties[param_name] = param_schema
# Add to required list if no default value
if param_required:
required.append(param_name)
# Create OpenAI format JSON schema
openai_schema = {
"type": "function",
"function": {
"name": tool_name,
"description": tool_description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
}
# Create MCP format JSON schema
mcp_schema = {
"name": tool_name,
"description": tool_description,
"inputSchema": {
"type": "object",
"properties": properties,
"required": required
}
}
# Create Pydantic model for args schema
field_definitions = {}
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
param_type = param.annotation
param_default = ... if param.default is inspect.Parameter.empty else param.default
# Handle Annotated types
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
args = get_args(param_type)
param_type = args[0]
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
else:
field_definitions[param_name] = (param_type, Field(default=param_default))
# Create args schema model
model_name = f"{tool_name.title().replace('_', '')}Schema"
args_schema = create_model(model_name, **field_definitions)
# 根据原始函数是否是异步函数来创建相应类型的包装函数
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
else:
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
# Attach metadata to function
wrapper.is_llm_tool = True
wrapper.tool_name = tool_name
wrapper.tool_description = tool_description
wrapper.openai_schema = openai_schema
wrapper.mcp_schema = mcp_schema
wrapper.args_schema = args_schema
# Register the tool
_TOOL_REGISTRY[tool_name] = wrapper
return wrapper
def get_all_tools() -> Dict[str, Callable]:
"""
Get all registered LLM tools.
Returns:
A dictionary mapping tool names to their corresponding functions.
"""
return _TOOL_REGISTRY
def get_all_tool_schemas(schema_type='openai') -> List[Dict[str, Any]]:
"""
Get JSON schemas for all registered LLM tools.
Returns:
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
"""
return [tool.mcp_schema for tool in _TOOL_REGISTRY.values()] if schema_type == 'mcp' else [tool.openai_schema for tool in _TOOL_REGISTRY.values()]
def import_domain_tools(domains: List[str]) -> None:
"""
Import tools from specified domains to ensure they are registered.
This function dynamically imports modules from the specified domains to ensure
that all tools decorated with @llm_tool are registered in the _TOOL_REGISTRY.
Args:
domains: List of domain names (e.g., ['material', 'general'])
"""
for domain in domains:
if domain not in _DOMAIN_MODULE_MAPPING:
continue
module_path = _DOMAIN_MODULE_MAPPING[domain]
try:
# Import the base module
base_module = importlib.import_module(module_path)
base_path = os.path.dirname(base_module.__file__)
# Recursively import all submodules
for _, name, is_pkg in pkgutil.walk_packages([base_path], f"{module_path}."):
try:
importlib.import_module(name)
except ImportError as e:
print(f"Error importing {name}: {e}")
except ImportError as e:
print(f"Error importing domain {domain}: {e}")
def get_domain_tools(domains: List[str]) -> Dict[str, Dict[str, Callable]]:
"""
Get tools from specified domains.
Args:
domains: List of domain names (e.g., ['material', 'general'])
Returns:
A dictionary that maps tool names and their functions
"""
# First, ensure all tools from the specified domains are imported and registered
import_domain_tools(domains)
domain_tools = {}
for domain in domains:
if domain not in _DOMAIN_MODULE_MAPPING:
continue
domain_module_prefix = _DOMAIN_MODULE_MAPPING[domain]
for tool_name, tool_func in _TOOL_REGISTRY.items():
# Check if the tool's module belongs to this domain
if hasattr(tool_func, "__module__") and tool_func.__module__.startswith(domain_module_prefix):
domain_tools[tool_name] = tool_func
return domain_tools
def get_domain_tool_schemas(domains: List[str],schema_type='openai') -> Dict[str, List[Dict[str, Any]]]:
"""
Get JSON schemas for tools from specified domains.
Args:
domains: List of domain names (e.g., ['material', 'general'])
Returns:
A dictionary mapping domain names to lists of tool schemas
"""
# First, get all domain tools
domain_tools = get_domain_tools(domains)
if schema_type == 'mcp':
tools_schema_list = [tool.mcp_schema for tool in domain_tools.values()]
else:
tools_schema_list = [tool.openai_schema for tool in domain_tools.values()]
return tools_schema_list
def _get_json_type(python_type: Any) -> str:
"""
Convert Python type to JSON schema type.
Args:
python_type: Python type annotation
Returns:
Corresponding JSON schema type as string
"""
if python_type is str:
return "string"
elif python_type is int:
return "integer"
elif python_type is float:
return "number"
elif python_type is bool:
return "boolean"
elif python_type is list or python_type is List:
return "array"
elif python_type is dict or python_type is Dict:
return "object"
else:
# Default to string for complex types
return "string"