构建mars_toolkit,删除tools_for_ms
This commit is contained in:
213
mars_toolkit/core/llm_tools.py
Normal file
213
mars_toolkit/core/llm_tools.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
LLM Tools Module
|
||||
|
||||
This module provides decorators and utilities for defining, registering, and managing LLM tools.
|
||||
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
|
||||
registered tools for use with LLM APIs.
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
|
||||
import docstring_parser
|
||||
from pydantic import BaseModel, create_model, Field
|
||||
|
||||
# Registry to store all registered tools
|
||||
_TOOL_REGISTRY = {}
|
||||
|
||||
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
|
||||
"""
|
||||
Decorator to mark a function as an LLM tool.
|
||||
|
||||
This decorator registers the function as an LLM tool, generates a JSON schema for it,
|
||||
and makes it available for retrieval through the get_tools function.
|
||||
|
||||
Args:
|
||||
name: Optional custom name for the tool. If not provided, the function name will be used.
|
||||
description: Optional custom description for the tool. If not provided, the function's
|
||||
docstring will be used.
|
||||
|
||||
Returns:
|
||||
The decorated function with additional attributes for LLM tool functionality.
|
||||
|
||||
Example:
|
||||
@llm_tool(name="weather_lookup", description="Get current weather for a location")
|
||||
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
|
||||
'''Get weather information for a specific location.'''
|
||||
# Implementation...
|
||||
return {"temperature": 22.5, "conditions": "sunny"}
|
||||
"""
|
||||
# Handle case when decorator is used without parentheses: @llm_tool
|
||||
if callable(name):
|
||||
func = name
|
||||
name = None
|
||||
description = None
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
|
||||
def decorator(func: Callable) -> Callable:
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
return decorator
|
||||
|
||||
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
|
||||
"""Implementation of the llm_tool decorator."""
|
||||
# Get function signature and docstring
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func) or ""
|
||||
parsed_doc = docstring_parser.parse(doc)
|
||||
|
||||
# Determine tool name
|
||||
tool_name = name or func.__name__
|
||||
|
||||
# Determine tool description
|
||||
tool_description = description or doc
|
||||
|
||||
# Create parameter properties for JSON schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self parameter for methods
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = None if param.default is inspect.Parameter.empty else param.default
|
||||
param_required = param.default is inspect.Parameter.empty
|
||||
|
||||
# Get parameter description from docstring if available
|
||||
param_desc = ""
|
||||
for param_doc in parsed_doc.params:
|
||||
if param_doc.arg_name == param_name:
|
||||
param_desc = param_doc.description
|
||||
break
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0] # The actual type
|
||||
if len(args) > 1 and isinstance(args[1], str):
|
||||
param_desc = args[1] # The description
|
||||
|
||||
# Create property for parameter
|
||||
param_schema = {
|
||||
"type": _get_json_type(param_type),
|
||||
"description": param_desc,
|
||||
"title": param_name.replace("_", " ").title()
|
||||
}
|
||||
|
||||
# Add default value if available
|
||||
if param_default is not None:
|
||||
param_schema["default"] = param_default
|
||||
|
||||
properties[param_name] = param_schema
|
||||
|
||||
# Add to required list if no default value
|
||||
if param_required:
|
||||
required.append(param_name)
|
||||
|
||||
# Create JSON schema
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Create Pydantic model for args schema
|
||||
field_definitions = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = ... if param.default is inspect.Parameter.empty else param.default
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0]
|
||||
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
|
||||
else:
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default))
|
||||
|
||||
# Create args schema model
|
||||
model_name = f"{tool_name.title().replace('_', '')}Schema"
|
||||
args_schema = create_model(model_name, **field_definitions)
|
||||
|
||||
# 根据原始函数是否是异步函数来创建相应类型的包装函数
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Attach metadata to function
|
||||
wrapper.is_llm_tool = True
|
||||
wrapper.tool_name = tool_name
|
||||
wrapper.tool_description = tool_description
|
||||
wrapper.json_schema = schema
|
||||
wrapper.args_schema = args_schema
|
||||
|
||||
# Register the tool
|
||||
_TOOL_REGISTRY[tool_name] = wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_tools() -> Dict[str, Callable]:
|
||||
"""
|
||||
Get all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping tool names to their corresponding functions.
|
||||
"""
|
||||
return _TOOL_REGISTRY
|
||||
|
||||
def get_tool_schemas() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get JSON schemas for all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
|
||||
"""
|
||||
return [tool.json_schema for tool in _TOOL_REGISTRY.values()]
|
||||
|
||||
def _get_json_type(python_type: Any) -> str:
|
||||
"""
|
||||
Convert Python type to JSON schema type.
|
||||
|
||||
Args:
|
||||
python_type: Python type annotation
|
||||
|
||||
Returns:
|
||||
Corresponding JSON schema type as string
|
||||
"""
|
||||
if python_type is str:
|
||||
return "string"
|
||||
elif python_type is int:
|
||||
return "integer"
|
||||
elif python_type is float:
|
||||
return "number"
|
||||
elif python_type is bool:
|
||||
return "boolean"
|
||||
elif python_type is list or python_type is List:
|
||||
return "array"
|
||||
elif python_type is dict or python_type is Dict:
|
||||
return "object"
|
||||
else:
|
||||
# Default to string for complex types
|
||||
return "string"
|
||||
Reference in New Issue
Block a user