CoACT initialize (#292)
This commit is contained in:
411
mm_agents/coact/autogen/tools/function_utils.py
Normal file
411
mm_agents/coact/autogen/tools/function_utils.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
from logging import getLogger
|
||||
from typing import Annotated, Any, Callable, ForwardRef, Optional, TypeVar, Union
|
||||
|
||||
from packaging.version import parse
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import __version__ as pydantic_version
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from typing_extensions import Literal, get_args, get_origin
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from .dependency_injection import Field as AG2Field
|
||||
|
||||
if parse(pydantic_version) < parse("2.10.2"):
|
||||
from pydantic._internal._typing_extra import eval_type_lenient as try_eval_type
|
||||
else:
|
||||
from pydantic._internal._typing_extra import try_eval_type
|
||||
|
||||
|
||||
__all__ = ["get_function_schema", "load_basemodels_if_needed", "serialize_to_str"]
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||
"""Get the type annotation of a parameter.
|
||||
|
||||
Args:
|
||||
annotation: The annotation of the parameter
|
||||
globalns: The global namespace of the function
|
||||
|
||||
Returns:
|
||||
The type annotation of the parameter
|
||||
"""
|
||||
if isinstance(annotation, AG2Field):
|
||||
annotation = annotation.description
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation, _ = try_eval_type(annotation, globalns, globalns)
|
||||
return annotation
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""Get the signature of a function with type annotations.
|
||||
|
||||
Args:
|
||||
call: The function to get the signature for
|
||||
|
||||
Returns:
|
||||
The signature of the function with type annotations
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_typed_annotation(param.annotation, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
typed_signature = inspect.Signature(typed_params)
|
||||
return typed_signature
|
||||
|
||||
|
||||
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
"""Get the return annotation of a function.
|
||||
|
||||
Args:
|
||||
call: The function to get the return annotation for
|
||||
|
||||
Returns:
|
||||
The return annotation of the function
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
annotation = signature.return_annotation
|
||||
|
||||
if annotation is inspect.Signature.empty:
|
||||
return None
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
return get_typed_annotation(annotation, globalns)
|
||||
|
||||
|
||||
def get_param_annotations(typed_signature: inspect.Signature) -> dict[str, Union[Annotated[type[Any], str], type[Any]]]:
|
||||
"""Get the type annotations of the parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
|
||||
Returns:
|
||||
A dictionary of the type annotations of the parameters of the function
|
||||
"""
|
||||
return {
|
||||
k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty
|
||||
}
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
"""Parameters of a function as defined by the OpenAI API"""
|
||||
|
||||
type: Literal["object"] = "object"
|
||||
properties: dict[str, JsonSchemaValue]
|
||||
required: list[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""A function as defined by the OpenAI API"""
|
||||
|
||||
description: Annotated[str, Field(description="Description of the function")]
|
||||
name: Annotated[str, Field(description="Name of the function")]
|
||||
parameters: Annotated[Parameters, Field(description="Parameters of the function")]
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
"""A function under tool as defined by the OpenAI API."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: Annotated[Function, Field(description="Function under tool")]
|
||||
|
||||
|
||||
def get_parameter_json_schema(k: str, v: Any, default_values: dict[str, Any]) -> JsonSchemaValue:
|
||||
"""Get a JSON schema for a parameter as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
k: The name of the parameter
|
||||
v: The type of the parameter
|
||||
default_values: The default values of the parameters of the function
|
||||
|
||||
Returns:
|
||||
A Pydanitc model for the parameter
|
||||
"""
|
||||
|
||||
def type2description(k: str, v: Union[Annotated[type[Any], str], type[Any]]) -> str:
|
||||
if not hasattr(v, "__metadata__"):
|
||||
return k
|
||||
|
||||
# handles Annotated
|
||||
retval = v.__metadata__[0]
|
||||
if isinstance(retval, AG2Field):
|
||||
return retval.description # type: ignore[return-value]
|
||||
else:
|
||||
raise ValueError(f"Invalid {retval} for parameter {k}, should be a DescriptionField, got {type(retval)}")
|
||||
|
||||
schema = TypeAdapter(v).json_schema()
|
||||
if k in default_values:
|
||||
dv = default_values[k]
|
||||
schema["default"] = dv
|
||||
|
||||
schema["description"] = type2description(k, v)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_required_params(typed_signature: inspect.Signature) -> list[str]:
|
||||
"""Get the required parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A list of the required parameters of the function
|
||||
"""
|
||||
return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty]
|
||||
|
||||
|
||||
def get_default_values(typed_signature: inspect.Signature) -> dict[str, Any]:
|
||||
"""Get default values of parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A dictionary of the default values of the parameters of the function
|
||||
"""
|
||||
return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty}
|
||||
|
||||
|
||||
def get_parameters(
|
||||
required: list[str],
|
||||
param_annotations: dict[str, Union[Annotated[type[Any], str], type[Any]]],
|
||||
default_values: dict[str, Any],
|
||||
) -> Parameters:
|
||||
"""Get the parameters of a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
required: The required parameters of the function
|
||||
param_annotations: The type annotations of the parameters of the function
|
||||
default_values: The default values of the parameters of the function
|
||||
|
||||
Returns:
|
||||
A Pydantic model for the parameters of the function
|
||||
"""
|
||||
return Parameters(
|
||||
properties={
|
||||
k: get_parameter_json_schema(k, v, default_values)
|
||||
for k, v in param_annotations.items()
|
||||
if v is not inspect.Signature.empty
|
||||
},
|
||||
required=required,
|
||||
)
|
||||
|
||||
|
||||
def get_missing_annotations(typed_signature: inspect.Signature, required: list[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the missing annotations of a function
|
||||
|
||||
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
required: The required parameters of the function
|
||||
|
||||
Returns:
|
||||
A set of the missing annotations of the function
|
||||
"""
|
||||
all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty}
|
||||
missing = all_missing.intersection(set(required))
|
||||
unannotated_with_default = all_missing.difference(missing)
|
||||
return missing, unannotated_with_default
|
||||
|
||||
|
||||
@export_module("autogen.tools")
|
||||
def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> dict[str, Any]:
|
||||
"""Get a JSON schema for a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
f: The function to get the JSON schema for
|
||||
name: The name of the function
|
||||
description: The description of the function
|
||||
|
||||
Returns:
|
||||
A JSON schema for the function
|
||||
|
||||
Raises:
|
||||
TypeError: If the function is not annotated
|
||||
|
||||
Examples:
|
||||
```python
|
||||
def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None:
|
||||
pass
|
||||
|
||||
|
||||
get_function_schema(f, description="function f")
|
||||
|
||||
# {'type': 'function',
|
||||
# 'function': {'description': 'function f',
|
||||
# 'name': 'f',
|
||||
# 'parameters': {'type': 'object',
|
||||
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
|
||||
# 'b': {'type': 'int', 'description': 'b'},
|
||||
# 'c': {'type': 'float', 'description': 'Parameter c'}},
|
||||
# 'required': ['a']}}}
|
||||
```
|
||||
|
||||
"""
|
||||
typed_signature = get_typed_signature(f)
|
||||
required = get_required_params(typed_signature)
|
||||
default_values = get_default_values(typed_signature)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
return_annotation = get_typed_return_annotation(f)
|
||||
missing, unannotated_with_default = get_missing_annotations(typed_signature, required)
|
||||
|
||||
if return_annotation is None:
|
||||
logger.warning(
|
||||
f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is "
|
||||
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
|
||||
)
|
||||
|
||||
if unannotated_with_default != set():
|
||||
unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)]
|
||||
logger.warning(
|
||||
f"The following parameters of the function '{f.__name__}' with default values are not annotated: "
|
||||
+ f"{', '.join(unannotated_with_default_s)}."
|
||||
)
|
||||
|
||||
if missing != set():
|
||||
missing_s = [f"'{k}'" for k in sorted(missing)]
|
||||
raise TypeError(
|
||||
f"All parameters of the function '{f.__name__}' without default values must be annotated. "
|
||||
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
|
||||
)
|
||||
|
||||
fname = name if name else f.__name__
|
||||
|
||||
parameters = get_parameters(required, param_annotations, default_values=default_values)
|
||||
|
||||
function = ToolFunction(
|
||||
function=Function(
|
||||
description=description,
|
||||
name=fname,
|
||||
parameters=parameters,
|
||||
)
|
||||
)
|
||||
|
||||
return function.model_dump()
|
||||
|
||||
|
||||
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[dict[str, Any], type[BaseModel]], BaseModel]]:
|
||||
"""Get a function to load a parameter if it is a Pydantic model
|
||||
|
||||
Args:
|
||||
t: The type annotation of the parameter
|
||||
|
||||
Returns:
|
||||
A function to load the parameter if it is a Pydantic model, otherwise None
|
||||
|
||||
"""
|
||||
origin = get_origin(t)
|
||||
|
||||
if origin is Annotated:
|
||||
args = get_args(t)
|
||||
if args:
|
||||
return get_load_param_if_needed_function(args[0])
|
||||
else:
|
||||
# Invalid Annotated usage
|
||||
return None
|
||||
|
||||
# Handle generic types (list[str], dict[str,Any], Union[...], etc.) or where t is not a type at all
|
||||
# This means it's not a BaseModel subclass
|
||||
if origin is not None or not isinstance(t, type):
|
||||
return None
|
||||
|
||||
def load_base_model(v: dict[str, Any], model_type: type[BaseModel]) -> BaseModel:
|
||||
return model_type(**v)
|
||||
|
||||
# Check if it's a class and a subclass of BaseModel
|
||||
if issubclass(t, BaseModel):
|
||||
return load_base_model
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@export_module("autogen.tools")
|
||||
def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""A decorator to load the parameters of a function if they are Pydantic models
|
||||
|
||||
Args:
|
||||
func: The function with annotated parameters
|
||||
|
||||
Returns:
|
||||
A function that loads the parameters before calling the original function
|
||||
|
||||
"""
|
||||
# get the type annotations of the parameters
|
||||
typed_signature = get_typed_signature(func)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
|
||||
# get functions for loading BaseModels when needed based on the type annotations
|
||||
kwargs_mapping_with_nones = {k: get_load_param_if_needed_function(t) for k, t in param_annotations.items()}
|
||||
|
||||
# remove the None values
|
||||
kwargs_mapping = {k: f for k, f in kwargs_mapping_with_nones.items() if f is not None}
|
||||
|
||||
# a function that loads the parameters before calling the original function
|
||||
@functools.wraps(func)
|
||||
def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
|
||||
# load the BaseModels if needed
|
||||
for k, f in kwargs_mapping.items():
|
||||
kwargs[k] = f(kwargs[k], param_annotations[k])
|
||||
|
||||
# call the original function
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@functools.wraps(func)
|
||||
async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
|
||||
# load the BaseModels if needed
|
||||
for k, f in kwargs_mapping.items():
|
||||
kwargs[k] = f(kwargs[k], param_annotations[k])
|
||||
|
||||
# call the original function
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return _a_load_parameters_if_needed
|
||||
else:
|
||||
return _load_parameters_if_needed
|
||||
|
||||
|
||||
class _SerializableResult(BaseModel):
|
||||
result: Any
|
||||
|
||||
|
||||
@export_module("autogen.tools")
|
||||
def serialize_to_str(x: Any) -> str:
|
||||
if isinstance(x, str):
|
||||
return x
|
||||
if isinstance(x, BaseModel):
|
||||
return x.model_dump_json()
|
||||
|
||||
retval_model = _SerializableResult(result=x)
|
||||
try:
|
||||
return str(retval_model.model_dump()["result"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# try json.dumps() and then just return str(x) if that fails too
|
||||
try:
|
||||
return json.dumps(x, ensure_ascii=False)
|
||||
except Exception:
|
||||
return str(x)
|
||||
Reference in New Issue
Block a user