214 lines
7.0 KiB
Python
Executable File
214 lines
7.0 KiB
Python
Executable File
"""
|
|
LLM Tools Module
|
|
|
|
This module provides decorators and utilities for defining, registering, and managing LLM tools.
|
|
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
|
|
registered tools for use with LLM APIs.
|
|
|
|
"""
|
|
|
|
import asyncio
|
|
import inspect
|
|
import json
|
|
from functools import wraps
|
|
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
|
|
import docstring_parser
|
|
from pydantic import BaseModel, create_model, Field
|
|
|
|
# Registry to store all registered tools
|
|
_TOOL_REGISTRY = {}
|
|
|
|
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
|
|
"""
|
|
Decorator to mark a function as an LLM tool.
|
|
|
|
This decorator registers the function as an LLM tool, generates a JSON schema for it,
|
|
and makes it available for retrieval through the get_tools function.
|
|
|
|
Args:
|
|
name: Optional custom name for the tool. If not provided, the function name will be used.
|
|
description: Optional custom description for the tool. If not provided, the function's
|
|
docstring will be used.
|
|
|
|
Returns:
|
|
The decorated function with additional attributes for LLM tool functionality.
|
|
|
|
Example:
|
|
@llm_tool(name="weather_lookup", description="Get current weather for a location")
|
|
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
|
|
'''Get weather information for a specific location.'''
|
|
# Implementation...
|
|
return {"temperature": 22.5, "conditions": "sunny"}
|
|
"""
|
|
# Handle case when decorator is used without parentheses: @llm_tool
|
|
if callable(name):
|
|
func = name
|
|
name = None
|
|
description = None
|
|
return _llm_tool_impl(func, name, description)
|
|
|
|
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
|
|
def decorator(func: Callable) -> Callable:
|
|
return _llm_tool_impl(func, name, description)
|
|
|
|
return decorator
|
|
|
|
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
|
|
"""Implementation of the llm_tool decorator."""
|
|
# Get function signature and docstring
|
|
sig = inspect.signature(func)
|
|
doc = inspect.getdoc(func) or ""
|
|
parsed_doc = docstring_parser.parse(doc)
|
|
|
|
# Determine tool name
|
|
tool_name = name or func.__name__
|
|
|
|
# Determine tool description
|
|
tool_description = description or doc
|
|
|
|
# Create parameter properties for JSON schema
|
|
properties = {}
|
|
required = []
|
|
|
|
for param_name, param in sig.parameters.items():
|
|
# Skip self parameter for methods
|
|
if param_name == "self":
|
|
continue
|
|
|
|
param_type = param.annotation
|
|
param_default = None if param.default is inspect.Parameter.empty else param.default
|
|
param_required = param.default is inspect.Parameter.empty
|
|
|
|
# Get parameter description from docstring if available
|
|
param_desc = ""
|
|
for param_doc in parsed_doc.params:
|
|
if param_doc.arg_name == param_name:
|
|
param_desc = param_doc.description
|
|
break
|
|
|
|
# Handle Annotated types
|
|
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
|
args = get_args(param_type)
|
|
param_type = args[0] # The actual type
|
|
if len(args) > 1 and isinstance(args[1], str):
|
|
param_desc = args[1] # The description
|
|
|
|
# Create property for parameter
|
|
param_schema = {
|
|
"type": _get_json_type(param_type),
|
|
"description": param_desc,
|
|
"title": param_name.replace("_", " ").title()
|
|
}
|
|
|
|
# Add default value if available
|
|
if param_default is not None:
|
|
param_schema["default"] = param_default
|
|
|
|
properties[param_name] = param_schema
|
|
|
|
# Add to required list if no default value
|
|
if param_required:
|
|
required.append(param_name)
|
|
|
|
# Create JSON schema
|
|
schema = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool_name,
|
|
"description": tool_description,
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": properties,
|
|
"required": required
|
|
}
|
|
}
|
|
}
|
|
|
|
# Create Pydantic model for args schema
|
|
field_definitions = {}
|
|
for param_name, param in sig.parameters.items():
|
|
if param_name == "self":
|
|
continue
|
|
|
|
param_type = param.annotation
|
|
param_default = ... if param.default is inspect.Parameter.empty else param.default
|
|
|
|
# Handle Annotated types
|
|
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
|
args = get_args(param_type)
|
|
param_type = args[0]
|
|
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
|
|
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
|
|
else:
|
|
field_definitions[param_name] = (param_type, Field(default=param_default))
|
|
|
|
# Create args schema model
|
|
model_name = f"{tool_name.title().replace('_', '')}Schema"
|
|
args_schema = create_model(model_name, **field_definitions)
|
|
|
|
# 根据原始函数是否是异步函数来创建相应类型的包装函数
|
|
if asyncio.iscoroutinefunction(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
return await func(*args, **kwargs)
|
|
else:
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
# Attach metadata to function
|
|
wrapper.is_llm_tool = True
|
|
wrapper.tool_name = tool_name
|
|
wrapper.tool_description = tool_description
|
|
wrapper.json_schema = schema
|
|
wrapper.args_schema = args_schema
|
|
|
|
# Register the tool
|
|
_TOOL_REGISTRY[tool_name] = wrapper
|
|
|
|
return wrapper
|
|
|
|
def get_tools() -> Dict[str, Callable]:
|
|
"""
|
|
Get all registered LLM tools.
|
|
|
|
Returns:
|
|
A dictionary mapping tool names to their corresponding functions.
|
|
"""
|
|
return _TOOL_REGISTRY
|
|
|
|
def get_tool_schemas() -> List[Dict[str, Any]]:
|
|
"""
|
|
Get JSON schemas for all registered LLM tools.
|
|
|
|
Returns:
|
|
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
|
|
"""
|
|
return [tool.json_schema for tool in _TOOL_REGISTRY.values()]
|
|
|
|
def _get_json_type(python_type: Any) -> str:
|
|
"""
|
|
Convert Python type to JSON schema type.
|
|
|
|
Args:
|
|
python_type: Python type annotation
|
|
|
|
Returns:
|
|
Corresponding JSON schema type as string
|
|
"""
|
|
if python_type is str:
|
|
return "string"
|
|
elif python_type is int:
|
|
return "integer"
|
|
elif python_type is float:
|
|
return "number"
|
|
elif python_type is bool:
|
|
return "boolean"
|
|
elif python_type is list or python_type is List:
|
|
return "array"
|
|
elif python_type is dict or python_type is Dict:
|
|
return "object"
|
|
else:
|
|
# Default to string for complex types
|
|
return "string"
|