""" LLM Tools Module This module provides decorators and utilities for defining, registering, and managing LLM tools. It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving registered tools for use with LLM APIs. """ import asyncio import inspect import importlib import pkgutil import os from functools import wraps from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args import docstring_parser from pydantic import BaseModel, create_model, Field # Registry to store all registered tools _TOOL_REGISTRY = {} # Mapping of domain names to their module paths _DOMAIN_MODULE_MAPPING = { 'material': 'sci_mcp.material_mcp', 'general': 'sci_mcp.general_mcp', 'biology': 'sci_mcp.biology_mcp', 'chemistry': 'sci_mcp.chemistry_mcp' } def llm_tool(name: Optional[str] = None, description: Optional[str] = None): """ Decorator to mark a function as an LLM tool. This decorator registers the function as an LLM tool, generates a JSON schema for it, and makes it available for retrieval through the get_tools function. Args: name: Optional custom name for the tool. If not provided, the function name will be used. description: Optional custom description for the tool. If not provided, the function's docstring will be used. Returns: The decorated function with additional attributes for LLM tool functionality. Example: @llm_tool(name="weather_lookup", description="Get current weather for a location") def get_weather(location: str, units: str = "metric") -> Dict[str, Any]: '''Get weather information for a specific location.''' # Implementation... return {"temperature": 22.5, "conditions": "sunny"} """ # Handle case when decorator is used without parentheses: @llm_tool if callable(name): func = name name = None description = None return _llm_tool_impl(func, name, description) # Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz") def decorator(func: Callable) -> Callable: return _llm_tool_impl(func, name, description) return decorator def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable: """Implementation of the llm_tool decorator.""" # Get function signature and docstring sig = inspect.signature(func) doc = inspect.getdoc(func) or "" parsed_doc = docstring_parser.parse(doc) # Determine tool name tool_name = name or func.__name__ # Determine tool description tool_description = description or doc # Create parameter properties for JSON schema properties = {} required = [] for param_name, param in sig.parameters.items(): # Skip self parameter for methods if param_name == "self": continue param_type = param.annotation param_default = None if param.default is inspect.Parameter.empty else param.default param_required = param.default is inspect.Parameter.empty # Get parameter description from docstring if available param_desc = "" for param_doc in parsed_doc.params: if param_doc.arg_name == param_name: param_desc = param_doc.description break # Handle Annotated types if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated": args = get_args(param_type) param_type = args[0] # The actual type if len(args) > 1 and isinstance(args[1], str): param_desc = args[1] # The description # Create property for parameter param_schema = { "type": _get_json_type(param_type), "description": param_desc, "title": param_name.replace("_", " ").title() } # Add default value if available if param_default is not None: param_schema["default"] = param_default properties[param_name] = param_schema # Add to required list if no default value if param_required: required.append(param_name) # Create OpenAI format JSON schema openai_schema = { "type": "function", "function": { "name": tool_name, "description": tool_description, "parameters": { "type": "object", "properties": properties, "required": required } } } # Create MCP format JSON schema mcp_schema = { "name": tool_name, "description": tool_description, "inputSchema": { "type": "object", "properties": properties, "required": required } } # Create Pydantic model for args schema field_definitions = {} for param_name, param in sig.parameters.items(): if param_name == "self": continue param_type = param.annotation param_default = ... if param.default is inspect.Parameter.empty else param.default # Handle Annotated types if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated": args = get_args(param_type) param_type = args[0] description = args[1] if len(args) > 1 and isinstance(args[1], str) else "" field_definitions[param_name] = (param_type, Field(default=param_default, description=description)) else: field_definitions[param_name] = (param_type, Field(default=param_default)) # Create args schema model model_name = f"{tool_name.title().replace('_', '')}Schema" args_schema = create_model(model_name, **field_definitions) # 根据原始函数是否是异步函数来创建相应类型的包装函数 if asyncio.iscoroutinefunction(func): @wraps(func) async def wrapper(*args, **kwargs): return await func(*args, **kwargs) else: @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) # Attach metadata to function wrapper.is_llm_tool = True wrapper.tool_name = tool_name wrapper.tool_description = tool_description wrapper.openai_schema = openai_schema wrapper.mcp_schema = mcp_schema wrapper.args_schema = args_schema # Register the tool _TOOL_REGISTRY[tool_name] = wrapper return wrapper def get_all_tools() -> Dict[str, Callable]: """ Get all registered LLM tools. Returns: A dictionary mapping tool names to their corresponding functions. """ return _TOOL_REGISTRY def get_all_tool_schemas(schema_type='openai') -> List[Dict[str, Any]]: """ Get JSON schemas for all registered LLM tools. Returns: A list of JSON schemas for all registered tools, suitable for use with LLM APIs. """ return [tool.mcp_schema for tool in _TOOL_REGISTRY.values()] if schema_type == 'mcp' else [tool.openai_schema for tool in _TOOL_REGISTRY.values()] def import_domain_tools(domains: List[str]) -> None: """ Import tools from specified domains to ensure they are registered. This function dynamically imports modules from the specified domains to ensure that all tools decorated with @llm_tool are registered in the _TOOL_REGISTRY. Args: domains: List of domain names (e.g., ['material', 'general']) """ for domain in domains: if domain not in _DOMAIN_MODULE_MAPPING: continue module_path = _DOMAIN_MODULE_MAPPING[domain] try: # Import the base module base_module = importlib.import_module(module_path) base_path = os.path.dirname(base_module.__file__) # Recursively import all submodules for _, name, is_pkg in pkgutil.walk_packages([base_path], f"{module_path}."): try: importlib.import_module(name) except ImportError as e: print(f"Error importing {name}: {e}") except ImportError as e: print(f"Error importing domain {domain}: {e}") def get_domain_tools(domains: List[str]) -> Dict[str, Dict[str, Callable]]: """ Get tools from specified domains. Args: domains: List of domain names (e.g., ['material', 'general']) Returns: A dictionary that maps tool names and their functions """ # First, ensure all tools from the specified domains are imported and registered import_domain_tools(domains) domain_tools = {} for domain in domains: if domain not in _DOMAIN_MODULE_MAPPING: continue domain_module_prefix = _DOMAIN_MODULE_MAPPING[domain] for tool_name, tool_func in _TOOL_REGISTRY.items(): # Check if the tool's module belongs to this domain if hasattr(tool_func, "__module__") and tool_func.__module__.startswith(domain_module_prefix): domain_tools[tool_name] = tool_func return domain_tools def get_domain_tool_schemas(domains: List[str],schema_type='openai') -> Dict[str, List[Dict[str, Any]]]: """ Get JSON schemas for tools from specified domains. Args: domains: List of domain names (e.g., ['material', 'general']) Returns: A dictionary mapping domain names to lists of tool schemas """ # First, get all domain tools domain_tools = get_domain_tools(domains) if schema_type == 'mcp': tools_schema_list = [tool.mcp_schema for tool in domain_tools.values()] else: tools_schema_list = [tool.openai_schema for tool in domain_tools.values()] return tools_schema_list def _get_json_type(python_type: Any) -> str: """ Convert Python type to JSON schema type. Args: python_type: Python type annotation Returns: Corresponding JSON schema type as string """ if python_type is str: return "string" elif python_type is int: return "integer" elif python_type is float: return "number" elif python_type is bool: return "boolean" elif python_type is list or python_type is List: return "array" elif python_type is dict or python_type is Dict: return "object" else: # Default to string for complex types return "string"