316 lines
10 KiB
Python
Executable File
316 lines
10 KiB
Python
Executable File
"""
|
|
LLM Tools Module
|
|
|
|
This module provides decorators and utilities for defining, registering, and managing LLM tools.
|
|
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
|
|
registered tools for use with LLM APIs.
|
|
|
|
"""
|
|
|
|
import asyncio
|
|
import inspect
|
|
|
|
import importlib
|
|
import pkgutil
|
|
import os
|
|
from functools import wraps
|
|
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
|
|
import docstring_parser
|
|
from pydantic import BaseModel, create_model, Field
|
|
|
|
# Registry to store all registered tools
|
|
_TOOL_REGISTRY = {}
|
|
# Mapping of domain names to their module paths
|
|
_DOMAIN_MODULE_MAPPING = {
|
|
'material': 'sci_mcp.material_mcp',
|
|
'general': 'sci_mcp.general_mcp',
|
|
'biology': 'sci_mcp.biology_mcp',
|
|
'chemistry': 'sci_mcp.chemistry_mcp'
|
|
}
|
|
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
|
|
"""
|
|
Decorator to mark a function as an LLM tool.
|
|
|
|
This decorator registers the function as an LLM tool, generates a JSON schema for it,
|
|
and makes it available for retrieval through the get_tools function.
|
|
|
|
Args:
|
|
name: Optional custom name for the tool. If not provided, the function name will be used.
|
|
description: Optional custom description for the tool. If not provided, the function's
|
|
docstring will be used.
|
|
|
|
Returns:
|
|
The decorated function with additional attributes for LLM tool functionality.
|
|
|
|
Example:
|
|
@llm_tool(name="weather_lookup", description="Get current weather for a location")
|
|
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
|
|
'''Get weather information for a specific location.'''
|
|
# Implementation...
|
|
return {"temperature": 22.5, "conditions": "sunny"}
|
|
"""
|
|
# Handle case when decorator is used without parentheses: @llm_tool
|
|
if callable(name):
|
|
func = name
|
|
name = None
|
|
description = None
|
|
return _llm_tool_impl(func, name, description)
|
|
|
|
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
|
|
def decorator(func: Callable) -> Callable:
|
|
return _llm_tool_impl(func, name, description)
|
|
|
|
return decorator
|
|
|
|
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
|
|
"""Implementation of the llm_tool decorator."""
|
|
# Get function signature and docstring
|
|
sig = inspect.signature(func)
|
|
doc = inspect.getdoc(func) or ""
|
|
parsed_doc = docstring_parser.parse(doc)
|
|
|
|
# Determine tool name
|
|
tool_name = name or func.__name__
|
|
|
|
# Determine tool description
|
|
tool_description = description or doc
|
|
|
|
# Create parameter properties for JSON schema
|
|
properties = {}
|
|
required = []
|
|
|
|
for param_name, param in sig.parameters.items():
|
|
# Skip self parameter for methods
|
|
if param_name == "self":
|
|
continue
|
|
|
|
param_type = param.annotation
|
|
param_default = None if param.default is inspect.Parameter.empty else param.default
|
|
param_required = param.default is inspect.Parameter.empty
|
|
|
|
# Get parameter description from docstring if available
|
|
param_desc = ""
|
|
for param_doc in parsed_doc.params:
|
|
if param_doc.arg_name == param_name:
|
|
param_desc = param_doc.description
|
|
break
|
|
|
|
# Handle Annotated types
|
|
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
|
args = get_args(param_type)
|
|
param_type = args[0] # The actual type
|
|
if len(args) > 1 and isinstance(args[1], str):
|
|
param_desc = args[1] # The description
|
|
|
|
# Create property for parameter
|
|
param_schema = {
|
|
"type": _get_json_type(param_type),
|
|
"description": param_desc,
|
|
"title": param_name.replace("_", " ").title()
|
|
}
|
|
|
|
# Add default value if available
|
|
if param_default is not None:
|
|
param_schema["default"] = param_default
|
|
|
|
properties[param_name] = param_schema
|
|
|
|
# Add to required list if no default value
|
|
if param_required:
|
|
required.append(param_name)
|
|
|
|
# Create OpenAI format JSON schema
|
|
openai_schema = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool_name,
|
|
"description": tool_description,
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": properties,
|
|
"required": required
|
|
}
|
|
}
|
|
}
|
|
|
|
# Create MCP format JSON schema
|
|
mcp_schema = {
|
|
"name": tool_name,
|
|
"description": tool_description,
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": properties,
|
|
"required": required
|
|
}
|
|
}
|
|
|
|
# Create Pydantic model for args schema
|
|
field_definitions = {}
|
|
for param_name, param in sig.parameters.items():
|
|
if param_name == "self":
|
|
continue
|
|
|
|
param_type = param.annotation
|
|
param_default = ... if param.default is inspect.Parameter.empty else param.default
|
|
|
|
# Handle Annotated types
|
|
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
|
args = get_args(param_type)
|
|
param_type = args[0]
|
|
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
|
|
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
|
|
else:
|
|
field_definitions[param_name] = (param_type, Field(default=param_default))
|
|
|
|
# Create args schema model
|
|
model_name = f"{tool_name.title().replace('_', '')}Schema"
|
|
args_schema = create_model(model_name, **field_definitions)
|
|
|
|
# 根据原始函数是否是异步函数来创建相应类型的包装函数
|
|
if asyncio.iscoroutinefunction(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
return await func(*args, **kwargs)
|
|
else:
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
# Attach metadata to function
|
|
wrapper.is_llm_tool = True
|
|
wrapper.tool_name = tool_name
|
|
wrapper.tool_description = tool_description
|
|
wrapper.openai_schema = openai_schema
|
|
wrapper.mcp_schema = mcp_schema
|
|
wrapper.args_schema = args_schema
|
|
|
|
# Register the tool
|
|
_TOOL_REGISTRY[tool_name] = wrapper
|
|
|
|
return wrapper
|
|
|
|
def get_all_tools() -> Dict[str, Callable]:
|
|
"""
|
|
Get all registered LLM tools.
|
|
|
|
Returns:
|
|
A dictionary mapping tool names to their corresponding functions.
|
|
"""
|
|
return _TOOL_REGISTRY
|
|
|
|
def get_all_tool_schemas(schema_type='openai') -> List[Dict[str, Any]]:
|
|
"""
|
|
Get JSON schemas for all registered LLM tools.
|
|
|
|
Returns:
|
|
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
|
|
"""
|
|
return [tool.mcp_schema for tool in _TOOL_REGISTRY.values()] if schema_type == 'mcp' else [tool.openai_schema for tool in _TOOL_REGISTRY.values()]
|
|
|
|
def import_domain_tools(domains: List[str]) -> None:
|
|
"""
|
|
Import tools from specified domains to ensure they are registered.
|
|
|
|
This function dynamically imports modules from the specified domains to ensure
|
|
that all tools decorated with @llm_tool are registered in the _TOOL_REGISTRY.
|
|
|
|
Args:
|
|
domains: List of domain names (e.g., ['material', 'general'])
|
|
"""
|
|
for domain in domains:
|
|
if domain not in _DOMAIN_MODULE_MAPPING:
|
|
continue
|
|
|
|
module_path = _DOMAIN_MODULE_MAPPING[domain]
|
|
try:
|
|
# Import the base module
|
|
base_module = importlib.import_module(module_path)
|
|
base_path = os.path.dirname(base_module.__file__)
|
|
|
|
# Recursively import all submodules
|
|
for _, name, is_pkg in pkgutil.walk_packages([base_path], f"{module_path}."):
|
|
try:
|
|
importlib.import_module(name)
|
|
except ImportError as e:
|
|
print(f"Error importing {name}: {e}")
|
|
except ImportError as e:
|
|
print(f"Error importing domain {domain}: {e}")
|
|
|
|
def get_domain_tools(domains: List[str]) -> Dict[str, Dict[str, Callable]]:
|
|
"""
|
|
Get tools from specified domains.
|
|
|
|
Args:
|
|
domains: List of domain names (e.g., ['material', 'general'])
|
|
|
|
Returns:
|
|
A dictionary that maps tool names and their functions
|
|
"""
|
|
# First, ensure all tools from the specified domains are imported and registered
|
|
import_domain_tools(domains)
|
|
|
|
domain_tools = {}
|
|
for domain in domains:
|
|
if domain not in _DOMAIN_MODULE_MAPPING:
|
|
continue
|
|
|
|
domain_module_prefix = _DOMAIN_MODULE_MAPPING[domain]
|
|
|
|
|
|
for tool_name, tool_func in _TOOL_REGISTRY.items():
|
|
# Check if the tool's module belongs to this domain
|
|
if hasattr(tool_func, "__module__") and tool_func.__module__.startswith(domain_module_prefix):
|
|
domain_tools[tool_name] = tool_func
|
|
|
|
|
|
return domain_tools
|
|
|
|
def get_domain_tool_schemas(domains: List[str],schema_type='openai') -> Dict[str, List[Dict[str, Any]]]:
|
|
"""
|
|
Get JSON schemas for tools from specified domains.
|
|
|
|
Args:
|
|
domains: List of domain names (e.g., ['material', 'general'])
|
|
|
|
|
|
Returns:
|
|
A dictionary mapping domain names to lists of tool schemas
|
|
"""
|
|
# First, get all domain tools
|
|
domain_tools = get_domain_tools(domains)
|
|
|
|
if schema_type == 'mcp':
|
|
tools_schema_list = [tool.mcp_schema for tool in domain_tools.values()]
|
|
else:
|
|
tools_schema_list = [tool.openai_schema for tool in domain_tools.values()]
|
|
|
|
|
|
return tools_schema_list
|
|
|
|
|
|
def _get_json_type(python_type: Any) -> str:
|
|
"""
|
|
Convert Python type to JSON schema type.
|
|
|
|
Args:
|
|
python_type: Python type annotation
|
|
|
|
Returns:
|
|
Corresponding JSON schema type as string
|
|
"""
|
|
if python_type is str:
|
|
return "string"
|
|
elif python_type is int:
|
|
return "integer"
|
|
elif python_type is float:
|
|
return "number"
|
|
elif python_type is bool:
|
|
return "boolean"
|
|
elif python_type is list or python_type is List:
|
|
return "array"
|
|
elif python_type is dict or python_type is Dict:
|
|
return "object"
|
|
else:
|
|
# Default to string for complex types
|
|
return "string"
|