CoACT initialize (#292)
This commit is contained in:
169
mm_agents/coact/autogen/oai/client_utils.py
Normal file
169
mm_agents/coact/autogen/oai/client_utils.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Utilities for client classes"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Optional, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class FormatterProtocol(Protocol):
|
||||
"""Structured Output classes with a format method"""
|
||||
|
||||
def format(self) -> str: ...
|
||||
|
||||
|
||||
def validate_parameter(
|
||||
params: dict[str, Any],
|
||||
param_name: str,
|
||||
allowed_types: tuple[Any, ...],
|
||||
allow_None: bool, # noqa: N803
|
||||
default_value: Any,
|
||||
numerical_bound: Optional[tuple[Optional[float], Optional[float]]],
|
||||
allowed_values: Optional[list[Any]],
|
||||
) -> Any:
|
||||
"""Validates a given config parameter, checking its type, values, and setting defaults
|
||||
Parameters:
|
||||
params (Dict[str, Any]): Dictionary containing parameters to validate.
|
||||
param_name (str): The name of the parameter to validate.
|
||||
allowed_types (Tuple): Tuple of acceptable types for the parameter.
|
||||
allow_None (bool): Whether the parameter can be `None`.
|
||||
default_value (Any): The default value to use if the parameter is invalid or missing.
|
||||
numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]):
|
||||
A tuple specifying the lower and upper bounds for numerical parameters.
|
||||
Each bound can be `None` if not applicable.
|
||||
allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter.
|
||||
Can be `None` if no specific values are required.
|
||||
|
||||
Returns:
|
||||
Any: The validated parameter value or the default value if validation fails.
|
||||
|
||||
Raises:
|
||||
TypeError: If `allowed_values` is provided but is not a list.
|
||||
|
||||
Example Usage:
|
||||
```python
|
||||
# Validating a numerical parameter within specific bounds
|
||||
params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"}
|
||||
temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None)
|
||||
# Result: 0.5
|
||||
|
||||
# Validating a parameter that can be one of a list of allowed values
|
||||
model = validate_parameter(
|
||||
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
|
||||
)
|
||||
# If "safety_model" is missing or invalid in params, defaults to "default"
|
||||
```
|
||||
"""
|
||||
if allowed_values is not None and not isinstance(allowed_values, list):
|
||||
raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}")
|
||||
|
||||
param_value = params.get(param_name, default_value)
|
||||
warning = ""
|
||||
|
||||
if param_value is None and allow_None:
|
||||
pass
|
||||
elif param_value is None:
|
||||
if not allow_None:
|
||||
warning = "cannot be None"
|
||||
elif not isinstance(param_value, allowed_types):
|
||||
# Check types and list possible types if invalid
|
||||
if isinstance(allowed_types, tuple):
|
||||
formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")"
|
||||
else:
|
||||
formatted_types = f"{allowed_types.__name__}"
|
||||
warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}"
|
||||
elif numerical_bound:
|
||||
# Check the value fits in possible bounds
|
||||
lower_bound, upper_bound = numerical_bound
|
||||
if (lower_bound is not None and param_value < lower_bound) or (
|
||||
upper_bound is not None and param_value > upper_bound
|
||||
):
|
||||
warning = "has numerical bounds"
|
||||
if lower_bound is not None:
|
||||
warning += f", >= {lower_bound!s}"
|
||||
if upper_bound is not None:
|
||||
if lower_bound is not None:
|
||||
warning += " and"
|
||||
warning += f" <= {upper_bound!s}"
|
||||
if allow_None:
|
||||
warning += ", or can be None"
|
||||
|
||||
elif allowed_values: # noqa: SIM102
|
||||
# Check if the value matches any allowed values
|
||||
if not (allow_None and param_value is None) and param_value not in allowed_values:
|
||||
warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}"
|
||||
|
||||
# If we failed any checks, warn and set to default value
|
||||
if warning:
|
||||
warnings.warn(
|
||||
f"Config error - {param_name} {warning}, defaulting to {default_value}.",
|
||||
UserWarning,
|
||||
)
|
||||
param_value = default_value
|
||||
|
||||
return param_value
|
||||
|
||||
|
||||
def should_hide_tools(messages: list[dict[str, Any]], tools: list[dict[str, Any]], hide_tools_param: str) -> bool:
|
||||
"""Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't.
|
||||
Parameters:
|
||||
messages (List[Dict[str, Any]]): List of messages
|
||||
tools (List[Dict[str, Any]]): List of tools
|
||||
hide_tools_param (str): "hide_tools" parameter value. Can be "if_all_run" (hide tools if all tools have been run), "if_any_run" (hide tools if any of the tools have been run), "never" (never hide tools). Default is "never".
|
||||
|
||||
Returns:
|
||||
bool: Indicates whether the tools should be excluded from the response create request
|
||||
|
||||
Example Usage:
|
||||
```python
|
||||
# Validating a numerical parameter within specific bounds
|
||||
messages = params.get("messages", [])
|
||||
tools = params.get("tools", None)
|
||||
hide_tools = should_hide_tools(messages, tools, params["hide_tools"])
|
||||
"""
|
||||
if hide_tools_param == "never" or tools is None or len(tools) == 0:
|
||||
return False
|
||||
elif hide_tools_param == "if_any_run":
|
||||
# Return True if any tool_call_id exists, indicating a tool call has been executed. False otherwise.
|
||||
return any(["tool_call_id" in dictionary for dictionary in messages])
|
||||
elif hide_tools_param == "if_all_run":
|
||||
# Return True if all tools have been executed at least once. False otherwise.
|
||||
|
||||
# Get the list of tool names
|
||||
check_tool_names = [item["function"]["name"] for item in tools]
|
||||
|
||||
# Prepare a list of tool call ids and related function names
|
||||
tool_call_ids = {}
|
||||
|
||||
# Loop through the messages and check if the tools have been run, removing them as we go
|
||||
for message in messages:
|
||||
if "tool_calls" in message:
|
||||
# Register the tool ids and the function names (there could be multiple tool calls)
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
|
||||
elif "tool_call_id" in message:
|
||||
# Tool called, get the name of the function based on the id
|
||||
tool_name_called = tool_call_ids[message["tool_call_id"]]
|
||||
|
||||
# If we had not yet called the tool, check and remove it to indicate we have
|
||||
if tool_name_called in check_tool_names:
|
||||
check_tool_names.remove(tool_name_called)
|
||||
|
||||
# Return True if all tools have been called at least once (accounted for)
|
||||
return len(check_tool_names) == 0
|
||||
else:
|
||||
raise TypeError(
|
||||
f"hide_tools_param is not a valid value ['if_all_run','if_any_run','never'], got '{hide_tools_param}'"
|
||||
)
|
||||
|
||||
|
||||
# Logging format (originally from FLAML)
|
||||
logging_formatter = logging.Formatter(
|
||||
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
|
||||
)
|
||||
Reference in New Issue
Block a user