CoACT initialize (#292)
This commit is contained in:
53
mm_agents/coact/autogen/oai/__init__.py
Normal file
53
mm_agents/coact/autogen/oai/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from ..cache.cache import Cache
|
||||
from .anthropic import AnthropicLLMConfigEntry
|
||||
from .bedrock import BedrockLLMConfigEntry
|
||||
from .cerebras import CerebrasLLMConfigEntry
|
||||
from .client import AzureOpenAILLMConfigEntry, DeepSeekLLMConfigEntry, ModelClient, OpenAILLMConfigEntry, OpenAIWrapper
|
||||
from .cohere import CohereLLMConfigEntry
|
||||
from .gemini import GeminiLLMConfigEntry
|
||||
from .groq import GroqLLMConfigEntry
|
||||
from .mistral import MistralLLMConfigEntry
|
||||
from .ollama import OllamaLLMConfigEntry
|
||||
from .openai_utils import (
|
||||
config_list_from_dotenv,
|
||||
config_list_from_json,
|
||||
config_list_from_models,
|
||||
config_list_gpt4_gpt35,
|
||||
config_list_openai_aoai,
|
||||
filter_config,
|
||||
get_config_list,
|
||||
get_first_llm_config,
|
||||
)
|
||||
from .together import TogetherLLMConfigEntry
|
||||
|
||||
__all__ = [
|
||||
"AnthropicLLMConfigEntry",
|
||||
"AzureOpenAILLMConfigEntry",
|
||||
"BedrockLLMConfigEntry",
|
||||
"Cache",
|
||||
"CerebrasLLMConfigEntry",
|
||||
"CohereLLMConfigEntry",
|
||||
"DeepSeekLLMConfigEntry",
|
||||
"GeminiLLMConfigEntry",
|
||||
"GroqLLMConfigEntry",
|
||||
"MistralLLMConfigEntry",
|
||||
"ModelClient",
|
||||
"OllamaLLMConfigEntry",
|
||||
"OpenAILLMConfigEntry",
|
||||
"OpenAIWrapper",
|
||||
"TogetherLLMConfigEntry",
|
||||
"config_list_from_dotenv",
|
||||
"config_list_from_json",
|
||||
"config_list_from_models",
|
||||
"config_list_gpt4_gpt35",
|
||||
"config_list_openai_aoai",
|
||||
"filter_config",
|
||||
"get_config_list",
|
||||
"get_first_llm_config",
|
||||
]
|
||||
714
mm_agents/coact/autogen/oai/anthropic.py
Normal file
714
mm_agents/coact/autogen/oai/anthropic.py
Normal file
@@ -0,0 +1,714 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Create an OpenAI-compatible client for the Anthropic API.
|
||||
|
||||
Example usage:
|
||||
Install the `anthropic` package by running `pip install --upgrade anthropic`.
|
||||
- https://docs.anthropic.com/en/docs/quickstart-guide
|
||||
|
||||
```python
|
||||
import autogen
|
||||
|
||||
config_list = [
|
||||
{
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
"api_type": "anthropic",
|
||||
}
|
||||
]
|
||||
|
||||
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
|
||||
```
|
||||
|
||||
Example usage for Anthropic Bedrock:
|
||||
|
||||
Install the `anthropic` package by running `pip install --upgrade anthropic`.
|
||||
- https://docs.anthropic.com/en/docs/quickstart-guide
|
||||
|
||||
```python
|
||||
import autogen
|
||||
|
||||
config_list = [
|
||||
{
|
||||
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"aws_access_key":<accessKey>,
|
||||
"aws_secret_key":<secretKey>,
|
||||
"aws_session_token":<sessionTok>,
|
||||
"aws_region":"us-east-1",
|
||||
"api_type": "anthropic",
|
||||
}
|
||||
]
|
||||
|
||||
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
|
||||
```
|
||||
|
||||
Example usage for Anthropic VertexAI:
|
||||
|
||||
Install the `anthropic` package by running `pip install anthropic[vertex]`.
|
||||
- https://docs.anthropic.com/en/docs/quickstart-guide
|
||||
|
||||
```python
|
||||
|
||||
import autogen
|
||||
config_list = [
|
||||
{
|
||||
"model": "claude-3-5-sonnet-20240620-v1:0",
|
||||
"gcp_project_id": "dummy_project_id",
|
||||
"gcp_region": "us-west-2",
|
||||
"gcp_auth_token": "dummy_auth_token",
|
||||
"api_type": "anthropic",
|
||||
}
|
||||
]
|
||||
|
||||
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
|
||||
```python
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from ..llm_config import LLMConfigEntry, register_llm_config
|
||||
from .client_utils import FormatterProtocol, validate_parameter
|
||||
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
||||
|
||||
with optional_import_block():
|
||||
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
|
||||
from anthropic import __version__ as anthropic_version
|
||||
from anthropic.types import Message, TextBlock, ToolUseBlock
|
||||
|
||||
TOOL_ENABLED = anthropic_version >= "0.23.1"
|
||||
if TOOL_ENABLED:
|
||||
pass
|
||||
|
||||
|
||||
ANTHROPIC_PRICING_1k = {
|
||||
"claude-3-7-sonnet-20250219": (0.003, 0.015),
|
||||
"claude-3-5-sonnet-20241022": (0.003, 0.015),
|
||||
"claude-3-5-haiku-20241022": (0.0008, 0.004),
|
||||
"claude-3-5-sonnet-20240620": (0.003, 0.015),
|
||||
"claude-3-sonnet-20240229": (0.003, 0.015),
|
||||
"claude-3-opus-20240229": (0.015, 0.075),
|
||||
"claude-3-haiku-20240307": (0.00025, 0.00125),
|
||||
"claude-2.1": (0.008, 0.024),
|
||||
"claude-2.0": (0.008, 0.024),
|
||||
"claude-instant-1.2": (0.008, 0.024),
|
||||
}
|
||||
|
||||
|
||||
@register_llm_config
|
||||
class AnthropicLLMConfigEntry(LLMConfigEntry):
|
||||
api_type: Literal["anthropic"] = "anthropic"
|
||||
timeout: Optional[int] = Field(default=None, ge=1)
|
||||
temperature: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default=None, ge=1)
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
stop_sequences: Optional[list[str]] = None
|
||||
stream: bool = False
|
||||
max_tokens: int = Field(default=4096, ge=1)
|
||||
price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
|
||||
tool_choice: Optional[dict] = None
|
||||
thinking: Optional[dict] = None
|
||||
|
||||
gcp_project_id: Optional[str] = None
|
||||
gcp_region: Optional[str] = None
|
||||
gcp_auth_token: Optional[str] = None
|
||||
|
||||
def create_client(self):
|
||||
raise NotImplementedError("AnthropicLLMConfigEntry.create_client is not implemented.")
|
||||
|
||||
|
||||
@require_optional_import("anthropic", "anthropic")
|
||||
class AnthropicClient:
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the Anthropic API client.
|
||||
|
||||
Args:
|
||||
**kwargs: The configuration parameters for the client.
|
||||
"""
|
||||
self._api_key = kwargs.get("api_key")
|
||||
self._aws_access_key = kwargs.get("aws_access_key")
|
||||
self._aws_secret_key = kwargs.get("aws_secret_key")
|
||||
self._aws_session_token = kwargs.get("aws_session_token")
|
||||
self._aws_region = kwargs.get("aws_region")
|
||||
self._gcp_project_id = kwargs.get("gcp_project_id")
|
||||
self._gcp_region = kwargs.get("gcp_region")
|
||||
self._gcp_auth_token = kwargs.get("gcp_auth_token")
|
||||
self._base_url = kwargs.get("base_url")
|
||||
|
||||
if not self._api_key:
|
||||
self._api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
if not self._aws_access_key:
|
||||
self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
|
||||
|
||||
if not self._aws_secret_key:
|
||||
self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
|
||||
|
||||
if not self._aws_region:
|
||||
self._aws_region = os.getenv("AWS_REGION")
|
||||
|
||||
if not self._gcp_region:
|
||||
self._gcp_region = os.getenv("GCP_REGION")
|
||||
|
||||
if self._api_key is None:
|
||||
if self._aws_region:
|
||||
if self._aws_access_key is None or self._aws_secret_key is None:
|
||||
raise ValueError("API key or AWS credentials are required to use the Anthropic API.")
|
||||
elif self._gcp_region:
|
||||
if self._gcp_project_id is None or self._gcp_region is None:
|
||||
raise ValueError("API key or GCP credentials are required to use the Anthropic API.")
|
||||
else:
|
||||
raise ValueError("API key or AWS credentials or GCP credentials are required to use the Anthropic API.")
|
||||
|
||||
if self._api_key is not None:
|
||||
client_kwargs = {"api_key": self._api_key}
|
||||
if self._base_url:
|
||||
client_kwargs["base_url"] = self._base_url
|
||||
self._client = Anthropic(**client_kwargs)
|
||||
elif self._gcp_region is not None:
|
||||
kw = {}
|
||||
for i, p in enumerate(inspect.signature(AnthropicVertex).parameters):
|
||||
if hasattr(self, f"_gcp_{p}"):
|
||||
kw[p] = getattr(self, f"_gcp_{p}")
|
||||
if self._base_url:
|
||||
kw["base_url"] = self._base_url
|
||||
self._client = AnthropicVertex(**kw)
|
||||
else:
|
||||
client_kwargs = {
|
||||
"aws_access_key": self._aws_access_key,
|
||||
"aws_secret_key": self._aws_secret_key,
|
||||
"aws_session_token": self._aws_session_token,
|
||||
"aws_region": self._aws_region,
|
||||
}
|
||||
if self._base_url:
|
||||
client_kwargs["base_url"] = self._base_url
|
||||
self._client = AnthropicBedrock(**client_kwargs)
|
||||
|
||||
self._last_tooluse_status = {}
|
||||
|
||||
# Store the response format, if provided (for structured outputs)
|
||||
self._response_format: Optional[type[BaseModel]] = None
|
||||
|
||||
def load_config(self, params: dict[str, Any]):
|
||||
"""Load the configuration for the Anthropic API client."""
|
||||
anthropic_params = {}
|
||||
|
||||
anthropic_params["model"] = params.get("model")
|
||||
assert anthropic_params["model"], "Please provide a `model` in the config_list to use the Anthropic API."
|
||||
|
||||
anthropic_params["temperature"] = validate_parameter(
|
||||
params, "temperature", (float, int), False, 1.0, (0.0, 1.0), None
|
||||
)
|
||||
anthropic_params["max_tokens"] = validate_parameter(params, "max_tokens", int, False, 4096, (1, None), None)
|
||||
anthropic_params["timeout"] = validate_parameter(params, "timeout", int, True, None, (1, None), None)
|
||||
anthropic_params["top_k"] = validate_parameter(params, "top_k", int, True, None, (1, None), None)
|
||||
anthropic_params["top_p"] = validate_parameter(params, "top_p", (float, int), True, None, (0.0, 1.0), None)
|
||||
anthropic_params["stop_sequences"] = validate_parameter(params, "stop_sequences", list, True, None, None, None)
|
||||
anthropic_params["stream"] = validate_parameter(params, "stream", bool, False, False, None, None)
|
||||
if "thinking" in params:
|
||||
anthropic_params["thinking"] = params["thinking"]
|
||||
|
||||
if anthropic_params["stream"]:
|
||||
warnings.warn(
|
||||
"Streaming is not currently supported, streaming will be disabled.",
|
||||
UserWarning,
|
||||
)
|
||||
anthropic_params["stream"] = False
|
||||
|
||||
# Note the Anthropic API supports "tool" for tool_choice but you must specify the tool name so we will ignore that here
|
||||
# Dictionary, see options here: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview#controlling-claudes-output
|
||||
# type = auto, any, tool, none | name = the name of the tool if type=tool
|
||||
anthropic_params["tool_choice"] = validate_parameter(params, "tool_choice", dict, True, None, None, None)
|
||||
|
||||
return anthropic_params
|
||||
|
||||
def cost(self, response) -> float:
|
||||
"""Calculate the cost of the completion using the Anthropic pricing."""
|
||||
return response.cost
|
||||
|
||||
@property
|
||||
def api_key(self):
|
||||
return self._api_key
|
||||
|
||||
@property
|
||||
def aws_access_key(self):
|
||||
return self._aws_access_key
|
||||
|
||||
@property
|
||||
def aws_secret_key(self):
|
||||
return self._aws_secret_key
|
||||
|
||||
@property
|
||||
def aws_session_token(self):
|
||||
return self._aws_session_token
|
||||
|
||||
@property
|
||||
def aws_region(self):
|
||||
return self._aws_region
|
||||
|
||||
@property
|
||||
def gcp_project_id(self):
|
||||
return self._gcp_project_id
|
||||
|
||||
@property
|
||||
def gcp_region(self):
|
||||
return self._gcp_region
|
||||
|
||||
@property
|
||||
def gcp_auth_token(self):
|
||||
return self._gcp_auth_token
|
||||
|
||||
def create(self, params: dict[str, Any]) -> ChatCompletion:
|
||||
"""Creates a completion using the Anthropic API."""
|
||||
if "tools" in params:
|
||||
converted_functions = self.convert_tools_to_functions(params["tools"])
|
||||
params["functions"] = params.get("functions", []) + converted_functions
|
||||
|
||||
# Convert AG2 messages to Anthropic messages
|
||||
anthropic_messages = oai_messages_to_anthropic_messages(params)
|
||||
anthropic_params = self.load_config(params)
|
||||
|
||||
# If response_format exists, we want structured outputs
|
||||
# Anthropic doesn't support response_format, so using Anthropic's "JSON Mode":
|
||||
# https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
|
||||
if params.get("response_format"):
|
||||
self._response_format = params["response_format"]
|
||||
self._add_response_format_to_system(params)
|
||||
|
||||
# TODO: support stream
|
||||
params = params.copy()
|
||||
if "functions" in params:
|
||||
tools_configs = params.pop("functions")
|
||||
tools_configs = [self.openai_func_to_anthropic(tool) for tool in tools_configs]
|
||||
params["tools"] = tools_configs
|
||||
|
||||
# Anthropic doesn't accept None values, so we need to use keyword argument unpacking instead of setting parameters.
|
||||
# Copy params we need into anthropic_params
|
||||
# Remove any that don't have values
|
||||
anthropic_params["messages"] = anthropic_messages
|
||||
if "system" in params:
|
||||
anthropic_params["system"] = params["system"]
|
||||
if "tools" in params:
|
||||
anthropic_params["tools"] = params["tools"]
|
||||
if anthropic_params["top_k"] is None:
|
||||
del anthropic_params["top_k"]
|
||||
if anthropic_params["top_p"] is None:
|
||||
del anthropic_params["top_p"]
|
||||
if anthropic_params["stop_sequences"] is None:
|
||||
del anthropic_params["stop_sequences"]
|
||||
if anthropic_params["tool_choice"] is None:
|
||||
del anthropic_params["tool_choice"]
|
||||
|
||||
response = self._client.messages.create(**anthropic_params)
|
||||
|
||||
tool_calls = []
|
||||
message_text = ""
|
||||
|
||||
if self._response_format:
|
||||
try:
|
||||
parsed_response = self._extract_json_response(response)
|
||||
message_text = _format_json_response(parsed_response)
|
||||
except ValueError as e:
|
||||
message_text = str(e)
|
||||
|
||||
anthropic_finish = "stop"
|
||||
else:
|
||||
if response is not None:
|
||||
# If we have tool use as the response, populate completed tool calls for our return OAI response
|
||||
if response.stop_reason == "tool_use":
|
||||
anthropic_finish = "tool_calls"
|
||||
for content in response.content:
|
||||
if type(content) == ToolUseBlock:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=content.id,
|
||||
function={"name": content.name, "arguments": json.dumps(content.input)},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
else:
|
||||
anthropic_finish = "stop"
|
||||
tool_calls = None
|
||||
|
||||
# Retrieve any text content from the response
|
||||
for content in response.content:
|
||||
if type(content) == TextBlock:
|
||||
message_text = content.text
|
||||
break
|
||||
|
||||
# Calculate and save the cost onto the response
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
|
||||
# Convert output back to AG2 response format
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=message_text,
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
choices = [Choice(finish_reason=anthropic_finish, index=0, message=message)]
|
||||
|
||||
response_oai = ChatCompletion(
|
||||
id=response.id,
|
||||
model=anthropic_params["model"],
|
||||
created=int(time.time()),
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
cost=_calculate_cost(prompt_tokens, completion_tokens, anthropic_params["model"]),
|
||||
)
|
||||
|
||||
return response_oai
|
||||
|
||||
def message_retrieval(self, response) -> list:
|
||||
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
||||
|
||||
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
||||
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
||||
"""
|
||||
return [choice.message for choice in response.choices]
|
||||
|
||||
@staticmethod
|
||||
def openai_func_to_anthropic(openai_func: dict) -> dict:
|
||||
res = openai_func.copy()
|
||||
res["input_schema"] = res.pop("parameters")
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def get_usage(response: ChatCompletion) -> dict:
|
||||
"""Get the usage of tokens and their cost information."""
|
||||
return {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage is not None else 0,
|
||||
"cost": response.cost if hasattr(response, "cost") else 0.0,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def convert_tools_to_functions(tools: list) -> list:
|
||||
"""
|
||||
Convert tool definitions into Anthropic-compatible functions,
|
||||
updating nested $ref paths in property schemas.
|
||||
|
||||
Args:
|
||||
tools (list): List of tool definitions.
|
||||
|
||||
Returns:
|
||||
list: List of functions with updated $ref paths.
|
||||
"""
|
||||
|
||||
def update_refs(obj, defs_keys, prop_name):
|
||||
"""Recursively update $ref values that start with "#/$defs/"."""
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
if key == "$ref" and isinstance(value, str) and value.startswith("#/$defs/"):
|
||||
ref_key = value[len("#/$defs/") :]
|
||||
if ref_key in defs_keys:
|
||||
obj[key] = f"#/properties/{prop_name}/$defs/{ref_key}"
|
||||
else:
|
||||
update_refs(value, defs_keys, prop_name)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
update_refs(item, defs_keys, prop_name)
|
||||
|
||||
functions = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function" and "function" in tool:
|
||||
function = tool["function"]
|
||||
parameters = function.get("parameters", {})
|
||||
properties = parameters.get("properties", {})
|
||||
for prop_name, prop_schema in properties.items():
|
||||
if "$defs" in prop_schema:
|
||||
defs_keys = set(prop_schema["$defs"].keys())
|
||||
update_refs(prop_schema, defs_keys, prop_name)
|
||||
functions.append(function)
|
||||
return functions
|
||||
|
||||
def _add_response_format_to_system(self, params: dict[str, Any]):
|
||||
"""Add prompt that will generate properly formatted JSON for structured outputs to system parameter.
|
||||
|
||||
Based on Anthropic's JSON Mode cookbook, we ask the LLM to put the JSON within <json_response> tags.
|
||||
|
||||
Args:
|
||||
params (dict): The client parameters
|
||||
"""
|
||||
if not params.get("system"):
|
||||
return
|
||||
|
||||
# Get the schema of the Pydantic model
|
||||
if isinstance(self._response_format, dict):
|
||||
schema = self._response_format
|
||||
else:
|
||||
schema = self._response_format.model_json_schema()
|
||||
|
||||
# Add instructions for JSON formatting
|
||||
format_content = f"""Please provide your response as a JSON object that matches the following schema:
|
||||
{json.dumps(schema, indent=2)}
|
||||
|
||||
Format your response as valid JSON within <json_response> tags.
|
||||
Do not include any text before or after the tags.
|
||||
Ensure the JSON is properly formatted and matches the schema exactly."""
|
||||
|
||||
# Add formatting to last user message
|
||||
params["system"] += "\n\n" + format_content
|
||||
|
||||
def _extract_json_response(self, response: Message) -> Any:
|
||||
"""Extract and validate JSON response from the output for structured outputs.
|
||||
|
||||
Args:
|
||||
response (Message): The response from the API.
|
||||
|
||||
Returns:
|
||||
Any: The parsed JSON response.
|
||||
"""
|
||||
if not self._response_format:
|
||||
return response
|
||||
|
||||
# Extract content from response
|
||||
content = response.content[0].text if response.content else ""
|
||||
|
||||
# Try to extract JSON from tags first
|
||||
json_match = re.search(r"<json_response>(.*?)</json_response>", content, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# Fallback to finding first JSON object
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}")
|
||||
if json_start == -1 or json_end == -1:
|
||||
raise ValueError("No valid JSON found in response for Structured Output.")
|
||||
json_str = content[json_start : json_end + 1]
|
||||
|
||||
try:
|
||||
# Parse JSON and validate against the Pydantic model if Pydantic model was provided
|
||||
json_data = json.loads(json_str)
|
||||
if isinstance(self._response_format, dict):
|
||||
return json_str
|
||||
else:
|
||||
return self._response_format.model_validate(json_data)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse response as valid JSON matching the schema for Structured Output: {e!s}")
|
||||
|
||||
|
||||
def _format_json_response(response: Any) -> str:
|
||||
"""Formats the JSON response for structured outputs using the format method if it exists."""
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
elif isinstance(response, FormatterProtocol):
|
||||
return response.format()
|
||||
else:
|
||||
return response.model_dump_json()
|
||||
|
||||
|
||||
def process_image_content(content_item: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Process an OpenAI image content item into Claude format."""
|
||||
if content_item["type"] != "image_url":
|
||||
return content_item
|
||||
|
||||
url = content_item["image_url"]["url"]
|
||||
try:
|
||||
# Handle data URLs
|
||||
if url.startswith("data:"):
|
||||
data_url_pattern = r"data:image/([a-zA-Z]+);base64,(.+)"
|
||||
match = re.match(data_url_pattern, url)
|
||||
if match:
|
||||
media_type, base64_data = match.groups()
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": f"image/{media_type}", "data": base64_data},
|
||||
}
|
||||
|
||||
else:
|
||||
print("Error processing image.")
|
||||
# Return original content if image processing fails
|
||||
return content_item
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing image image: {e}")
|
||||
# Return original content if image processing fails
|
||||
return content_item
|
||||
|
||||
|
||||
def process_message_content(message: dict[str, Any]) -> Union[str, list[dict[str, Any]]]:
|
||||
"""Process message content, handling both string and list formats with images."""
|
||||
content = message.get("content", "")
|
||||
|
||||
# Handle empty content
|
||||
if content == "":
|
||||
return content
|
||||
|
||||
# If content is already a string, return as is
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
# Handle list content (mixed text and images)
|
||||
if isinstance(content, list):
|
||||
processed_content = []
|
||||
for item in content:
|
||||
if item["type"] == "text":
|
||||
processed_content.append({"type": "text", "text": item["text"]})
|
||||
elif item["type"] == "image_url":
|
||||
processed_content.append(process_image_content(item))
|
||||
return processed_content
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@require_optional_import("anthropic", "anthropic")
|
||||
def oai_messages_to_anthropic_messages(params: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Anthropic format.
|
||||
We correct for any specific role orders and types, etc.
|
||||
"""
|
||||
# Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages.
|
||||
# Anthropic requires a tools parameter with the tools listed, if there are other messages with tool use or tool results.
|
||||
# This can occur when we don't need tool calling, such as for group chat speaker selection.
|
||||
has_tools = "tools" in params
|
||||
|
||||
# Convert messages to Anthropic compliant format
|
||||
processed_messages = []
|
||||
|
||||
# Used to interweave user messages to ensure user/assistant alternating
|
||||
user_continue_message = {"content": "Please continue.", "role": "user"}
|
||||
assistant_continue_message = {"content": "Please continue.", "role": "assistant"}
|
||||
|
||||
tool_use_messages = 0
|
||||
tool_result_messages = 0
|
||||
last_tool_use_index = -1
|
||||
last_tool_result_index = -1
|
||||
for message in params["messages"]:
|
||||
if message["role"] == "system":
|
||||
content = process_message_content(message)
|
||||
if isinstance(content, list):
|
||||
# For system messages with images, concatenate only the text portions
|
||||
text_content = " ".join(item.get("text", "") for item in content if item.get("type") == "text")
|
||||
params["system"] = params.get("system", "") + (" " if "system" in params else "") + text_content
|
||||
else:
|
||||
params["system"] = params.get("system", "") + ("\n" if "system" in params else "") + content
|
||||
else:
|
||||
# New messages will be added here, manage role alternations
|
||||
expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
|
||||
|
||||
if "tool_calls" in message:
|
||||
# Map the tool call options to Anthropic's ToolUseBlock
|
||||
tool_uses = []
|
||||
tool_names = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_uses.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=tool_call["id"],
|
||||
name=tool_call["function"]["name"],
|
||||
input=json.loads(tool_call["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
if has_tools:
|
||||
tool_use_messages += 1
|
||||
tool_names.append(tool_call["function"]["name"])
|
||||
|
||||
if expected_role == "user":
|
||||
# Insert an extra user message as we will append an assistant message
|
||||
processed_messages.append(user_continue_message)
|
||||
|
||||
if has_tools:
|
||||
processed_messages.append({"role": "assistant", "content": tool_uses})
|
||||
last_tool_use_index = len(processed_messages) - 1
|
||||
else:
|
||||
# Not using tools, so put in a plain text message
|
||||
processed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]",
|
||||
})
|
||||
elif "tool_call_id" in message:
|
||||
if has_tools:
|
||||
# Map the tool usage call to tool_result for Anthropic
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
}
|
||||
|
||||
# If the previous message also had a tool_result, add it to that
|
||||
# Otherwise append a new message
|
||||
if last_tool_result_index == len(processed_messages) - 1:
|
||||
processed_messages[-1]["content"].append(tool_result)
|
||||
else:
|
||||
if expected_role == "assistant":
|
||||
# Insert an extra assistant message as we will append a user message
|
||||
processed_messages.append(assistant_continue_message)
|
||||
|
||||
processed_messages.append({"role": "user", "content": [tool_result]})
|
||||
last_tool_result_index = len(processed_messages) - 1
|
||||
|
||||
tool_result_messages += 1
|
||||
else:
|
||||
# Not using tools, so put in a plain text message
|
||||
processed_messages.append({
|
||||
"role": "user",
|
||||
"content": f"Running the function returned: {message['content']}",
|
||||
})
|
||||
elif message["content"] == "":
|
||||
# Ignoring empty messages
|
||||
pass
|
||||
else:
|
||||
if expected_role != message["role"]:
|
||||
# Inserting the alternating continue message
|
||||
processed_messages.append(
|
||||
user_continue_message if expected_role == "user" else assistant_continue_message
|
||||
)
|
||||
# Process messages for images
|
||||
processed_content = process_message_content(message)
|
||||
processed_message = message.copy()
|
||||
processed_message["content"] = processed_content
|
||||
processed_messages.append(processed_message)
|
||||
|
||||
# We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function)
|
||||
if has_tools and tool_use_messages != tool_result_messages:
|
||||
processed_messages[last_tool_use_index] = assistant_continue_message
|
||||
|
||||
# name is not a valid field on messages
|
||||
for message in processed_messages:
|
||||
if "name" in message:
|
||||
message.pop("name", None)
|
||||
|
||||
# Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response
|
||||
# So, if the last role is not user, add a 'user' continue message at the end
|
||||
if processed_messages[-1]["role"] != "user":
|
||||
processed_messages.append(user_continue_message)
|
||||
|
||||
return processed_messages
|
||||
|
||||
|
||||
def _calculate_cost(input_tokens: int, output_tokens: int, model: str) -> float:
|
||||
"""Calculate the cost of the completion using the Anthropic pricing."""
|
||||
total = 0.0
|
||||
|
||||
if model in ANTHROPIC_PRICING_1k:
|
||||
input_cost_per_1k, output_cost_per_1k = ANTHROPIC_PRICING_1k[model]
|
||||
input_cost = (input_tokens / 1000) * input_cost_per_1k
|
||||
output_cost = (output_tokens / 1000) * output_cost_per_1k
|
||||
total = input_cost + output_cost
|
||||
else:
|
||||
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
|
||||
|
||||
return total
|
||||
628
mm_agents/coact/autogen/oai/bedrock.py
Normal file
628
mm_agents/coact/autogen/oai/bedrock.py
Normal file
@@ -0,0 +1,628 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Create a compatible client for the Amazon Bedrock Converse API.
|
||||
|
||||
Example usage:
|
||||
Install the `boto3` package by running `pip install --upgrade boto3`.
|
||||
- https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
|
||||
|
||||
```python
|
||||
import autogen
|
||||
|
||||
config_list = [
|
||||
{
|
||||
"api_type": "bedrock",
|
||||
"model": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"aws_region": "us-west-2",
|
||||
"aws_access_key": "",
|
||||
"aws_secret_key": "",
|
||||
"price": [0.003, 0.015],
|
||||
}
|
||||
]
|
||||
|
||||
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
|
||||
```
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import Field, SecretStr, field_serializer
|
||||
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from ..llm_config import LLMConfigEntry, register_llm_config
|
||||
from .client_utils import validate_parameter
|
||||
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
||||
|
||||
with optional_import_block():
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
|
||||
@register_llm_config
|
||||
class BedrockLLMConfigEntry(LLMConfigEntry):
|
||||
api_type: Literal["bedrock"] = "bedrock"
|
||||
aws_region: str
|
||||
aws_access_key: Optional[SecretStr] = None
|
||||
aws_secret_key: Optional[SecretStr] = None
|
||||
aws_session_token: Optional[SecretStr] = None
|
||||
aws_profile_name: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None # noqa: N815
|
||||
maxTokens: Optional[int] = None # noqa: N815
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
k: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
cache_seed: Optional[int] = None
|
||||
supports_system_prompts: bool = True
|
||||
stream: bool = False
|
||||
price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
|
||||
timeout: Optional[int] = None
|
||||
|
||||
@field_serializer("aws_access_key", "aws_secret_key", "aws_session_token", when_used="unless-none")
|
||||
def serialize_aws_secrets(self, v: SecretStr) -> str:
|
||||
return v.get_secret_value()
|
||||
|
||||
def create_client(self):
|
||||
raise NotImplementedError("BedrockLLMConfigEntry.create_client must be implemented.")
|
||||
|
||||
|
||||
@require_optional_import("boto3", "bedrock")
|
||||
class BedrockClient:
|
||||
"""Client for Amazon's Bedrock Converse API."""
|
||||
|
||||
_retries = 5
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialises BedrockClient for Amazon's Bedrock Converse API"""
|
||||
self._aws_access_key = kwargs.get("aws_access_key")
|
||||
self._aws_secret_key = kwargs.get("aws_secret_key")
|
||||
self._aws_session_token = kwargs.get("aws_session_token")
|
||||
self._aws_region = kwargs.get("aws_region")
|
||||
self._aws_profile_name = kwargs.get("aws_profile_name")
|
||||
self._timeout = kwargs.get("timeout")
|
||||
|
||||
if not self._aws_access_key:
|
||||
self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
|
||||
|
||||
if not self._aws_secret_key:
|
||||
self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
|
||||
|
||||
if not self._aws_session_token:
|
||||
self._aws_session_token = os.getenv("AWS_SESSION_TOKEN")
|
||||
|
||||
if not self._aws_region:
|
||||
self._aws_region = os.getenv("AWS_REGION")
|
||||
|
||||
if self._aws_region is None:
|
||||
raise ValueError("Region is required to use the Amazon Bedrock API.")
|
||||
|
||||
if self._timeout is None:
|
||||
self._timeout = 60
|
||||
|
||||
# Initialize Bedrock client, session, and runtime
|
||||
bedrock_config = Config(
|
||||
region_name=self._aws_region,
|
||||
signature_version="v4",
|
||||
retries={"max_attempts": self._retries, "mode": "standard"},
|
||||
read_timeout=self._timeout,
|
||||
)
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=self._aws_access_key,
|
||||
aws_secret_access_key=self._aws_secret_key,
|
||||
aws_session_token=self._aws_session_token,
|
||||
profile_name=self._aws_profile_name,
|
||||
)
|
||||
|
||||
if "response_format" in kwargs and kwargs["response_format"] is not None:
|
||||
warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning)
|
||||
|
||||
# if haven't got any access_key or secret_key in environment variable or via arguments then
|
||||
if (
|
||||
self._aws_access_key is None
|
||||
or self._aws_access_key == ""
|
||||
or self._aws_secret_key is None
|
||||
or self._aws_secret_key == ""
|
||||
):
|
||||
# attempts to get client from attached role of managed service (lambda, ec2, ecs, etc.)
|
||||
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime", config=bedrock_config)
|
||||
else:
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=self._aws_access_key,
|
||||
aws_secret_access_key=self._aws_secret_key,
|
||||
aws_session_token=self._aws_session_token,
|
||||
profile_name=self._aws_profile_name,
|
||||
)
|
||||
self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)
|
||||
|
||||
def message_retrieval(self, response):
|
||||
"""Retrieve the messages from the response."""
|
||||
return [choice.message for choice in response.choices]
|
||||
|
||||
def parse_custom_params(self, params: dict[str, Any]):
|
||||
"""Parses custom parameters for logic in this client class"""
|
||||
# Should we separate system messages into its own request parameter, default is True
|
||||
# This is required because not all models support a system prompt (e.g. Mistral Instruct).
|
||||
self._supports_system_prompts = params.get("supports_system_prompts", True)
|
||||
|
||||
def parse_params(self, params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Loads the valid parameters required to invoke Bedrock Converse
|
||||
Returns a tuple of (base_params, additional_params)
|
||||
"""
|
||||
base_params = {}
|
||||
additional_params = {}
|
||||
|
||||
# Amazon Bedrock base model IDs are here:
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
|
||||
self._model_id = params.get("model")
|
||||
assert self._model_id, "Please provide the 'model` in the config_list to use Amazon Bedrock"
|
||||
|
||||
# Parameters vary based on the model used.
|
||||
# As we won't cater for all models and parameters, it's the developer's
|
||||
# responsibility to implement the parameters and they will only be
|
||||
# included if the developer has it in the config.
|
||||
#
|
||||
# Important:
|
||||
# No defaults will be used (as they can vary per model)
|
||||
# No ranges will be used (as they can vary)
|
||||
# We will cover all the main parameters but there may be others
|
||||
# that need to be added later
|
||||
#
|
||||
# Here are some pages that show the parameters available for different models
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-chat-completion.html
|
||||
|
||||
# Here are the possible "base" parameters and their suitable types
|
||||
base_parameters = [["temperature", (float, int)], ["topP", (float, int)], ["maxTokens", (int)]]
|
||||
|
||||
for param_name, suitable_types in base_parameters:
|
||||
if param_name in params:
|
||||
base_params[param_name] = validate_parameter(
|
||||
params, param_name, suitable_types, False, None, None, None
|
||||
)
|
||||
|
||||
# Here are the possible "model-specific" parameters and their suitable types, known as additional parameters
|
||||
additional_parameters = [
|
||||
["top_p", (float, int)],
|
||||
["top_k", (int)],
|
||||
["k", (int)],
|
||||
["seed", (int)],
|
||||
]
|
||||
|
||||
for param_name, suitable_types in additional_parameters:
|
||||
if param_name in params:
|
||||
additional_params[param_name] = validate_parameter(
|
||||
params, param_name, suitable_types, False, None, None, None
|
||||
)
|
||||
|
||||
# Streaming
|
||||
self._streaming = params.get("stream", False)
|
||||
|
||||
# For this release we will not support streaming as many models do not support streaming with tool use
|
||||
if self._streaming:
|
||||
warnings.warn(
|
||||
"Streaming is not currently supported, streaming will be disabled.",
|
||||
UserWarning,
|
||||
)
|
||||
self._streaming = False
|
||||
|
||||
return base_params, additional_params
|
||||
|
||||
def create(self, params) -> ChatCompletion:
|
||||
"""Run Amazon Bedrock inference and return AG2 response"""
|
||||
# Set custom client class settings
|
||||
self.parse_custom_params(params)
|
||||
|
||||
# Parse the inference parameters
|
||||
base_params, additional_params = self.parse_params(params)
|
||||
|
||||
has_tools = "tools" in params
|
||||
messages = oai_messages_to_bedrock_messages(params["messages"], has_tools, self._supports_system_prompts)
|
||||
|
||||
if self._supports_system_prompts:
|
||||
system_messages = extract_system_messages(params["messages"])
|
||||
|
||||
tool_config = format_tools(params["tools"] if has_tools else [])
|
||||
|
||||
request_args = {"messages": messages, "modelId": self._model_id}
|
||||
|
||||
# Base and additional args
|
||||
if len(base_params) > 0:
|
||||
request_args["inferenceConfig"] = base_params
|
||||
|
||||
if len(additional_params) > 0:
|
||||
request_args["additionalModelRequestFields"] = additional_params
|
||||
|
||||
if self._supports_system_prompts:
|
||||
request_args["system"] = system_messages
|
||||
|
||||
if len(tool_config["tools"]) > 0:
|
||||
request_args["toolConfig"] = tool_config
|
||||
|
||||
response = self.bedrock_runtime.converse(**request_args)
|
||||
if response is None:
|
||||
raise RuntimeError(f"Failed to get response from Bedrock after retrying {self._retries} times.")
|
||||
|
||||
finish_reason = convert_stop_reason_to_finish_reason(response["stopReason"])
|
||||
response_message = response["output"]["message"]
|
||||
|
||||
tool_calls = format_tool_calls(response_message["content"]) if finish_reason == "tool_calls" else None
|
||||
|
||||
text = ""
|
||||
for content in response_message["content"]:
|
||||
if "text" in content:
|
||||
text = content["text"]
|
||||
# NOTE: other types of output may be dealt with here
|
||||
|
||||
message = ChatCompletionMessage(role="assistant", content=text, tool_calls=tool_calls)
|
||||
|
||||
response_usage = response["usage"]
|
||||
usage = CompletionUsage(
|
||||
prompt_tokens=response_usage["inputTokens"],
|
||||
completion_tokens=response_usage["outputTokens"],
|
||||
total_tokens=response_usage["totalTokens"],
|
||||
)
|
||||
|
||||
return ChatCompletion(
|
||||
id=response["ResponseMetadata"]["RequestId"],
|
||||
choices=[Choice(finish_reason=finish_reason, index=0, message=message)],
|
||||
created=int(time.time()),
|
||||
model=self._model_id,
|
||||
object="chat.completion",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def cost(self, response: ChatCompletion) -> float:
|
||||
"""Calculate the cost of the response."""
|
||||
return calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens, response.model)
|
||||
|
||||
@staticmethod
|
||||
def get_usage(response) -> dict:
|
||||
"""Get the usage of tokens and their cost information."""
|
||||
return {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"cost": response.cost,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
def extract_system_messages(messages: list[dict[str, Any]]) -> list:
|
||||
"""Extract the system messages from the list of messages.
|
||||
|
||||
Args:
|
||||
messages (list[dict[str, Any]]): List of messages.
|
||||
|
||||
Returns:
|
||||
List[SystemMessage]: List of System messages.
|
||||
"""
|
||||
"""
|
||||
system_messages = [message.get("content")[0]["text"] for message in messages if message.get("role") == "system"]
|
||||
return system_messages # ''.join(system_messages)
|
||||
"""
|
||||
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
if isinstance(message["content"], str):
|
||||
return [{"text": message.get("content")}]
|
||||
else:
|
||||
return [{"text": message.get("content")[0]["text"]}]
|
||||
return []
|
||||
|
||||
|
||||
def oai_messages_to_bedrock_messages(
|
||||
messages: list[dict[str, Any]], has_tools: bool, supports_system_prompts: bool
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Bedrock format.
|
||||
We correct for any specific role orders and types, etc.
|
||||
AWS Bedrock requires messages to alternate between user and assistant roles. This function ensures that the messages
|
||||
are in the correct order and format for Bedrock by inserting "Please continue" messages as needed.
|
||||
This is the same method as the one in the Autogen Anthropic client
|
||||
"""
|
||||
# Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages.
|
||||
# Bedrock requires a tools parameter with the tools listed, if there are other messages with tool use or tool results.
|
||||
# This can occur when we don't need tool calling, such as for group chat speaker selection
|
||||
|
||||
# Convert messages to Bedrock compliant format
|
||||
|
||||
# Take out system messages if the model supports it, otherwise leave them in.
|
||||
if supports_system_prompts:
|
||||
messages = [x for x in messages if x["role"] != "system"]
|
||||
else:
|
||||
# Replace role="system" with role="user"
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
msg["role"] = "user"
|
||||
|
||||
processed_messages = []
|
||||
|
||||
# Used to interweave user messages to ensure user/assistant alternating
|
||||
user_continue_message = {"content": [{"text": "Please continue."}], "role": "user"}
|
||||
assistant_continue_message = {
|
||||
"content": [{"text": "Please continue."}],
|
||||
"role": "assistant",
|
||||
}
|
||||
|
||||
tool_use_messages = 0
|
||||
tool_result_messages = 0
|
||||
last_tool_use_index = -1
|
||||
last_tool_result_index = -1
|
||||
# user_role_index = 0 if supports_system_prompts else 1 # If system prompts are supported, messages start with user, otherwise they'll be the second message
|
||||
for message in messages:
|
||||
# New messages will be added here, manage role alternations
|
||||
expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
|
||||
|
||||
if "tool_calls" in message:
|
||||
# Map the tool call options to Bedrock's format
|
||||
tool_uses = []
|
||||
tool_names = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_uses.append({
|
||||
"toolUse": {
|
||||
"toolUseId": tool_call["id"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"input": json.loads(tool_call["function"]["arguments"]),
|
||||
}
|
||||
})
|
||||
if has_tools:
|
||||
tool_use_messages += 1
|
||||
tool_names.append(tool_call["function"]["name"])
|
||||
|
||||
if expected_role == "user":
|
||||
# Insert an extra user message as we will append an assistant message
|
||||
processed_messages.append(user_continue_message)
|
||||
|
||||
if has_tools:
|
||||
processed_messages.append({"role": "assistant", "content": tool_uses})
|
||||
last_tool_use_index = len(processed_messages) - 1
|
||||
else:
|
||||
# Not using tools, so put in a plain text message
|
||||
processed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"text": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]"}],
|
||||
})
|
||||
elif "tool_call_id" in message:
|
||||
if has_tools:
|
||||
# Map the tool usage call to tool_result for Bedrock
|
||||
tool_result = {
|
||||
"toolResult": {
|
||||
"toolUseId": message["tool_call_id"],
|
||||
"content": [{"text": message["content"]}],
|
||||
}
|
||||
}
|
||||
|
||||
# If the previous message also had a tool_result, add it to that
|
||||
# Otherwise append a new message
|
||||
if last_tool_result_index == len(processed_messages) - 1:
|
||||
processed_messages[-1]["content"].append(tool_result)
|
||||
else:
|
||||
if expected_role == "assistant":
|
||||
# Insert an extra assistant message as we will append a user message
|
||||
processed_messages.append(assistant_continue_message)
|
||||
|
||||
processed_messages.append({"role": "user", "content": [tool_result]})
|
||||
last_tool_result_index = len(processed_messages) - 1
|
||||
|
||||
tool_result_messages += 1
|
||||
else:
|
||||
# Not using tools, so put in a plain text message
|
||||
processed_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"text": f"Running the function returned: {message['content']}"}],
|
||||
})
|
||||
elif message["content"] == "":
|
||||
# Ignoring empty messages
|
||||
pass
|
||||
else:
|
||||
if expected_role != message["role"] and not (len(processed_messages) == 0 and message["role"] == "system"):
|
||||
# Inserting the alternating continue message (ignore if it's the first message and a system message)
|
||||
processed_messages.append(
|
||||
user_continue_message if expected_role == "user" else assistant_continue_message
|
||||
)
|
||||
|
||||
processed_messages.append({
|
||||
"role": message["role"],
|
||||
"content": parse_content_parts(message=message),
|
||||
})
|
||||
|
||||
# We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function)
|
||||
if has_tools and tool_use_messages != tool_result_messages:
|
||||
processed_messages[last_tool_use_index] = assistant_continue_message
|
||||
|
||||
# name is not a valid field on messages
|
||||
for message in processed_messages:
|
||||
if "name" in message:
|
||||
message.pop("name", None)
|
||||
|
||||
# Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response
|
||||
# So, if the last role is not user, add a 'user' continue message at the end
|
||||
if processed_messages[-1]["role"] != "user":
|
||||
processed_messages.append(user_continue_message)
|
||||
|
||||
return processed_messages
|
||||
|
||||
|
||||
def parse_content_parts(
|
||||
message: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
content: str | list[dict[str, Any]] = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return [
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
]
|
||||
content_parts = []
|
||||
for part in content:
|
||||
# part_content: Dict = part.get("content")
|
||||
if "text" in part: # part_content:
|
||||
content_parts.append({
|
||||
"text": part.get("text"),
|
||||
})
|
||||
elif "image_url" in part: # part_content:
|
||||
image_data, content_type = parse_image(part.get("image_url").get("url"))
|
||||
content_parts.append({
|
||||
"image": {
|
||||
"format": content_type[6:], # image/
|
||||
"source": {"bytes": image_data},
|
||||
},
|
||||
})
|
||||
else:
|
||||
# Ignore..
|
||||
continue
|
||||
return content_parts
|
||||
|
||||
|
||||
def parse_image(image_url: str) -> tuple[bytes, str]:
|
||||
"""Try to get the raw data from an image url.
|
||||
|
||||
Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html
|
||||
returns a tuple of (Image Data, Content Type)
|
||||
"""
|
||||
pattern = r"^data:(image/[a-z]*);base64,\s*"
|
||||
content_type = re.search(pattern, image_url)
|
||||
# if already base64 encoded.
|
||||
# Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp'
|
||||
if content_type:
|
||||
image_data = re.sub(pattern, "", image_url)
|
||||
return base64.b64decode(image_data), content_type.group(1)
|
||||
|
||||
# Send a request to the image URL
|
||||
response = requests.get(image_url)
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if not content_type.startswith("image"):
|
||||
content_type = "image/jpeg"
|
||||
# Get the image content
|
||||
image_content = response.content
|
||||
return image_content, content_type
|
||||
else:
|
||||
raise RuntimeError("Unable to access the image url")
|
||||
|
||||
|
||||
def format_tools(tools: list[dict[str, Any]]) -> dict[Literal["tools"], list[dict[str, Any]]]:
|
||||
converted_schema = {"tools": []}
|
||||
|
||||
for tool in tools:
|
||||
if tool["type"] == "function":
|
||||
function = tool["function"]
|
||||
converted_tool = {
|
||||
"toolSpec": {
|
||||
"name": function["name"],
|
||||
"description": function["description"],
|
||||
"inputSchema": {"json": {"type": "object", "properties": {}, "required": []}},
|
||||
}
|
||||
}
|
||||
|
||||
for prop_name, prop_details in function["parameters"]["properties"].items():
|
||||
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name] = {
|
||||
"type": prop_details["type"],
|
||||
"description": prop_details.get("description", ""),
|
||||
}
|
||||
if "enum" in prop_details:
|
||||
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["enum"] = prop_details[
|
||||
"enum"
|
||||
]
|
||||
if "default" in prop_details:
|
||||
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["default"] = (
|
||||
prop_details["default"]
|
||||
)
|
||||
|
||||
if "required" in function["parameters"]:
|
||||
converted_tool["toolSpec"]["inputSchema"]["json"]["required"] = function["parameters"]["required"]
|
||||
|
||||
converted_schema["tools"].append(converted_tool)
|
||||
|
||||
return converted_schema
|
||||
|
||||
|
||||
def format_tool_calls(content):
|
||||
"""Converts Converse API response tool calls to AG2 format"""
|
||||
tool_calls = []
|
||||
for tool_request in content:
|
||||
if "toolUse" in tool_request:
|
||||
tool = tool_request["toolUse"]
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=tool["toolUseId"],
|
||||
function={
|
||||
"name": tool["name"],
|
||||
"arguments": json.dumps(tool["input"]),
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
def convert_stop_reason_to_finish_reason(
|
||||
stop_reason: str,
|
||||
) -> Literal["stop", "length", "tool_calls", "content_filter"]:
|
||||
"""Converts Bedrock finish reasons to our finish reasons, according to OpenAI:
|
||||
|
||||
- stop: if the model hit a natural stop point or a provided stop sequence,
|
||||
- length: if the maximum number of tokens specified in the request was reached,
|
||||
- content_filter: if content was omitted due to a flag from our content filters,
|
||||
- tool_calls: if the model called a tool
|
||||
"""
|
||||
if stop_reason:
|
||||
finish_reason_mapping = {
|
||||
"tool_use": "tool_calls",
|
||||
"finished": "stop",
|
||||
"end_turn": "stop",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"complete": "stop",
|
||||
"content_filtered": "content_filter",
|
||||
}
|
||||
return finish_reason_mapping.get(stop_reason.lower(), stop_reason.lower())
|
||||
|
||||
warnings.warn(f"Unsupported stop reason: {stop_reason}", UserWarning)
|
||||
return None
|
||||
|
||||
|
||||
# NOTE: As this will be quite dynamic, it's expected that the developer will use the "price" parameter in their config
|
||||
# These may be removed.
|
||||
PRICES_PER_K_TOKENS = {
|
||||
"meta.llama3-8b-instruct-v1:0": (0.0003, 0.0006),
|
||||
"meta.llama3-70b-instruct-v1:0": (0.00265, 0.0035),
|
||||
"mistral.mistral-7b-instruct-v0:2": (0.00015, 0.0002),
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": (0.00045, 0.0007),
|
||||
"mistral.mistral-large-2402-v1:0": (0.004, 0.012),
|
||||
"mistral.mistral-small-2402-v1:0": (0.001, 0.003),
|
||||
}
|
||||
|
||||
|
||||
def calculate_cost(input_tokens: int, output_tokens: int, model_id: str) -> float:
|
||||
"""Calculate the cost of the completion using the Bedrock pricing."""
|
||||
if model_id in PRICES_PER_K_TOKENS:
|
||||
input_cost_per_k, output_cost_per_k = PRICES_PER_K_TOKENS[model_id]
|
||||
input_cost = (input_tokens / 1000) * input_cost_per_k
|
||||
output_cost = (output_tokens / 1000) * output_cost_per_k
|
||||
return input_cost + output_cost
|
||||
else:
|
||||
warnings.warn(
|
||||
f'Cannot get the costs for {model_id}. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.',
|
||||
UserWarning,
|
||||
)
|
||||
return 0
|
||||
299
mm_agents/coact/autogen/oai/cerebras.py
Normal file
299
mm_agents/coact/autogen/oai/cerebras.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Create an OpenAI-compatible client using Cerebras's API.
|
||||
|
||||
Example:
|
||||
```python
|
||||
llm_config = {
|
||||
"config_list": [{"api_type": "cerebras", "model": "llama3.1-8b", "api_key": os.environ.get("CEREBRAS_API_KEY")}]
|
||||
}
|
||||
|
||||
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
||||
```
|
||||
|
||||
Install Cerebras's python library using: pip install --upgrade cerebras_cloud_sdk
|
||||
|
||||
Resources:
|
||||
- https://inference-docs.cerebras.ai/quickstart
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import Field, ValidationInfo, field_validator
|
||||
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from ..llm_config import LLMConfigEntry, register_llm_config
|
||||
from .client_utils import should_hide_tools, validate_parameter
|
||||
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
||||
|
||||
with optional_import_block():
|
||||
from cerebras.cloud.sdk import Cerebras, Stream
|
||||
|
||||
CEREBRAS_PRICING_1K = {
|
||||
# Convert pricing per million to per thousand tokens.
|
||||
"llama3.1-8b": (0.10 / 1000, 0.10 / 1000),
|
||||
"llama-3.3-70b": (0.85 / 1000, 1.20 / 1000),
|
||||
}
|
||||
|
||||
|
||||
@register_llm_config
|
||||
class CerebrasLLMConfigEntry(LLMConfigEntry):
|
||||
api_type: Literal["cerebras"] = "cerebras"
|
||||
max_tokens: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
stream: bool = False
|
||||
temperature: float = Field(default=1.0, ge=0.0, le=1.5)
|
||||
top_p: Optional[float] = None
|
||||
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
|
||||
tool_choice: Optional[Literal["none", "auto", "required"]] = None
|
||||
|
||||
@field_validator("top_p", mode="before")
|
||||
@classmethod
|
||||
def check_top_p(cls, v: Any, info: ValidationInfo) -> Any:
|
||||
if v is not None and info.data.get("temperature") is not None:
|
||||
raise ValueError("temperature and top_p cannot be set at the same time.")
|
||||
return v
|
||||
|
||||
def create_client(self):
|
||||
raise NotImplementedError("CerebrasLLMConfigEntry.create_client is not implemented.")
|
||||
|
||||
|
||||
class CerebrasClient:
|
||||
"""Client for Cerebras's API."""
|
||||
|
||||
def __init__(self, api_key=None, **kwargs):
|
||||
"""Requires api_key or environment variable to be set
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set)
|
||||
**kwargs: Additional keyword arguments to pass to the Cerebras client
|
||||
"""
|
||||
# Ensure we have the api_key upon instantiation
|
||||
self.api_key = api_key
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("CEREBRAS_API_KEY")
|
||||
|
||||
assert self.api_key, (
|
||||
"Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable."
|
||||
)
|
||||
|
||||
if "response_format" in kwargs and kwargs["response_format"] is not None:
|
||||
warnings.warn("response_format is not supported for Crebras, it will be ignored.", UserWarning)
|
||||
|
||||
def message_retrieval(self, response: ChatCompletion) -> list:
|
||||
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
||||
|
||||
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
||||
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
||||
"""
|
||||
return [choice.message for choice in response.choices]
|
||||
|
||||
def cost(self, response: ChatCompletion) -> float:
|
||||
# Note: This field isn't explicitly in `ChatCompletion`, but is injected during chat creation.
|
||||
return response.cost
|
||||
|
||||
@staticmethod
|
||||
def get_usage(response: ChatCompletion) -> dict:
|
||||
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
||||
# ... # pragma: no cover
|
||||
return {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"cost": response.cost,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Loads the parameters for Cerebras API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
|
||||
cerebras_params = {}
|
||||
|
||||
# Check that we have what we need to use Cerebras's API
|
||||
# We won't enforce the available models as they are likely to change
|
||||
cerebras_params["model"] = params.get("model")
|
||||
assert cerebras_params["model"], (
|
||||
"Please specify the 'model' in your config list entry to nominate the Cerebras model to use."
|
||||
)
|
||||
|
||||
# Validate allowed Cerebras parameters
|
||||
# https://inference-docs.cerebras.ai/api-reference/chat-completions
|
||||
cerebras_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
|
||||
cerebras_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
|
||||
cerebras_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
|
||||
cerebras_params["temperature"] = validate_parameter(
|
||||
params, "temperature", (int, float), True, 1, (0, 1.5), None
|
||||
)
|
||||
cerebras_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
|
||||
cerebras_params["tool_choice"] = validate_parameter(
|
||||
params, "tool_choice", str, True, None, None, ["none", "auto", "required"]
|
||||
)
|
||||
|
||||
return cerebras_params
|
||||
|
||||
@require_optional_import("cerebras", "cerebras")
|
||||
def create(self, params: dict) -> ChatCompletion:
|
||||
messages = params.get("messages", [])
|
||||
|
||||
# Convert AG2 messages to Cerebras messages
|
||||
cerebras_messages = oai_messages_to_cerebras_messages(messages)
|
||||
|
||||
# Parse parameters to the Cerebras API's parameters
|
||||
cerebras_params = self.parse_params(params)
|
||||
|
||||
# Add tools to the call if we have them and aren't hiding them
|
||||
if "tools" in params:
|
||||
hide_tools = validate_parameter(
|
||||
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
|
||||
)
|
||||
if not should_hide_tools(cerebras_messages, params["tools"], hide_tools):
|
||||
cerebras_params["tools"] = params["tools"]
|
||||
|
||||
cerebras_params["messages"] = cerebras_messages
|
||||
|
||||
# We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
|
||||
client = Cerebras(api_key=self.api_key, max_retries=5)
|
||||
|
||||
# Token counts will be returned
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
# Streaming tool call recommendations
|
||||
streaming_tool_calls = []
|
||||
|
||||
ans = None
|
||||
response = client.chat.completions.create(**cerebras_params)
|
||||
|
||||
if cerebras_params["stream"]:
|
||||
# Read in the chunks as they stream, taking in tool_calls which may be across
|
||||
# multiple chunks if more than one suggested
|
||||
ans = ""
|
||||
for chunk in response:
|
||||
# Grab first choice, which _should_ always be generated.
|
||||
ans = ans + (getattr(chunk.choices[0].delta, "content", None) or "")
|
||||
|
||||
if "tool_calls" in chunk.choices[0].delta:
|
||||
# We have a tool call recommendation
|
||||
for tool_call in chunk.choices[0].delta["tool_calls"]:
|
||||
streaming_tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=tool_call["id"],
|
||||
function={
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": tool_call["function"]["arguments"],
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
total_tokens = chunk.usage.total_tokens
|
||||
else:
|
||||
# Non-streaming finished
|
||||
ans: str = response.choices[0].message.content
|
||||
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
total_tokens = response.usage.total_tokens
|
||||
|
||||
if response is not None:
|
||||
if isinstance(response, Stream):
|
||||
# Streaming response
|
||||
if chunk.choices[0].finish_reason == "tool_calls":
|
||||
cerebras_finish = "tool_calls"
|
||||
tool_calls = streaming_tool_calls
|
||||
else:
|
||||
cerebras_finish = "stop"
|
||||
tool_calls = None
|
||||
|
||||
response_content = ans
|
||||
response_id = chunk.id
|
||||
else:
|
||||
# Non-streaming response
|
||||
# If we have tool calls as the response, populate completed tool calls for our return OAI response
|
||||
if response.choices[0].finish_reason == "tool_calls":
|
||||
cerebras_finish = "tool_calls"
|
||||
tool_calls = []
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=tool_call.id,
|
||||
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
else:
|
||||
cerebras_finish = "stop"
|
||||
tool_calls = None
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
response_id = response.id
|
||||
|
||||
# 3. convert output
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
choices = [Choice(finish_reason=cerebras_finish, index=0, message=message)]
|
||||
|
||||
response_oai = ChatCompletion(
|
||||
id=response_id,
|
||||
model=cerebras_params["model"],
|
||||
created=int(time.time()),
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
),
|
||||
# Note: This seems to be a field that isn't in the schema of `ChatCompletion`, so Pydantic
|
||||
# just adds it dynamically.
|
||||
cost=calculate_cerebras_cost(prompt_tokens, completion_tokens, cerebras_params["model"]),
|
||||
)
|
||||
|
||||
return response_oai
|
||||
|
||||
|
||||
def oai_messages_to_cerebras_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Cerebras's format.
|
||||
We correct for any specific role orders and types.
|
||||
"""
|
||||
cerebras_messages = copy.deepcopy(messages)
|
||||
|
||||
# Remove the name field
|
||||
for message in cerebras_messages:
|
||||
if "name" in message:
|
||||
message.pop("name", None)
|
||||
|
||||
return cerebras_messages
|
||||
|
||||
|
||||
def calculate_cerebras_cost(input_tokens: int, output_tokens: int, model: str) -> float:
|
||||
"""Calculate the cost of the completion using the Cerebras pricing."""
|
||||
total = 0.0
|
||||
|
||||
if model in CEREBRAS_PRICING_1K:
|
||||
input_cost_per_k, output_cost_per_k = CEREBRAS_PRICING_1K[model]
|
||||
input_cost = math.ceil((input_tokens / 1000) * input_cost_per_k * 1e6) / 1e6
|
||||
output_cost = math.ceil((output_tokens / 1000) * output_cost_per_k * 1e6) / 1e6
|
||||
total = math.ceil((input_cost + output_cost) * 1e6) / 1e6
|
||||
else:
|
||||
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
|
||||
|
||||
return total
|
||||
1444
mm_agents/coact/autogen/oai/client.py
Normal file
1444
mm_agents/coact/autogen/oai/client.py
Normal file
File diff suppressed because it is too large
Load Diff
169
mm_agents/coact/autogen/oai/client_utils.py
Normal file
169
mm_agents/coact/autogen/oai/client_utils.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Utilities for client classes"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Optional, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class FormatterProtocol(Protocol):
|
||||
"""Structured Output classes with a format method"""
|
||||
|
||||
def format(self) -> str: ...
|
||||
|
||||
|
||||
def validate_parameter(
|
||||
params: dict[str, Any],
|
||||
param_name: str,
|
||||
allowed_types: tuple[Any, ...],
|
||||
allow_None: bool, # noqa: N803
|
||||
default_value: Any,
|
||||
numerical_bound: Optional[tuple[Optional[float], Optional[float]]],
|
||||
allowed_values: Optional[list[Any]],
|
||||
) -> Any:
|
||||
"""Validates a given config parameter, checking its type, values, and setting defaults
|
||||
Parameters:
|
||||
params (Dict[str, Any]): Dictionary containing parameters to validate.
|
||||
param_name (str): The name of the parameter to validate.
|
||||
allowed_types (Tuple): Tuple of acceptable types for the parameter.
|
||||
allow_None (bool): Whether the parameter can be `None`.
|
||||
default_value (Any): The default value to use if the parameter is invalid or missing.
|
||||
numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]):
|
||||
A tuple specifying the lower and upper bounds for numerical parameters.
|
||||
Each bound can be `None` if not applicable.
|
||||
allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter.
|
||||
Can be `None` if no specific values are required.
|
||||
|
||||
Returns:
|
||||
Any: The validated parameter value or the default value if validation fails.
|
||||
|
||||
Raises:
|
||||
TypeError: If `allowed_values` is provided but is not a list.
|
||||
|
||||
Example Usage:
|
||||
```python
|
||||
# Validating a numerical parameter within specific bounds
|
||||
params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"}
|
||||
temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None)
|
||||
# Result: 0.5
|
||||
|
||||
# Validating a parameter that can be one of a list of allowed values
|
||||
model = validate_parameter(
|
||||
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
|
||||
)
|
||||
# If "safety_model" is missing or invalid in params, defaults to "default"
|
||||
```
|
||||
"""
|
||||
if allowed_values is not None and not isinstance(allowed_values, list):
|
||||
raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}")
|
||||
|
||||
param_value = params.get(param_name, default_value)
|
||||
warning = ""
|
||||
|
||||
if param_value is None and allow_None:
|
||||
pass
|
||||
elif param_value is None:
|
||||
if not allow_None:
|
||||
warning = "cannot be None"
|
||||
elif not isinstance(param_value, allowed_types):
|
||||
# Check types and list possible types if invalid
|
||||
if isinstance(allowed_types, tuple):
|
||||
formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")"
|
||||
else:
|
||||
formatted_types = f"{allowed_types.__name__}"
|
||||
warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}"
|
||||
elif numerical_bound:
|
||||
# Check the value fits in possible bounds
|
||||
lower_bound, upper_bound = numerical_bound
|
||||
if (lower_bound is not None and param_value < lower_bound) or (
|
||||
upper_bound is not None and param_value > upper_bound
|
||||
):
|
||||
warning = "has numerical bounds"
|
||||
if lower_bound is not None:
|
||||
warning += f", >= {lower_bound!s}"
|
||||
if upper_bound is not None:
|
||||
if lower_bound is not None:
|
||||
warning += " and"
|
||||
warning += f" <= {upper_bound!s}"
|
||||
if allow_None:
|
||||
warning += ", or can be None"
|
||||
|
||||
elif allowed_values: # noqa: SIM102
|
||||
# Check if the value matches any allowed values
|
||||
if not (allow_None and param_value is None) and param_value not in allowed_values:
|
||||
warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}"
|
||||
|
||||
# If we failed any checks, warn and set to default value
|
||||
if warning:
|
||||
warnings.warn(
|
||||
f"Config error - {param_name} {warning}, defaulting to {default_value}.",
|
||||
UserWarning,
|
||||
)
|
||||
param_value = default_value
|
||||
|
||||
return param_value
|
||||
|
||||
|
||||
def should_hide_tools(messages: list[dict[str, Any]], tools: list[dict[str, Any]], hide_tools_param: str) -> bool:
|
||||
"""Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't.
|
||||
Parameters:
|
||||
messages (List[Dict[str, Any]]): List of messages
|
||||
tools (List[Dict[str, Any]]): List of tools
|
||||
hide_tools_param (str): "hide_tools" parameter value. Can be "if_all_run" (hide tools if all tools have been run), "if_any_run" (hide tools if any of the tools have been run), "never" (never hide tools). Default is "never".
|
||||
|
||||
Returns:
|
||||
bool: Indicates whether the tools should be excluded from the response create request
|
||||
|
||||
Example Usage:
|
||||
```python
|
||||
# Validating a numerical parameter within specific bounds
|
||||
messages = params.get("messages", [])
|
||||
tools = params.get("tools", None)
|
||||
hide_tools = should_hide_tools(messages, tools, params["hide_tools"])
|
||||
"""
|
||||
if hide_tools_param == "never" or tools is None or len(tools) == 0:
|
||||
return False
|
||||
elif hide_tools_param == "if_any_run":
|
||||
# Return True if any tool_call_id exists, indicating a tool call has been executed. False otherwise.
|
||||
return any(["tool_call_id" in dictionary for dictionary in messages])
|
||||
elif hide_tools_param == "if_all_run":
|
||||
# Return True if all tools have been executed at least once. False otherwise.
|
||||
|
||||
# Get the list of tool names
|
||||
check_tool_names = [item["function"]["name"] for item in tools]
|
||||
|
||||
# Prepare a list of tool call ids and related function names
|
||||
tool_call_ids = {}
|
||||
|
||||
# Loop through the messages and check if the tools have been run, removing them as we go
|
||||
for message in messages:
|
||||
if "tool_calls" in message:
|
||||
# Register the tool ids and the function names (there could be multiple tool calls)
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
|
||||
elif "tool_call_id" in message:
|
||||
# Tool called, get the name of the function based on the id
|
||||
tool_name_called = tool_call_ids[message["tool_call_id"]]
|
||||
|
||||
# If we had not yet called the tool, check and remove it to indicate we have
|
||||
if tool_name_called in check_tool_names:
|
||||
check_tool_names.remove(tool_name_called)
|
||||
|
||||
# Return True if all tools have been called at least once (accounted for)
|
||||
return len(check_tool_names) == 0
|
||||
else:
|
||||
raise TypeError(
|
||||
f"hide_tools_param is not a valid value ['if_all_run','if_any_run','never'], got '{hide_tools_param}'"
|
||||
)
|
||||
|
||||
|
||||
# Logging format (originally from FLAML)
|
||||
logging_formatter = logging.Formatter(
|
||||
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
|
||||
)
|
||||
479
mm_agents/coact/autogen/oai/cohere.py
Normal file
479
mm_agents/coact/autogen/oai/cohere.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Create an OpenAI-compatible client using Cohere's API.
|
||||
|
||||
Example:
|
||||
```python
|
||||
llm_config={
|
||||
"config_list": [{
|
||||
"api_type": "cohere",
|
||||
"model": "command-r-plus",
|
||||
"api_key": os.environ.get("COHERE_API_KEY")
|
||||
"client_name": "autogen-cohere", # Optional parameter
|
||||
}
|
||||
]}
|
||||
|
||||
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
||||
```
|
||||
|
||||
Install Cohere's python library using: pip install --upgrade cohere
|
||||
|
||||
Resources:
|
||||
- https://docs.cohere.com/reference/chat
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogen.oai.client_utils import FormatterProtocol, logging_formatter, validate_parameter
|
||||
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from ..llm_config import LLMConfigEntry, register_llm_config
|
||||
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
||||
|
||||
with optional_import_block():
|
||||
from cohere import ClientV2 as CohereV2
|
||||
from cohere.types import ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logger.handlers:
|
||||
# Add the console handler.
|
||||
_ch = logging.StreamHandler(stream=sys.stdout)
|
||||
_ch.setFormatter(logging_formatter)
|
||||
logger.addHandler(_ch)
|
||||
|
||||
|
||||
COHERE_PRICING_1K = {
|
||||
"command-r-plus": (0.003, 0.015),
|
||||
"command-r": (0.0005, 0.0015),
|
||||
"command-nightly": (0.00025, 0.00125),
|
||||
"command": (0.015, 0.075),
|
||||
"command-light": (0.008, 0.024),
|
||||
"command-light-nightly": (0.008, 0.024),
|
||||
}
|
||||
|
||||
|
||||
@register_llm_config
|
||||
class CohereLLMConfigEntry(LLMConfigEntry):
|
||||
api_type: Literal["cohere"] = "cohere"
|
||||
temperature: float = Field(default=0.3, ge=0)
|
||||
max_tokens: Optional[int] = Field(default=None, ge=0)
|
||||
k: int = Field(default=0, ge=0, le=500)
|
||||
p: float = Field(default=0.75, ge=0.01, le=0.99)
|
||||
seed: Optional[int] = None
|
||||
frequency_penalty: float = Field(default=0, ge=0, le=1)
|
||||
presence_penalty: float = Field(default=0, ge=0, le=1)
|
||||
client_name: Optional[str] = None
|
||||
strict_tools: bool = False
|
||||
stream: bool = False
|
||||
tool_choice: Optional[Literal["NONE", "REQUIRED"]] = None
|
||||
|
||||
def create_client(self):
|
||||
raise NotImplementedError("CohereLLMConfigEntry.create_client is not implemented.")
|
||||
|
||||
|
||||
class CohereClient:
|
||||
"""Client for Cohere's API."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Requires api_key or environment variable to be set
|
||||
|
||||
Args:
|
||||
**kwargs: The keyword arguments to pass to the Cohere API.
|
||||
"""
|
||||
# Ensure we have the api_key upon instantiation
|
||||
self.api_key = kwargs.get("api_key")
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("COHERE_API_KEY")
|
||||
|
||||
assert self.api_key, (
|
||||
"Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."
|
||||
)
|
||||
|
||||
# Store the response format, if provided (for structured outputs)
|
||||
self._response_format: Optional[Type[BaseModel]] = None
|
||||
|
||||
def message_retrieval(self, response) -> list:
|
||||
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
||||
|
||||
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
||||
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
||||
"""
|
||||
return [choice.message for choice in response.choices]
|
||||
|
||||
def cost(self, response) -> float:
|
||||
return response.cost
|
||||
|
||||
@staticmethod
|
||||
def get_usage(response) -> dict:
|
||||
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
||||
# ... # pragma: no cover
|
||||
return {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"cost": response.cost,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Loads the parameters for Cohere API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
|
||||
cohere_params = {}
|
||||
|
||||
# Check that we have what we need to use Cohere's API
|
||||
# We won't enforce the available models as they are likely to change
|
||||
cohere_params["model"] = params.get("model")
|
||||
assert cohere_params["model"], (
|
||||
"Please specify the 'model' in your config list entry to nominate the Cohere model to use."
|
||||
)
|
||||
|
||||
# Handle structured output response format from Pydantic model
|
||||
if "response_format" in params and params["response_format"] is not None:
|
||||
self._response_format = params.get("response_format")
|
||||
|
||||
response_format = params["response_format"]
|
||||
|
||||
# Check if it's a Pydantic model
|
||||
if hasattr(response_format, "model_json_schema"):
|
||||
# Get the JSON schema from the Pydantic model
|
||||
schema = response_format.model_json_schema()
|
||||
|
||||
def resolve_ref(ref: str, defs: dict) -> dict:
|
||||
"""Resolve a $ref to its actual schema definition"""
|
||||
# Extract the definition name from "#/$defs/Name"
|
||||
def_name = ref.split("/")[-1]
|
||||
return defs[def_name]
|
||||
|
||||
def ensure_type_fields(obj: dict, defs: dict) -> dict:
|
||||
"""Recursively ensure all objects in the schema have a type and properties field"""
|
||||
if isinstance(obj, dict):
|
||||
# If it has a $ref, replace it with the actual definition
|
||||
if "$ref" in obj:
|
||||
ref_def = resolve_ref(obj["$ref"], defs)
|
||||
# Merge the reference definition with any existing fields
|
||||
obj = {**ref_def, **obj}
|
||||
# Remove the $ref as we've replaced it
|
||||
del obj["$ref"]
|
||||
|
||||
# Process each value recursively
|
||||
return {
|
||||
k: ensure_type_fields(v, defs) if isinstance(v, (dict, list)) else v for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ensure_type_fields(item, defs) for item in obj]
|
||||
return obj
|
||||
|
||||
# Make a copy of $defs before processing
|
||||
defs = schema.get("$defs", {})
|
||||
|
||||
# Process the schema
|
||||
processed_schema = ensure_type_fields(schema, defs)
|
||||
|
||||
cohere_params["response_format"] = {"type": "json_object", "json_schema": processed_schema}
|
||||
else:
|
||||
raise ValueError("response_format must be a Pydantic BaseModel")
|
||||
|
||||
# Handle strict tools parameter for structured outputs with tools
|
||||
if "tools" in params:
|
||||
cohere_params["strict_tools"] = validate_parameter(params, "strict_tools", bool, False, False, None, None)
|
||||
|
||||
# Validate allowed Cohere parameters
|
||||
# https://docs.cohere.com/reference/chat
|
||||
if "temperature" in params:
|
||||
cohere_params["temperature"] = validate_parameter(
|
||||
params, "temperature", (int, float), False, 0.3, (0, None), None
|
||||
)
|
||||
|
||||
if "max_tokens" in params:
|
||||
cohere_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
|
||||
|
||||
if "k" in params:
|
||||
cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None)
|
||||
|
||||
if "p" in params:
|
||||
cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None)
|
||||
|
||||
if "seed" in params:
|
||||
cohere_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
|
||||
|
||||
if "frequency_penalty" in params:
|
||||
cohere_params["frequency_penalty"] = validate_parameter(
|
||||
params, "frequency_penalty", (int, float), True, 0, (0, 1), None
|
||||
)
|
||||
|
||||
if "presence_penalty" in params:
|
||||
cohere_params["presence_penalty"] = validate_parameter(
|
||||
params, "presence_penalty", (int, float), True, 0, (0, 1), None
|
||||
)
|
||||
|
||||
if "tool_choice" in params:
|
||||
cohere_params["tool_choice"] = validate_parameter(
|
||||
params, "tool_choice", str, True, None, None, ["NONE", "REQUIRED"]
|
||||
)
|
||||
|
||||
return cohere_params
|
||||
|
||||
@require_optional_import("cohere", "cohere")
|
||||
def create(self, params: dict) -> ChatCompletion:
|
||||
messages = params.get("messages", [])
|
||||
client_name = params.get("client_name") or "AG2"
|
||||
cohere_tool_names = set()
|
||||
tool_calls_modified_ids = set()
|
||||
|
||||
# Parse parameters to the Cohere API's parameters
|
||||
cohere_params = self.parse_params(params)
|
||||
|
||||
cohere_params["messages"] = messages
|
||||
|
||||
if "tools" in params:
|
||||
cohere_tool_names = set([tool["function"]["name"] for tool in params["tools"]])
|
||||
cohere_params["tools"] = params["tools"]
|
||||
|
||||
# Strip out name
|
||||
for message in cohere_params["messages"]:
|
||||
message_name = message.pop("name", "")
|
||||
# Extract and prepend name to content or tool_plan if available
|
||||
message["content"] = (
|
||||
f"{message_name}: {(message.get('content') or message.get('tool_plan'))}"
|
||||
if message_name
|
||||
else (message.get("content") or message.get("tool_plan"))
|
||||
)
|
||||
|
||||
# Handle tool calls
|
||||
if message.get("tool_calls") is not None and len(message["tool_calls"]) > 0:
|
||||
message["tool_plan"] = message.get("tool_plan", message["content"])
|
||||
del message["content"] # Remove content as tool_plan is prioritized
|
||||
|
||||
# If tool call name is missing or not recognized, modify role and content
|
||||
for tool_call in message["tool_calls"] or []:
|
||||
if (not tool_call.get("function", {}).get("name")) or tool_call.get("function", {}).get(
|
||||
"name"
|
||||
) not in cohere_tool_names:
|
||||
message["role"] = "assistant"
|
||||
message["content"] = f"{message.pop('tool_plan', '')}{str(message['tool_calls'])}"
|
||||
tool_calls_modified_ids = tool_calls_modified_ids.union(
|
||||
set([tool_call.get("id") for tool_call in message["tool_calls"]])
|
||||
)
|
||||
del message["tool_calls"]
|
||||
break
|
||||
|
||||
# Adjust role if message comes from a tool with a modified ID
|
||||
if message.get("role") == "tool":
|
||||
tool_id = message.get("tool_call_id")
|
||||
if tool_id in tool_calls_modified_ids:
|
||||
message["role"] = "user"
|
||||
del message["tool_call_id"] # Remove the tool call ID
|
||||
|
||||
# We use chat model by default
|
||||
client = CohereV2(api_key=self.api_key, client_name=client_name)
|
||||
|
||||
# Token counts will be returned
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
# Stream if in parameters
|
||||
streaming = params.get("stream")
|
||||
cohere_finish = "stop"
|
||||
tool_calls = None
|
||||
ans = None
|
||||
if streaming:
|
||||
response = client.chat_stream(**cohere_params)
|
||||
# Streaming...
|
||||
ans = ""
|
||||
plan = ""
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
for chunk in response:
|
||||
if chunk.type == "content-delta":
|
||||
ans = ans + chunk.delta.message.content.text
|
||||
elif chunk.type == "tool-plan-delta":
|
||||
plan = plan + chunk.delta.message.tool_plan
|
||||
elif chunk.type == "tool-call-start":
|
||||
cohere_finish = "tool_calls"
|
||||
|
||||
# Initialize a new tool call
|
||||
tool_call = chunk.delta.message.tool_calls
|
||||
current_tool = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {"name": tool_call.function.name, "arguments": ""},
|
||||
}
|
||||
elif chunk.type == "tool-call-delta":
|
||||
# Progressively build the arguments as they stream in
|
||||
if current_tool is not None:
|
||||
current_tool["function"]["arguments"] += chunk.delta.message.tool_calls.function.arguments
|
||||
elif chunk.type == "tool-call-end":
|
||||
# Append the finished tool call to the list
|
||||
if current_tool is not None:
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(ChatCompletionMessageToolCall(**current_tool))
|
||||
current_tool = None
|
||||
elif chunk.type == "message-start":
|
||||
response_id = chunk.id
|
||||
elif chunk.type == "message-end":
|
||||
prompt_tokens = (
|
||||
chunk.delta.usage.billed_units.input_tokens
|
||||
) # Note total (billed+non-billed) available with ...usage.tokens...
|
||||
completion_tokens = chunk.delta.usage.billed_units.output_tokens
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
else:
|
||||
response = client.chat(**cohere_params)
|
||||
|
||||
if response.message.tool_calls is not None:
|
||||
ans = response.message.tool_plan
|
||||
cohere_finish = "tool_calls"
|
||||
tool_calls = []
|
||||
for tool_call in response.message.tool_calls:
|
||||
# if parameters are null, clear them out (Cohere can return a string "null" if no parameter values)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=tool_call.id,
|
||||
function={
|
||||
"name": tool_call.function.name,
|
||||
"arguments": (
|
||||
"" if tool_call.function.arguments is None else tool_call.function.arguments
|
||||
),
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
else:
|
||||
ans: str = response.message.content[0].text
|
||||
|
||||
# Not using billed_units, but that may be better for cost purposes
|
||||
prompt_tokens = (
|
||||
response.usage.billed_units.input_tokens
|
||||
) # Note total (billed+non-billed) available with ...usage.tokens...
|
||||
completion_tokens = response.usage.billed_units.output_tokens
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
response_id = response.id
|
||||
|
||||
# Clean up structured output if needed
|
||||
if self._response_format:
|
||||
# ans = clean_return_response_format(ans)
|
||||
try:
|
||||
parsed_response = self._convert_json_response(ans)
|
||||
ans = _format_json_response(parsed_response, ans)
|
||||
except ValueError as e:
|
||||
ans = str(e)
|
||||
|
||||
# 3. convert output
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=ans,
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
choices = [Choice(finish_reason=cohere_finish, index=0, message=message)]
|
||||
|
||||
response_oai = ChatCompletion(
|
||||
id=response_id,
|
||||
model=cohere_params["model"],
|
||||
created=int(time.time()),
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
),
|
||||
cost=calculate_cohere_cost(prompt_tokens, completion_tokens, cohere_params["model"]),
|
||||
)
|
||||
|
||||
return response_oai
|
||||
|
||||
def _convert_json_response(self, response: str) -> Any:
|
||||
"""Extract and validate JSON response from the output for structured outputs.
|
||||
Args:
|
||||
response (str): The response from the API.
|
||||
Returns:
|
||||
Any: The parsed JSON response.
|
||||
"""
|
||||
if not self._response_format:
|
||||
return response
|
||||
|
||||
try:
|
||||
# Parse JSON and validate against the Pydantic model
|
||||
json_data = json.loads(response)
|
||||
return self._response_format.model_validate(json_data)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to parse response as valid JSON matching the schema for Structured Output: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _format_json_response(response: Any, original_answer: str) -> str:
|
||||
"""Formats the JSON response for structured outputs using the format method if it exists."""
|
||||
return (
|
||||
response.format() if isinstance(response, FormatterProtocol) else clean_return_response_format(original_answer)
|
||||
)
|
||||
|
||||
|
||||
def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> list[dict[str, Any]]:
|
||||
temp_tool_results = []
|
||||
|
||||
for tool_call in all_tool_calls:
|
||||
if tool_call["id"] == tool_call_id:
|
||||
call = {
|
||||
"name": tool_call["function"]["name"],
|
||||
"parameters": json.loads(
|
||||
tool_call["function"]["arguments"] if tool_call["function"]["arguments"] != "" else "{}"
|
||||
),
|
||||
}
|
||||
output = [{"value": content_output}]
|
||||
temp_tool_results.append(ToolResult(call=call, outputs=output))
|
||||
return temp_tool_results
|
||||
|
||||
|
||||
def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float:
|
||||
"""Calculate the cost of the completion using the Cohere pricing."""
|
||||
total = 0.0
|
||||
|
||||
if model in COHERE_PRICING_1K:
|
||||
input_cost_per_k, output_cost_per_k = COHERE_PRICING_1K[model]
|
||||
input_cost = (input_tokens / 1000) * input_cost_per_k
|
||||
output_cost = (output_tokens / 1000) * output_cost_per_k
|
||||
total = input_cost + output_cost
|
||||
else:
|
||||
warnings.warn(f"Cost calculation not available for {model} model", UserWarning)
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def clean_return_response_format(response_str: str) -> str:
|
||||
"""Clean up the response string by parsing through json library."""
|
||||
# Parse the string to a JSON object to handle escapes
|
||||
data = json.loads(response_str)
|
||||
|
||||
# Convert back to JSON string with minimal formatting
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
class CohereError(Exception):
|
||||
"""Base class for other Cohere exceptions"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CohereRateLimitError(CohereError):
|
||||
"""Raised when rate limit is exceeded"""
|
||||
|
||||
pass
|
||||
1007
mm_agents/coact/autogen/oai/gemini.py
Normal file
1007
mm_agents/coact/autogen/oai/gemini.py
Normal file
File diff suppressed because it is too large
Load Diff
156
mm_agents/coact/autogen/oai/gemini_types.py
Normal file
156
mm_agents/coact/autogen/oai/gemini_types.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
import warnings
|
||||
from typing import Any, Optional, Type, TypeVar, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel as BaseModel
|
||||
from pydantic import ConfigDict, Field, alias_generators
|
||||
|
||||
|
||||
def _remove_extra_fields(model: Any, response: dict[str, object]) -> None:
|
||||
"""Removes extra fields from the response that are not in the model.
|
||||
|
||||
Mutates the response in place.
|
||||
"""
|
||||
|
||||
key_values = list(response.items())
|
||||
|
||||
for key, value in key_values:
|
||||
# Need to convert to snake case to match model fields names
|
||||
# ex: UsageMetadata
|
||||
alias_map = {field_info.alias: key for key, field_info in model.model_fields.items()}
|
||||
|
||||
if key not in model.model_fields and key not in alias_map:
|
||||
response.pop(key)
|
||||
continue
|
||||
|
||||
key = alias_map.get(key, key)
|
||||
|
||||
annotation = model.model_fields[key].annotation
|
||||
|
||||
# Get the BaseModel if Optional
|
||||
if get_origin(annotation) is Union:
|
||||
annotation = get_args(annotation)[0]
|
||||
|
||||
# if dict, assume BaseModel but also check that field type is not dict
|
||||
# example: FunctionCall.args
|
||||
if isinstance(value, dict) and get_origin(annotation) is not dict:
|
||||
_remove_extra_fields(annotation, value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
# assume a list of dict is list of BaseModel
|
||||
if isinstance(item, dict):
|
||||
_remove_extra_fields(get_args(annotation)[0], item)
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseModel")
|
||||
|
||||
|
||||
class CommonBaseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=alias_generators.to_camel,
|
||||
populate_by_name=True,
|
||||
from_attributes=True,
|
||||
protected_namespaces=(),
|
||||
extra="forbid",
|
||||
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
||||
arbitrary_types_allowed=True,
|
||||
ser_json_bytes="base64",
|
||||
val_json_bytes="base64",
|
||||
ignored_types=(TypeVar,),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_response(cls: Type[T], *, response: dict[str, object], kwargs: dict[str, object]) -> T:
|
||||
# To maintain forward compatibility, we need to remove extra fields from
|
||||
# the response.
|
||||
# We will provide another mechanism to allow users to access these fields.
|
||||
_remove_extra_fields(cls, response)
|
||||
validated_response = cls.model_validate(response)
|
||||
return validated_response
|
||||
|
||||
def to_json_dict(self) -> dict[str, object]:
|
||||
return self.model_dump(exclude_none=True, mode="json")
|
||||
|
||||
|
||||
class CaseInSensitiveEnum(str, enum.Enum):
|
||||
"""Case insensitive enum."""
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: Any) -> Optional["CaseInSensitiveEnum"]:
|
||||
try:
|
||||
return cls[value.upper()] # Try to access directly with uppercase
|
||||
except KeyError:
|
||||
try:
|
||||
return cls[value.lower()] # Try to access directly with lowercase
|
||||
except KeyError:
|
||||
warnings.warn(f"{value} is not a valid {cls.__name__}")
|
||||
try:
|
||||
# Creating a enum instance based on the value
|
||||
# We need to use super() to avoid infinite recursion.
|
||||
unknown_enum_val = super().__new__(cls, value)
|
||||
unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
|
||||
unknown_enum_val._value_ = value # pylint: disable=protected-access
|
||||
return unknown_enum_val
|
||||
except: # noqa: E722
|
||||
return None
|
||||
|
||||
|
||||
class FunctionCallingConfigMode(CaseInSensitiveEnum):
|
||||
"""Config for the function calling config mode."""
|
||||
|
||||
MODE_UNSPECIFIED = "MODE_UNSPECIFIED"
|
||||
AUTO = "AUTO"
|
||||
ANY = "ANY"
|
||||
NONE = "NONE"
|
||||
|
||||
|
||||
class LatLng(CommonBaseModel):
|
||||
"""An object that represents a latitude/longitude pair.
|
||||
|
||||
This is expressed as a pair of doubles to represent degrees latitude and
|
||||
degrees longitude. Unless specified otherwise, this object must conform to the
|
||||
<a href="https://en.wikipedia.org/wiki/World_Geodetic_System#1984_version">
|
||||
WGS84 standard</a>. Values must be within normalized ranges.
|
||||
"""
|
||||
|
||||
latitude: Optional[float] = Field(
|
||||
default=None,
|
||||
description="""The latitude in degrees. It must be in the range [-90.0, +90.0].""",
|
||||
)
|
||||
longitude: Optional[float] = Field(
|
||||
default=None,
|
||||
description="""The longitude in degrees. It must be in the range [-180.0, +180.0]""",
|
||||
)
|
||||
|
||||
|
||||
class FunctionCallingConfig(CommonBaseModel):
|
||||
"""Function calling config."""
|
||||
|
||||
mode: Optional[FunctionCallingConfigMode] = Field(default=None, description="""Optional. Function calling mode.""")
|
||||
allowed_function_names: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description="""Optional. Function names to call. Only set when the Mode is ANY. Function names should match [FunctionDeclaration.name]. With mode set to ANY, model will predict a function call from the set of function names provided.""",
|
||||
)
|
||||
|
||||
|
||||
class RetrievalConfig(CommonBaseModel):
|
||||
"""Retrieval config."""
|
||||
|
||||
lat_lng: Optional[LatLng] = Field(default=None, description="""Optional. The location of the user.""")
|
||||
language_code: Optional[str] = Field(default=None, description="""The language code of the user.""")
|
||||
|
||||
|
||||
class ToolConfig(CommonBaseModel):
|
||||
"""Tool config.
|
||||
|
||||
This config is shared for all tools provided in the request.
|
||||
"""
|
||||
|
||||
function_calling_config: Optional[FunctionCallingConfig] = Field(
|
||||
default=None, description="""Optional. Function calling config."""
|
||||
)
|
||||
retrieval_config: Optional[RetrievalConfig] = Field(default=None, description="""Optional. Retrieval config.""")
|
||||
305
mm_agents/coact/autogen/oai/groq.py
Normal file
305
mm_agents/coact/autogen/oai/groq.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Create an OpenAI-compatible client using Groq's API.
|
||||
|
||||
Example:
|
||||
```python
|
||||
llm_config = {
|
||||
"config_list": [{"api_type": "groq", "model": "mixtral-8x7b-32768", "api_key": os.environ.get("GROQ_API_KEY")}]
|
||||
}
|
||||
|
||||
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
||||
```
|
||||
|
||||
Install Groq's python library using: pip install --upgrade groq
|
||||
|
||||
Resources:
|
||||
- https://console.groq.com/docs/quickstart
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from ..llm_config import LLMConfigEntry, register_llm_config
|
||||
from .client_utils import should_hide_tools, validate_parameter
|
||||
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
||||
|
||||
with optional_import_block():
|
||||
from groq import Groq, Stream
|
||||
|
||||
# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
|
||||
GROQ_PRICING_1K = {
|
||||
"llama3-70b-8192": (0.00059, 0.00079),
|
||||
"mixtral-8x7b-32768": (0.00024, 0.00024),
|
||||
"llama3-8b-8192": (0.00005, 0.00008),
|
||||
"gemma-7b-it": (0.00007, 0.00007),
|
||||
}
|
||||
|
||||
|
||||
@register_llm_config
|
||||
class GroqLLMConfigEntry(LLMConfigEntry):
|
||||
api_type: Literal["groq"] = "groq"
|
||||
frequency_penalty: float = Field(default=None, ge=-2, le=2)
|
||||
max_tokens: int = Field(default=None, ge=0)
|
||||
presence_penalty: float = Field(default=None, ge=-2, le=2)
|
||||
seed: int = Field(default=None)
|
||||
stream: bool = Field(default=False)
|
||||
temperature: float = Field(default=1, ge=0, le=2)
|
||||
top_p: float = Field(default=None)
|
||||
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
|
||||
tool_choice: Optional[Literal["none", "auto", "required"]] = None
|
||||
|
||||
def create_client(self):
|
||||
raise NotImplementedError("GroqLLMConfigEntry.create_client is not implemented.")
|
||||
|
||||
|
||||
class GroqClient:
|
||||
"""Client for Groq's API."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Requires api_key or environment variable to be set
|
||||
|
||||
Args:
|
||||
**kwargs: Additional parameters to pass to the Groq API
|
||||
"""
|
||||
# Ensure we have the api_key upon instantiation
|
||||
self.api_key = kwargs.get("api_key")
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("GROQ_API_KEY")
|
||||
|
||||
assert self.api_key, (
|
||||
"Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
|
||||
)
|
||||
|
||||
if "response_format" in kwargs and kwargs["response_format"] is not None:
|
||||
warnings.warn("response_format is not supported for Groq API, it will be ignored.", UserWarning)
|
||||
self.base_url = kwargs.get("base_url")
|
||||
|
||||
def message_retrieval(self, response) -> list:
|
||||
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
||||
|
||||
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
||||
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
||||
"""
|
||||
return [choice.message for choice in response.choices]
|
||||
|
||||
def cost(self, response) -> float:
|
||||
return response.cost
|
||||
|
||||
@staticmethod
|
||||
def get_usage(response) -> dict:
|
||||
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
||||
# ... # pragma: no cover
|
||||
return {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"cost": response.cost,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
|
||||
groq_params = {}
|
||||
|
||||
# Check that we have what we need to use Groq's API
|
||||
# We won't enforce the available models as they are likely to change
|
||||
groq_params["model"] = params.get("model")
|
||||
assert groq_params["model"], (
|
||||
"Please specify the 'model' in your config list entry to nominate the Groq model to use."
|
||||
)
|
||||
|
||||
# Validate allowed Groq parameters
|
||||
# https://console.groq.com/docs/api-reference#chat
|
||||
groq_params["frequency_penalty"] = validate_parameter(
|
||||
params, "frequency_penalty", (int, float), True, None, (-2, 2), None
|
||||
)
|
||||
groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
|
||||
groq_params["presence_penalty"] = validate_parameter(
|
||||
params, "presence_penalty", (int, float), True, None, (-2, 2), None
|
||||
)
|
||||
groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
|
||||
groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
|
||||
groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
|
||||
groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
|
||||
if "tool_choice" in params:
|
||||
groq_params["tool_choice"] = validate_parameter(
|
||||
params, "tool_choice", str, True, None, None, ["none", "auto", "required"]
|
||||
)
|
||||
|
||||
# Groq parameters not supported by their models yet, ignoring
|
||||
# logit_bias, logprobs, top_logprobs
|
||||
|
||||
# Groq parameters we are ignoring:
|
||||
# n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
|
||||
# parallel_tool_calls (defaults to True), stop
|
||||
# function_call (deprecated), functions (deprecated)
|
||||
# tool_choice (none if no tools, auto if there are tools)
|
||||
|
||||
return groq_params
|
||||
|
||||
@require_optional_import("groq", "groq")
|
||||
def create(self, params: dict) -> ChatCompletion:
|
||||
messages = params.get("messages", [])
|
||||
|
||||
# Convert AG2 messages to Groq messages
|
||||
groq_messages = oai_messages_to_groq_messages(messages)
|
||||
|
||||
# Parse parameters to the Groq API's parameters
|
||||
groq_params = self.parse_params(params)
|
||||
|
||||
# Add tools to the call if we have them and aren't hiding them
|
||||
if "tools" in params:
|
||||
hide_tools = validate_parameter(
|
||||
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
|
||||
)
|
||||
if not should_hide_tools(groq_messages, params["tools"], hide_tools):
|
||||
groq_params["tools"] = params["tools"]
|
||||
|
||||
groq_params["messages"] = groq_messages
|
||||
|
||||
# We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
|
||||
client = Groq(api_key=self.api_key, max_retries=5, base_url=self.base_url)
|
||||
|
||||
# Token counts will be returned
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
# Streaming tool call recommendations
|
||||
streaming_tool_calls = []
|
||||
|
||||
ans = None
|
||||
response = client.chat.completions.create(**groq_params)
|
||||
if groq_params["stream"]:
|
||||
# Read in the chunks as they stream, taking in tool_calls which may be across
|
||||
# multiple chunks if more than one suggested
|
||||
ans = ""
|
||||
for chunk in response:
|
||||
ans = ans + (chunk.choices[0].delta.content or "")
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
# We have a tool call recommendation
|
||||
for tool_call in chunk.choices[0].delta.tool_calls:
|
||||
streaming_tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=tool_call.id,
|
||||
function={
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
prompt_tokens = chunk.x_groq.usage.prompt_tokens
|
||||
completion_tokens = chunk.x_groq.usage.completion_tokens
|
||||
total_tokens = chunk.x_groq.usage.total_tokens
|
||||
else:
|
||||
# Non-streaming finished
|
||||
ans: str = response.choices[0].message.content
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
total_tokens = response.usage.total_tokens
|
||||
|
||||
if response is not None:
|
||||
if isinstance(response, Stream):
|
||||
# Streaming response
|
||||
if chunk.choices[0].finish_reason == "tool_calls":
|
||||
groq_finish = "tool_calls"
|
||||
tool_calls = streaming_tool_calls
|
||||
else:
|
||||
groq_finish = "stop"
|
||||
tool_calls = None
|
||||
|
||||
response_content = ans
|
||||
response_id = chunk.id
|
||||
else:
|
||||
# Non-streaming response
|
||||
# If we have tool calls as the response, populate completed tool calls for our return OAI response
|
||||
if response.choices[0].finish_reason == "tool_calls":
|
||||
groq_finish = "tool_calls"
|
||||
tool_calls = []
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=tool_call.id,
|
||||
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
else:
|
||||
groq_finish = "stop"
|
||||
tool_calls = None
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
response_id = response.id
|
||||
else:
|
||||
raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
|
||||
|
||||
# 3. convert output
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
|
||||
|
||||
response_oai = ChatCompletion(
|
||||
id=response_id,
|
||||
model=groq_params["model"],
|
||||
created=int(time.time()),
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
),
|
||||
cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
|
||||
)
|
||||
|
||||
return response_oai
|
||||
|
||||
|
||||
def oai_messages_to_groq_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Groq's format.
|
||||
We correct for any specific role orders and types.
|
||||
"""
|
||||
groq_messages = copy.deepcopy(messages)
|
||||
|
||||
# Remove the name field
|
||||
for message in groq_messages:
|
||||
if "name" in message:
|
||||
message.pop("name", None)
|
||||
|
||||
return groq_messages
|
||||
|
||||
|
||||
def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
|
||||
"""Calculate the cost of the completion using the Groq pricing."""
|
||||
total = 0.0
|
||||
|
||||
if model in GROQ_PRICING_1K:
|
||||
input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
|
||||
input_cost = (input_tokens / 1000) * input_cost_per_k
|
||||
output_cost = (output_tokens / 1000) * output_cost_per_k
|
||||
total = input_cost + output_cost
|
||||
else:
|
||||
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
|
||||
|
||||
return total
|
||||
303
mm_agents/coact/autogen/oai/mistral.py
Normal file
303
mm_agents/coact/autogen/oai/mistral.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Create an OpenAI-compatible client using Mistral.AI's API.
|
||||
|
||||
Example:
|
||||
```python
|
||||
llm_config = {
|
||||
"config_list": [
|
||||
{"api_type": "mistral", "model": "open-mixtral-8x22b", "api_key": os.environ.get("MISTRAL_API_KEY")}
|
||||
]
|
||||
}
|
||||
|
||||
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
||||
```
|
||||
|
||||
Install Mistral.AI python library using: pip install --upgrade mistralai
|
||||
|
||||
Resources:
|
||||
- https://docs.mistral.ai/getting-started/quickstart/
|
||||
|
||||
NOTE: Requires mistralai package version >= 1.0.1
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from ..llm_config import LLMConfigEntry, register_llm_config
|
||||
from .client_utils import should_hide_tools, validate_parameter
|
||||
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
||||
|
||||
with optional_import_block():
|
||||
# Mistral libraries
|
||||
# pip install mistralai
|
||||
from mistralai import (
|
||||
AssistantMessage,
|
||||
Function,
|
||||
FunctionCall,
|
||||
Mistral,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
@register_llm_config
|
||||
class MistralLLMConfigEntry(LLMConfigEntry):
|
||||
api_type: Literal["mistral"] = "mistral"
|
||||
temperature: float = Field(default=0.7)
|
||||
top_p: Optional[float] = None
|
||||
max_tokens: Optional[int] = Field(default=None, ge=0)
|
||||
safe_prompt: bool = False
|
||||
random_seed: Optional[int] = None
|
||||
stream: bool = False
|
||||
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
|
||||
tool_choice: Optional[Literal["none", "auto", "any"]] = None
|
||||
|
||||
def create_client(self):
|
||||
raise NotImplementedError("MistralLLMConfigEntry.create_client is not implemented.")
|
||||
|
||||
|
||||
@require_optional_import("mistralai", "mistral")
|
||||
class MistralAIClient:
|
||||
"""Client for Mistral.AI's API."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Requires api_key or environment variable to be set
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments to pass to the Mistral client.
|
||||
"""
|
||||
# Ensure we have the api_key upon instantiation
|
||||
self.api_key = kwargs.get("api_key")
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("MISTRAL_API_KEY", None)
|
||||
|
||||
assert self.api_key, (
|
||||
"Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."
|
||||
)
|
||||
|
||||
if "response_format" in kwargs and kwargs["response_format"] is not None:
|
||||
warnings.warn("response_format is not supported for Mistral.AI, it will be ignored.", UserWarning)
|
||||
|
||||
self._client = Mistral(api_key=self.api_key)
|
||||
|
||||
def message_retrieval(self, response: ChatCompletion) -> Union[list[str], list[ChatCompletionMessage]]:
|
||||
"""Retrieve the messages from the response."""
|
||||
return [choice.message for choice in response.choices]
|
||||
|
||||
def cost(self, response) -> float:
|
||||
return response.cost
|
||||
|
||||
@require_optional_import("mistralai", "mistral")
|
||||
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Loads the parameters for Mistral.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
|
||||
mistral_params = {}
|
||||
|
||||
# 1. Validate models
|
||||
mistral_params["model"] = params.get("model")
|
||||
assert mistral_params["model"], (
|
||||
"Please specify the 'model' in your config list entry to nominate the Mistral.ai model to use."
|
||||
)
|
||||
|
||||
# 2. Validate allowed Mistral.AI parameters
|
||||
mistral_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 0.7, None, None)
|
||||
mistral_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
|
||||
mistral_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
|
||||
mistral_params["safe_prompt"] = validate_parameter(
|
||||
params, "safe_prompt", bool, False, False, None, [True, False]
|
||||
)
|
||||
mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None)
|
||||
mistral_params["tool_choice"] = validate_parameter(
|
||||
params, "tool_choice", str, False, None, None, ["none", "auto", "any"]
|
||||
)
|
||||
|
||||
# TODO
|
||||
if params.get("stream", False):
|
||||
warnings.warn(
|
||||
"Streaming is not currently supported, streaming will be disabled.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 3. Convert messages to Mistral format
|
||||
mistral_messages = []
|
||||
tool_call_ids = {} # tool call ids to function name mapping
|
||||
for message in params["messages"]:
|
||||
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
|
||||
# Convert OAI ToolCall to Mistral ToolCall
|
||||
mistral_messages_tools = []
|
||||
for toolcall in message["tool_calls"]:
|
||||
mistral_messages_tools.append(
|
||||
ToolCall(
|
||||
id=toolcall["id"],
|
||||
function=FunctionCall(
|
||||
name=toolcall["function"]["name"],
|
||||
arguments=json.loads(toolcall["function"]["arguments"]),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
mistral_messages.append(AssistantMessage(content="", tool_calls=mistral_messages_tools))
|
||||
|
||||
# Map tool call id to the function name
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
|
||||
|
||||
elif message["role"] == "system":
|
||||
if len(mistral_messages) > 0 and mistral_messages[-1].role == "assistant":
|
||||
# System messages can't appear after an Assistant message, so use a UserMessage
|
||||
mistral_messages.append(UserMessage(content=message["content"]))
|
||||
else:
|
||||
mistral_messages.append(SystemMessage(content=message["content"]))
|
||||
elif message["role"] == "assistant":
|
||||
mistral_messages.append(AssistantMessage(content=message["content"]))
|
||||
elif message["role"] == "user":
|
||||
mistral_messages.append(UserMessage(content=message["content"]))
|
||||
|
||||
elif message["role"] == "tool":
|
||||
# Indicates the result of a tool call, the name is the function name called
|
||||
mistral_messages.append(
|
||||
ToolMessage(
|
||||
name=tool_call_ids[message["tool_call_id"]],
|
||||
content=message["content"],
|
||||
tool_call_id=message["tool_call_id"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
warnings.warn(f"Unknown message role {message['role']}", UserWarning)
|
||||
|
||||
# 4. Last message needs to be user or tool, if not, add a "please continue" message
|
||||
if not isinstance(mistral_messages[-1], UserMessage) and not isinstance(mistral_messages[-1], ToolMessage):
|
||||
mistral_messages.append(UserMessage(content="Please continue."))
|
||||
|
||||
mistral_params["messages"] = mistral_messages
|
||||
|
||||
# 5. Add tools to the call if we have them and aren't hiding them
|
||||
if "tools" in params:
|
||||
hide_tools = validate_parameter(
|
||||
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
|
||||
)
|
||||
if not should_hide_tools(params["messages"], params["tools"], hide_tools):
|
||||
mistral_params["tools"] = tool_def_to_mistral(params["tools"])
|
||||
|
||||
return mistral_params
|
||||
|
||||
@require_optional_import("mistralai", "mistral")
|
||||
def create(self, params: dict[str, Any]) -> ChatCompletion:
|
||||
# 1. Parse parameters to Mistral.AI API's parameters
|
||||
mistral_params = self.parse_params(params)
|
||||
|
||||
# 2. Call Mistral.AI API
|
||||
mistral_response = self._client.chat.complete(**mistral_params)
|
||||
# TODO: Handle streaming
|
||||
|
||||
# 3. Convert Mistral response to OAI compatible format
|
||||
if mistral_response.choices[0].finish_reason == "tool_calls":
|
||||
mistral_finish = "tool_calls"
|
||||
tool_calls = []
|
||||
for tool_call in mistral_response.choices[0].message.tool_calls:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=tool_call.id,
|
||||
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
else:
|
||||
mistral_finish = "stop"
|
||||
tool_calls = None
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=mistral_response.choices[0].message.content,
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
choices = [Choice(finish_reason=mistral_finish, index=0, message=message)]
|
||||
|
||||
response_oai = ChatCompletion(
|
||||
id=mistral_response.id,
|
||||
model=mistral_response.model,
|
||||
created=int(time.time()),
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=mistral_response.usage.prompt_tokens,
|
||||
completion_tokens=mistral_response.usage.completion_tokens,
|
||||
total_tokens=mistral_response.usage.prompt_tokens + mistral_response.usage.completion_tokens,
|
||||
),
|
||||
cost=calculate_mistral_cost(
|
||||
mistral_response.usage.prompt_tokens, mistral_response.usage.completion_tokens, mistral_response.model
|
||||
),
|
||||
)
|
||||
|
||||
return response_oai
|
||||
|
||||
@staticmethod
|
||||
def get_usage(response: ChatCompletion) -> dict:
|
||||
return {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
|
||||
"total_tokens": (
|
||||
response.usage.prompt_tokens + response.usage.completion_tokens if response.usage is not None else 0
|
||||
),
|
||||
"cost": response.cost if hasattr(response, "cost") else 0,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
@require_optional_import("mistralai", "mistral")
|
||||
def tool_def_to_mistral(tool_definitions: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts AG2 tool definition to a mistral tool format"""
|
||||
mistral_tools = []
|
||||
|
||||
for autogen_tool in tool_definitions:
|
||||
mistral_tool = {
|
||||
"type": "function",
|
||||
"function": Function(
|
||||
name=autogen_tool["function"]["name"],
|
||||
description=autogen_tool["function"]["description"],
|
||||
parameters=autogen_tool["function"]["parameters"],
|
||||
),
|
||||
}
|
||||
|
||||
mistral_tools.append(mistral_tool)
|
||||
|
||||
return mistral_tools
|
||||
|
||||
|
||||
def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
|
||||
"""Calculate the cost of the mistral response."""
|
||||
# Prices per 1 thousand tokens
|
||||
# https://mistral.ai/technology/
|
||||
model_cost_map = {
|
||||
"open-mistral-7b": {"input": 0.00025, "output": 0.00025},
|
||||
"open-mixtral-8x7b": {"input": 0.0007, "output": 0.0007},
|
||||
"open-mixtral-8x22b": {"input": 0.002, "output": 0.006},
|
||||
"mistral-small-latest": {"input": 0.001, "output": 0.003},
|
||||
"mistral-medium-latest": {"input": 0.00275, "output": 0.0081},
|
||||
"mistral-large-latest": {"input": 0.0003, "output": 0.0003},
|
||||
"mistral-large-2407": {"input": 0.0003, "output": 0.0003},
|
||||
"open-mistral-nemo-2407": {"input": 0.0003, "output": 0.0003},
|
||||
"codestral-2405": {"input": 0.001, "output": 0.003},
|
||||
}
|
||||
|
||||
# Ensure we have the model they are using and return the total cost
|
||||
if model_name in model_cost_map:
|
||||
costs = model_cost_map[model_name]
|
||||
|
||||
return (input_tokens * costs["input"] / 1000) + (output_tokens * costs["output"] / 1000)
|
||||
else:
|
||||
warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning)
|
||||
return 0
|
||||
11
mm_agents/coact/autogen/oai/oai_models/__init__.py
Normal file
11
mm_agents/coact/autogen/oai/oai_models/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .chat_completion import ChatCompletionExtended as ChatCompletion
|
||||
from .chat_completion import Choice
|
||||
from .chat_completion_message import ChatCompletionMessage
|
||||
from .chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
from .completion_usage import CompletionUsage
|
||||
|
||||
__all__ = ["ChatCompletion", "ChatCompletionMessage", "ChatCompletionMessageToolCall", "Choice", "CompletionUsage"]
|
||||
16
mm_agents/coact/autogen/oai/oai_models/_models.py
Normal file
16
mm_agents/coact/autogen/oai/oai_models/_models.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Taken over from https://github.com/openai/openai-python/blob/main/src/openai/_models.py
|
||||
|
||||
import pydantic
|
||||
import pydantic.generics
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import ClassVar
|
||||
|
||||
__all__ = ["BaseModel"]
|
||||
|
||||
|
||||
class BaseModel(pydantic.BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")
|
||||
87
mm_agents/coact/autogen/oai/oai_models/chat_completion.py
Normal file
87
mm_agents/coact/autogen/oai/oai_models/chat_completion.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/chat/chat_completion.py
|
||||
|
||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ._models import BaseModel
|
||||
from .chat_completion_message import ChatCompletionMessage
|
||||
from .chat_completion_token_logprob import ChatCompletionTokenLogprob
|
||||
from .completion_usage import CompletionUsage
|
||||
|
||||
__all__ = ["ChatCompletion", "Choice", "ChoiceLogprobs"]
|
||||
|
||||
|
||||
class ChoiceLogprobs(BaseModel):
|
||||
content: Optional[List[ChatCompletionTokenLogprob]] = None
|
||||
"""A list of message content tokens with log probability information."""
|
||||
|
||||
refusal: Optional[List[ChatCompletionTokenLogprob]] = None
|
||||
"""A list of message refusal tokens with log probability information."""
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
|
||||
"""The reason the model stopped generating tokens.
|
||||
|
||||
This will be `stop` if the model hit a natural stop point or a provided stop
|
||||
sequence, `length` if the maximum number of tokens specified in the request was
|
||||
reached, `content_filter` if content was omitted due to a flag from our content
|
||||
filters, `tool_calls` if the model called a tool, or `function_call`
|
||||
(deprecated) if the model called a function.
|
||||
"""
|
||||
|
||||
index: int
|
||||
"""The index of the choice in the list of choices."""
|
||||
|
||||
logprobs: Optional[ChoiceLogprobs] = None
|
||||
"""Log probability information for the choice."""
|
||||
|
||||
message: ChatCompletionMessage
|
||||
"""A chat completion message generated by the model."""
|
||||
|
||||
|
||||
class ChatCompletion(BaseModel):
|
||||
id: str
|
||||
"""A unique identifier for the chat completion."""
|
||||
|
||||
choices: List[Choice]
|
||||
"""A list of chat completion choices.
|
||||
|
||||
Can be more than one if `n` is greater than 1.
|
||||
"""
|
||||
|
||||
created: int
|
||||
"""The Unix timestamp (in seconds) of when the chat completion was created."""
|
||||
|
||||
model: str
|
||||
"""The model used for the chat completion."""
|
||||
|
||||
object: Literal["chat.completion"]
|
||||
"""The object type, which is always `chat.completion`."""
|
||||
|
||||
service_tier: Optional[Literal["auto", "default", "flex", "scale"]] = None
|
||||
"""The service tier used for processing the request."""
|
||||
|
||||
system_fingerprint: Optional[str] = None
|
||||
"""This fingerprint represents the backend configuration that the model runs with.
|
||||
|
||||
Can be used in conjunction with the `seed` request parameter to understand when
|
||||
backend changes have been made that might impact determinism.
|
||||
"""
|
||||
|
||||
usage: Optional[CompletionUsage] = None
|
||||
"""Usage statistics for the completion request."""
|
||||
|
||||
|
||||
class ChatCompletionExtended(ChatCompletion):
|
||||
message_retrieval_function: Optional[Callable[[Any, "ChatCompletion"], list[ChatCompletionMessage]]] = None
|
||||
config_id: Optional[str] = None
|
||||
pass_filter: Optional[Callable[..., bool]] = None
|
||||
cost: Optional[float] = None
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/chat/chat_completion_audio.py
|
||||
|
||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
|
||||
from ._models import BaseModel
|
||||
|
||||
__all__ = ["ChatCompletionAudio"]
|
||||
|
||||
|
||||
class ChatCompletionAudio(BaseModel):
|
||||
id: str
|
||||
"""Unique identifier for this audio response."""
|
||||
|
||||
data: str
|
||||
"""
|
||||
Base64 encoded audio bytes generated by the model, in the format specified in
|
||||
the request.
|
||||
"""
|
||||
|
||||
expires_at: int
|
||||
"""
|
||||
The Unix timestamp (in seconds) for when this audio response will no longer be
|
||||
accessible on the server for use in multi-turn conversations.
|
||||
"""
|
||||
|
||||
transcript: str
|
||||
"""Transcript of the audio generated by the model."""
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Taken over from https://github.com/openai/openai-python/blob/16a10604fbd0d82c1382b84b417a1d6a2d33a7f1/src/openai/types/chat/chat_completion_message.py
|
||||
|
||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ._models import BaseModel
|
||||
from .chat_completion_audio import ChatCompletionAudio
|
||||
from .chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
|
||||
__all__ = ["Annotation", "AnnotationURLCitation", "ChatCompletionMessage", "FunctionCall"]
|
||||
|
||||
|
||||
class AnnotationURLCitation(BaseModel):
|
||||
end_index: int
|
||||
"""The index of the last character of the URL citation in the message."""
|
||||
|
||||
start_index: int
|
||||
"""The index of the first character of the URL citation in the message."""
|
||||
|
||||
title: str
|
||||
"""The title of the web resource."""
|
||||
|
||||
url: str
|
||||
"""The URL of the web resource."""
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
type: Literal["url_citation"]
|
||||
"""The type of the URL citation. Always `url_citation`."""
|
||||
|
||||
url_citation: AnnotationURLCitation
|
||||
"""A URL citation when using web search."""
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
arguments: str
|
||||
"""
|
||||
The arguments to call the function with, as generated by the model in JSON
|
||||
format. Note that the model does not always generate valid JSON, and may
|
||||
hallucinate parameters not defined by your function schema. Validate the
|
||||
arguments in your code before calling your function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""The name of the function to call."""
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
content: Optional[str] = None
|
||||
"""The contents of the message."""
|
||||
|
||||
refusal: Optional[str] = None
|
||||
"""The refusal message generated by the model."""
|
||||
|
||||
role: Literal["assistant"]
|
||||
"""The role of the author of this message."""
|
||||
|
||||
annotations: Optional[List[Annotation]] = None
|
||||
"""
|
||||
Annotations for the message, when applicable, as when using the
|
||||
[web search tool](https://platform.openai.com/docs/guides/tools-web-search?api-mode=chat).
|
||||
"""
|
||||
|
||||
audio: Optional[ChatCompletionAudio] = None
|
||||
"""
|
||||
If the audio output modality is requested, this object contains data about the
|
||||
audio response from the model.
|
||||
[Learn more](https://platform.openai.com/docs/guides/audio).
|
||||
"""
|
||||
|
||||
function_call: Optional[FunctionCall] = None
|
||||
"""Deprecated and replaced by `tool_calls`.
|
||||
|
||||
The name and arguments of a function that should be called, as generated by the
|
||||
model.
|
||||
"""
|
||||
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/chat/chat_completion_message_tool_call.py
|
||||
|
||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ._models import BaseModel
|
||||
|
||||
__all__ = ["ChatCompletionMessageToolCall", "Function"]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
arguments: str
|
||||
"""
|
||||
The arguments to call the function with, as generated by the model in JSON
|
||||
format. Note that the model does not always generate valid JSON, and may
|
||||
hallucinate parameters not defined by your function schema. Validate the
|
||||
arguments in your code before calling your function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""The name of the function to call."""
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCall(BaseModel):
|
||||
id: str
|
||||
"""The ID of the tool call."""
|
||||
|
||||
function: Function
|
||||
"""The function that the model called."""
|
||||
|
||||
type: Literal["function"]
|
||||
"""The type of the tool. Currently, only `function` is supported."""
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/chat/chat_completion_token_logprob.py
|
||||
|
||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from ._models import BaseModel
|
||||
|
||||
__all__ = ["ChatCompletionTokenLogprob", "TopLogprob"]
|
||||
|
||||
|
||||
class TopLogprob(BaseModel):
|
||||
token: str
|
||||
"""The token."""
|
||||
|
||||
bytes: Optional[List[int]] = None
|
||||
"""A list of integers representing the UTF-8 bytes representation of the token.
|
||||
|
||||
Useful in instances where characters are represented by multiple tokens and
|
||||
their byte representations must be combined to generate the correct text
|
||||
representation. Can be `null` if there is no bytes representation for the token.
|
||||
"""
|
||||
|
||||
logprob: float
|
||||
"""The log probability of this token, if it is within the top 20 most likely
|
||||
tokens.
|
||||
|
||||
Otherwise, the value `-9999.0` is used to signify that the token is very
|
||||
unlikely.
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionTokenLogprob(BaseModel):
|
||||
token: str
|
||||
"""The token."""
|
||||
|
||||
bytes: Optional[List[int]] = None
|
||||
"""A list of integers representing the UTF-8 bytes representation of the token.
|
||||
|
||||
Useful in instances where characters are represented by multiple tokens and
|
||||
their byte representations must be combined to generate the correct text
|
||||
representation. Can be `null` if there is no bytes representation for the token.
|
||||
"""
|
||||
|
||||
logprob: float
|
||||
"""The log probability of this token, if it is within the top 20 most likely
|
||||
tokens.
|
||||
|
||||
Otherwise, the value `-9999.0` is used to signify that the token is very
|
||||
unlikely.
|
||||
"""
|
||||
|
||||
top_logprobs: List[TopLogprob]
|
||||
"""List of the most likely tokens and their log probability, at this token
|
||||
position.
|
||||
|
||||
In rare cases, there may be fewer than the number of requested `top_logprobs`
|
||||
returned.
|
||||
"""
|
||||
60
mm_agents/coact/autogen/oai/oai_models/completion_usage.py
Normal file
60
mm_agents/coact/autogen/oai/oai_models/completion_usage.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/completion_usage.py
|
||||
|
||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ._models import BaseModel
|
||||
|
||||
__all__ = ["CompletionTokensDetails", "CompletionUsage", "PromptTokensDetails"]
|
||||
|
||||
|
||||
class CompletionTokensDetails(BaseModel):
|
||||
accepted_prediction_tokens: Optional[int] = None
|
||||
"""
|
||||
When using Predicted Outputs, the number of tokens in the prediction that
|
||||
appeared in the completion.
|
||||
"""
|
||||
|
||||
audio_tokens: Optional[int] = None
|
||||
"""Audio input tokens generated by the model."""
|
||||
|
||||
reasoning_tokens: Optional[int] = None
|
||||
"""Tokens generated by the model for reasoning."""
|
||||
|
||||
rejected_prediction_tokens: Optional[int] = None
|
||||
"""
|
||||
When using Predicted Outputs, the number of tokens in the prediction that did
|
||||
not appear in the completion. However, like reasoning tokens, these tokens are
|
||||
still counted in the total completion tokens for purposes of billing, output,
|
||||
and context window limits.
|
||||
"""
|
||||
|
||||
|
||||
class PromptTokensDetails(BaseModel):
|
||||
audio_tokens: Optional[int] = None
|
||||
"""Audio input tokens present in the prompt."""
|
||||
|
||||
cached_tokens: Optional[int] = None
|
||||
"""Cached tokens present in the prompt."""
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
completion_tokens: int
|
||||
"""Number of tokens in the generated completion."""
|
||||
|
||||
prompt_tokens: int
|
||||
"""Number of tokens in the prompt."""
|
||||
|
||||
total_tokens: int
|
||||
"""Total number of tokens used in the request (prompt + completion)."""
|
||||
|
||||
completion_tokens_details: Optional[CompletionTokensDetails] = None
|
||||
"""Breakdown of tokens used in a completion."""
|
||||
|
||||
prompt_tokens_details: Optional[PromptTokensDetails] = None
|
||||
"""Breakdown of tokens used in the prompt."""
|
||||
643
mm_agents/coact/autogen/oai/ollama.py
Normal file
643
mm_agents/coact/autogen/oai/ollama.py
Normal file
@@ -0,0 +1,643 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Create an OpenAI-compatible client using Ollama's API.
|
||||
|
||||
Example:
|
||||
```python
|
||||
llm_config = {"config_list": [{"api_type": "ollama", "model": "mistral:7b-instruct-v0.3-q6_K"}]}
|
||||
|
||||
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
||||
```
|
||||
|
||||
Install Ollama's python library using: pip install --upgrade ollama
|
||||
Install fix-busted-json library: pip install --upgrade fix-busted-json
|
||||
|
||||
Resources:
|
||||
- https://github.com/ollama/ollama-python
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from ..llm_config import LLMConfigEntry, register_llm_config
|
||||
from .client_utils import FormatterProtocol, should_hide_tools, validate_parameter
|
||||
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
||||
|
||||
with optional_import_block():
|
||||
import ollama
|
||||
from fix_busted_json import repair_json
|
||||
from ollama import Client
|
||||
|
||||
|
||||
@register_llm_config
|
||||
class OllamaLLMConfigEntry(LLMConfigEntry):
|
||||
api_type: Literal["ollama"] = "ollama"
|
||||
client_host: Optional[HttpUrl] = None
|
||||
stream: bool = False
|
||||
num_predict: int = Field(
|
||||
default=-1,
|
||||
description="Maximum number of tokens to predict, note: -1 is infinite (default), -2 is fill context.",
|
||||
)
|
||||
num_ctx: int = Field(default=2048)
|
||||
repeat_penalty: float = Field(default=1.1)
|
||||
seed: int = Field(default=0)
|
||||
temperature: float = Field(default=0.8)
|
||||
top_k: int = Field(default=40)
|
||||
top_p: float = Field(default=0.9)
|
||||
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
|
||||
|
||||
def create_client(self):
|
||||
raise NotImplementedError("OllamaLLMConfigEntry.create_client is not implemented.")
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Client for Ollama's API."""
|
||||
|
||||
# Defaults for manual tool calling
|
||||
# Instruction is added to the first system message and provides directions to follow a two step
|
||||
# process
|
||||
# 1. (before tools have been called) Return JSON with the functions to call
|
||||
# 2. (directly after tools have been called) Return Text describing the results of the function calls in text format
|
||||
|
||||
# Override using "manual_tool_call_instruction" config parameter
|
||||
TOOL_CALL_MANUAL_INSTRUCTION = (
|
||||
"You are to follow a strict two step process that will occur over "
|
||||
"a number of interactions, so pay attention to what step you are in based on the full "
|
||||
"conversation. We will be taking turns so only do one step at a time so don't perform step "
|
||||
"2 until step 1 is complete and I've told you the result. The first step is to choose one "
|
||||
"or more functions based on the request given and return only JSON with the functions and "
|
||||
"arguments to use. The second step is to analyse the given output of the function and summarise "
|
||||
"it returning only TEXT and not Python or JSON. "
|
||||
"For argument values, be sure numbers aren't strings, they should not have double quotes around them. "
|
||||
"In terms of your response format, for step 1 return only JSON and NO OTHER text, "
|
||||
"for step 2 return only text and NO JSON/Python/Markdown. "
|
||||
'The format for running a function is [{"name": "function_name1", "arguments":{"argument_name": "argument_value"}},{"name": "function_name2", "arguments":{"argument_name": "argument_value"}}] '
|
||||
'Make sure the keys "name" and "arguments" are as described. '
|
||||
"If you don't get the format correct, try again. "
|
||||
"The following functions are available to you:[FUNCTIONS_LIST]"
|
||||
)
|
||||
|
||||
# Appended to the last user message if no tools have been called
|
||||
# Override using "manual_tool_call_step1" config parameter
|
||||
TOOL_CALL_MANUAL_STEP1 = " (proceed with step 1)"
|
||||
|
||||
# Appended to the user message after tools have been executed. Will create a 'user' message if one doesn't exist.
|
||||
# Override using "manual_tool_call_step2" config parameter
|
||||
TOOL_CALL_MANUAL_STEP2 = " (proceed with step 2)"
|
||||
|
||||
def __init__(self, response_format: Optional[Union[BaseModel, dict[str, Any]]] = None, **kwargs):
|
||||
"""Note that no api_key or environment variable is required for Ollama."""
|
||||
|
||||
# Store the response format, if provided (for structured outputs)
|
||||
self._response_format: Optional[Union[BaseModel, dict[str, Any]]] = response_format
|
||||
|
||||
def message_retrieval(self, response) -> list:
|
||||
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
||||
|
||||
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
||||
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
||||
"""
|
||||
return [choice.message for choice in response.choices]
|
||||
|
||||
def cost(self, response) -> float:
|
||||
return response.cost
|
||||
|
||||
@staticmethod
|
||||
def get_usage(response) -> dict:
|
||||
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
||||
# ... # pragma: no cover
|
||||
return {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"cost": response.cost,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Loads the parameters for Ollama API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
|
||||
ollama_params = {}
|
||||
|
||||
# Check that we have what we need to use Ollama's API
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||
|
||||
# The main parameters are model, prompt, stream, and options
|
||||
# Options is a dictionary of parameters for the model
|
||||
# There are other, advanced, parameters such as format, system (to override system message), template, raw, etc. - not used
|
||||
|
||||
# We won't enforce the available models
|
||||
ollama_params["model"] = params.get("model")
|
||||
assert ollama_params["model"], (
|
||||
"Please specify the 'model' in your config list entry to nominate the Ollama model to use."
|
||||
)
|
||||
|
||||
ollama_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
|
||||
|
||||
# Build up the options dictionary
|
||||
# https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
|
||||
options_dict = {}
|
||||
|
||||
if "num_predict" in params:
|
||||
# Maximum number of tokens to predict, note: -1 is infinite, -2 is fill context, 128 is default
|
||||
options_dict["num_predict"] = validate_parameter(params, "num_predict", int, False, 128, None, None)
|
||||
|
||||
if "num_ctx" in params:
|
||||
# Set size of context window used to generate next token, 2048 is default
|
||||
options_dict["num_ctx"] = validate_parameter(params, "num_ctx", int, False, 2048, None, None)
|
||||
|
||||
if "repeat_penalty" in params:
|
||||
options_dict["repeat_penalty"] = validate_parameter(
|
||||
params, "repeat_penalty", (int, float), False, 1.1, None, None
|
||||
)
|
||||
|
||||
if "seed" in params:
|
||||
options_dict["seed"] = validate_parameter(params, "seed", int, False, 42, None, None)
|
||||
|
||||
if "temperature" in params:
|
||||
options_dict["temperature"] = validate_parameter(
|
||||
params, "temperature", (int, float), False, 0.8, None, None
|
||||
)
|
||||
|
||||
if "top_k" in params:
|
||||
options_dict["top_k"] = validate_parameter(params, "top_k", int, False, 40, None, None)
|
||||
|
||||
if "top_p" in params:
|
||||
options_dict["top_p"] = validate_parameter(params, "top_p", (int, float), False, 0.9, None, None)
|
||||
|
||||
if self._native_tool_calls and self._tools_in_conversation and not self._should_hide_tools:
|
||||
ollama_params["tools"] = params["tools"]
|
||||
|
||||
# Ollama doesn't support streaming with tools natively
|
||||
if ollama_params["stream"] and self._native_tool_calls:
|
||||
warnings.warn(
|
||||
"Streaming is not supported when using tools and 'Native' tool calling, streaming will be disabled.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
ollama_params["stream"] = False
|
||||
|
||||
if not self._native_tool_calls and self._tools_in_conversation:
|
||||
# For manual tool calling we have injected the available tools into the prompt
|
||||
# and we don't want to force JSON mode
|
||||
ollama_params["format"] = "" # Don't force JSON for manual tool calling mode
|
||||
|
||||
if len(options_dict) != 0:
|
||||
ollama_params["options"] = options_dict
|
||||
|
||||
# Structured outputs (see https://ollama.com/blog/structured-outputs)
|
||||
if not self._response_format and params.get("response_format"):
|
||||
self._response_format = params["response_format"]
|
||||
|
||||
if self._response_format:
|
||||
if isinstance(self._response_format, dict):
|
||||
ollama_params["format"] = self._response_format
|
||||
else:
|
||||
# Keep self._response_format as a Pydantic model for when process the response
|
||||
ollama_params["format"] = self._response_format.model_json_schema()
|
||||
|
||||
return ollama_params
|
||||
|
||||
@require_optional_import(["ollama", "fix_busted_json"], "ollama")
|
||||
def create(self, params: dict) -> ChatCompletion:
|
||||
messages = params.get("messages", [])
|
||||
|
||||
# Are tools involved in this conversation?
|
||||
self._tools_in_conversation = "tools" in params
|
||||
|
||||
# We provide second-level filtering out of tools to avoid LLMs re-calling tools continuously
|
||||
if self._tools_in_conversation:
|
||||
hide_tools = validate_parameter(
|
||||
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
|
||||
)
|
||||
self._should_hide_tools = should_hide_tools(messages, params["tools"], hide_tools)
|
||||
else:
|
||||
self._should_hide_tools = False
|
||||
|
||||
# Are we using native Ollama tool calling, otherwise we're doing manual tool calling
|
||||
# We allow the user to decide if they want to use Ollama's tool calling
|
||||
# or for tool calling to be handled manually through text messages
|
||||
# Default is True = Ollama's tool calling
|
||||
self._native_tool_calls = validate_parameter(params, "native_tool_calls", bool, False, True, None, None)
|
||||
|
||||
if not self._native_tool_calls:
|
||||
# Load defaults
|
||||
self._manual_tool_call_instruction = validate_parameter(
|
||||
params, "manual_tool_call_instruction", str, False, self.TOOL_CALL_MANUAL_INSTRUCTION, None, None
|
||||
)
|
||||
self._manual_tool_call_step1 = validate_parameter(
|
||||
params, "manual_tool_call_step1", str, False, self.TOOL_CALL_MANUAL_STEP1, None, None
|
||||
)
|
||||
self._manual_tool_call_step2 = validate_parameter(
|
||||
params, "manual_tool_call_step2", str, False, self.TOOL_CALL_MANUAL_STEP2, None, None
|
||||
)
|
||||
|
||||
# Convert AG2 messages to Ollama messages
|
||||
ollama_messages = self.oai_messages_to_ollama_messages(
|
||||
messages,
|
||||
(
|
||||
params["tools"]
|
||||
if (not self._native_tool_calls and self._tools_in_conversation) and not self._should_hide_tools
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Parse parameters to the Ollama API's parameters
|
||||
ollama_params = self.parse_params(params)
|
||||
|
||||
ollama_params["messages"] = ollama_messages
|
||||
|
||||
# Token counts will be returned
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
ans = None
|
||||
if "client_host" in params:
|
||||
# Convert client_host to string from HttpUrl
|
||||
client = Client(host=str(params["client_host"]))
|
||||
response = client.chat(**ollama_params)
|
||||
else:
|
||||
response = ollama.chat(**ollama_params)
|
||||
|
||||
if ollama_params["stream"]:
|
||||
# Read in the chunks as they stream, taking in tool_calls which may be across
|
||||
# multiple chunks if more than one suggested
|
||||
ans = ""
|
||||
for chunk in response:
|
||||
ans = ans + (chunk["message"]["content"] or "")
|
||||
|
||||
if "done_reason" in chunk:
|
||||
prompt_tokens = chunk.get("prompt_eval_count", 0)
|
||||
completion_tokens = chunk.get("eval_count", 0)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
else:
|
||||
# Non-streaming finished
|
||||
ans: str = response["message"]["content"]
|
||||
|
||||
prompt_tokens = response.get("prompt_eval_count", 0)
|
||||
completion_tokens = response.get("eval_count", 0)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
if response is not None:
|
||||
# Defaults
|
||||
ollama_finish = "stop"
|
||||
tool_calls = None
|
||||
|
||||
# Id and streaming text into response
|
||||
if ollama_params["stream"]:
|
||||
response_content = ans
|
||||
response_id = chunk["created_at"]
|
||||
else:
|
||||
response_content = response["message"]["content"]
|
||||
response_id = response["created_at"]
|
||||
|
||||
# Process tools in the response
|
||||
if self._tools_in_conversation:
|
||||
if self._native_tool_calls:
|
||||
if not ollama_params["stream"]:
|
||||
response_content = response["message"]["content"]
|
||||
|
||||
# Native tool calling
|
||||
if "tool_calls" in response["message"]:
|
||||
ollama_finish = "tool_calls"
|
||||
tool_calls = []
|
||||
random_id = random.randint(0, 10000)
|
||||
for tool_call in response["message"]["tool_calls"]:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"ollama_func_{random_id}",
|
||||
function={
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": json.dumps(tool_call["function"]["arguments"]),
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
|
||||
random_id += 1
|
||||
|
||||
elif not self._native_tool_calls:
|
||||
# Try to convert the response to a tool call object
|
||||
response_toolcalls = response_to_tool_call(ans)
|
||||
|
||||
# If we can, then we've got tool call(s)
|
||||
if response_toolcalls is not None:
|
||||
ollama_finish = "tool_calls"
|
||||
tool_calls = []
|
||||
random_id = random.randint(0, 10000)
|
||||
|
||||
for json_function in response_toolcalls:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"ollama_manual_func_{random_id}",
|
||||
function={
|
||||
"name": json_function["name"],
|
||||
"arguments": (
|
||||
json.dumps(json_function["arguments"])
|
||||
if "arguments" in json_function
|
||||
else "{}"
|
||||
),
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
|
||||
random_id += 1
|
||||
|
||||
# Blank the message content
|
||||
response_content = ""
|
||||
|
||||
if ollama_finish == "stop": # noqa: SIM102
|
||||
# Not a tool call, so let's check if we need to process structured output
|
||||
if self._response_format and response_content:
|
||||
try:
|
||||
parsed_response = self._convert_json_response(response_content)
|
||||
response_content = _format_json_response(parsed_response, response_content)
|
||||
except ValueError as e:
|
||||
response_content = str(e)
|
||||
else:
|
||||
raise RuntimeError("Failed to get response from Ollama.")
|
||||
|
||||
# Convert response to AG2 response
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
choices = [Choice(finish_reason=ollama_finish, index=0, message=message)]
|
||||
|
||||
response_oai = ChatCompletion(
|
||||
id=response_id,
|
||||
model=ollama_params["model"],
|
||||
created=int(time.time()),
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
),
|
||||
cost=0, # Local models, FREE!
|
||||
)
|
||||
|
||||
return response_oai
|
||||
|
||||
def oai_messages_to_ollama_messages(self, messages: list[dict[str, Any]], tools: list) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Ollama's format.
|
||||
We correct for any specific role orders and types, and convert tools to messages (as Ollama can't use tool messages)
|
||||
"""
|
||||
ollama_messages = copy.deepcopy(messages)
|
||||
|
||||
# Remove the name field
|
||||
for message in ollama_messages:
|
||||
if "name" in message:
|
||||
message.pop("name", None)
|
||||
|
||||
# Having a 'system' message on the end does not work well with Ollama, so we change it to 'user'
|
||||
# 'system' messages on the end are typical of the summarisation message: summary_method="reflection_with_llm"
|
||||
if len(ollama_messages) > 1 and ollama_messages[-1]["role"] == "system":
|
||||
ollama_messages[-1]["role"] = "user"
|
||||
|
||||
# Process messages for tool calling manually
|
||||
if tools is not None and not self._native_tool_calls:
|
||||
# 1. We need to append instructions to the starting system message on function calling
|
||||
# 2. If we have not yet called tools we append "step 1 instruction" to the latest user message
|
||||
# 3. If we have already called tools we append "step 2 instruction" to the latest user message
|
||||
|
||||
have_tool_calls = False
|
||||
have_tool_results = False
|
||||
last_tool_result_index = -1
|
||||
|
||||
for i, message in enumerate(ollama_messages):
|
||||
if "tool_calls" in message:
|
||||
have_tool_calls = True
|
||||
if "tool_call_id" in message:
|
||||
have_tool_results = True
|
||||
last_tool_result_index = i
|
||||
|
||||
tool_result_is_last_msg = have_tool_results and last_tool_result_index == len(ollama_messages) - 1
|
||||
|
||||
if ollama_messages[0]["role"] == "system":
|
||||
manual_instruction = self._manual_tool_call_instruction
|
||||
|
||||
# Build a string of the functions available
|
||||
functions_string = ""
|
||||
for function in tools:
|
||||
functions_string += f"""\n{function}\n"""
|
||||
|
||||
# Replace single quotes with double questions - Not sure why this helps the LLM perform
|
||||
# better, but it seems to. Monitor and remove if not necessary.
|
||||
functions_string = functions_string.replace("'", '"')
|
||||
|
||||
manual_instruction = manual_instruction.replace("[FUNCTIONS_LIST]", functions_string)
|
||||
|
||||
# Update the system message with the instructions and functions
|
||||
ollama_messages[0]["content"] = ollama_messages[0]["content"] + manual_instruction.rstrip()
|
||||
|
||||
# If we are still in the function calling or evaluating process, append the steps instruction
|
||||
if (not have_tool_calls or tool_result_is_last_msg) and ollama_messages[0]["role"] == "system":
|
||||
# NOTE: we require a system message to exist for the manual steps texts
|
||||
# Append the manual step instructions
|
||||
content_to_append = (
|
||||
self._manual_tool_call_step1 if not have_tool_results else self._manual_tool_call_step2
|
||||
)
|
||||
|
||||
if content_to_append != "":
|
||||
# Append the relevant tool call instruction to the latest user message
|
||||
if ollama_messages[-1]["role"] == "user":
|
||||
ollama_messages[-1]["content"] = ollama_messages[-1]["content"] + content_to_append
|
||||
else:
|
||||
ollama_messages.append({"role": "user", "content": content_to_append})
|
||||
|
||||
# Convert tool call and tool result messages to normal text messages for Ollama
|
||||
for i, message in enumerate(ollama_messages):
|
||||
if "tool_calls" in message:
|
||||
# Recommended tool calls
|
||||
content = "Run the following function(s):"
|
||||
for tool_call in message["tool_calls"]:
|
||||
content = content + "\n" + str(tool_call)
|
||||
ollama_messages[i] = {"role": "assistant", "content": content}
|
||||
if "tool_call_id" in message:
|
||||
# Executed tool results
|
||||
message["result"] = message["content"]
|
||||
del message["content"]
|
||||
del message["role"]
|
||||
content = "The following function was run: " + str(message)
|
||||
ollama_messages[i] = {"role": "user", "content": content}
|
||||
|
||||
# As we are changing messages, let's merge if they have two user messages on the end and the last one is tool call step instructions
|
||||
if (
|
||||
len(ollama_messages) >= 2
|
||||
and not self._native_tool_calls
|
||||
and ollama_messages[-2]["role"] == "user"
|
||||
and ollama_messages[-1]["role"] == "user"
|
||||
and (
|
||||
ollama_messages[-1]["content"] == self._manual_tool_call_step1
|
||||
or ollama_messages[-1]["content"] == self._manual_tool_call_step2
|
||||
)
|
||||
):
|
||||
ollama_messages[-2]["content"] = ollama_messages[-2]["content"] + ollama_messages[-1]["content"]
|
||||
del ollama_messages[-1]
|
||||
|
||||
# Ensure the last message is a user / system message, if not, add a user message
|
||||
if ollama_messages[-1]["role"] != "user" and ollama_messages[-1]["role"] != "system":
|
||||
ollama_messages.append({"role": "user", "content": "Please continue."})
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def _convert_json_response(self, response: str) -> Any:
|
||||
"""Extract and validate JSON response from the output for structured outputs.
|
||||
|
||||
Args:
|
||||
response (str): The response from the API.
|
||||
|
||||
Returns:
|
||||
Any: The parsed JSON response.
|
||||
"""
|
||||
if not self._response_format:
|
||||
return response
|
||||
|
||||
try:
|
||||
# Parse JSON and validate against the Pydantic model if Pydantic model was provided
|
||||
if isinstance(self._response_format, dict):
|
||||
return response
|
||||
else:
|
||||
return self._response_format.model_validate_json(response)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse response as valid JSON matching the schema for Structured Output: {e!s}")
|
||||
|
||||
|
||||
def _format_json_response(response: Any, original_answer: str) -> str:
|
||||
"""Formats the JSON response for structured outputs using the format method if it exists."""
|
||||
return response.format() if isinstance(response, FormatterProtocol) else original_answer
|
||||
|
||||
|
||||
@require_optional_import("fix_busted_json", "ollama")
|
||||
def response_to_tool_call(response_string: str) -> Any:
|
||||
"""Attempts to convert the response to an object, aimed to align with function format `[{},{}]`"""
|
||||
# We try and detect the list[dict[str, Any]] format:
|
||||
# Pattern 1 is [{},{}]
|
||||
# Pattern 2 is {} (without the [], so could be a single function call)
|
||||
patterns = [r"\[[\s\S]*?\]", r"\{[\s\S]*\}"]
|
||||
|
||||
for i, pattern in enumerate(patterns):
|
||||
# Search for the pattern in the input string
|
||||
matches = re.findall(pattern, response_string.strip())
|
||||
|
||||
for match in matches:
|
||||
# It has matched, extract it and load it
|
||||
json_str = match.strip()
|
||||
data_object = None
|
||||
|
||||
try:
|
||||
# Attempt to convert it as is
|
||||
data_object = json.loads(json_str)
|
||||
except Exception:
|
||||
try:
|
||||
# If that fails, attempt to repair it
|
||||
|
||||
if i == 0:
|
||||
# Enclose to a JSON object for repairing, which is restored upon fix
|
||||
fixed_json = repair_json("{'temp':" + json_str + "}")
|
||||
data_object = json.loads(fixed_json)
|
||||
data_object = data_object["temp"]
|
||||
else:
|
||||
fixed_json = repair_json(json_str)
|
||||
data_object = json.loads(fixed_json)
|
||||
except json.JSONDecodeError as e:
|
||||
if e.msg == "Invalid \\escape":
|
||||
# Handle Mistral/Mixtral trying to escape underlines with \\
|
||||
try:
|
||||
json_str = json_str.replace("\\_", "_")
|
||||
if i == 0:
|
||||
fixed_json = repair_json("{'temp':" + json_str + "}")
|
||||
data_object = json.loads(fixed_json)
|
||||
data_object = data_object["temp"]
|
||||
else:
|
||||
fixed_json = repair_json("{'temp':" + json_str + "}")
|
||||
data_object = json.loads(fixed_json)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if data_object is not None:
|
||||
data_object = _object_to_tool_call(data_object)
|
||||
|
||||
if data_object is not None:
|
||||
return data_object
|
||||
|
||||
# There's no tool call in the response
|
||||
return None
|
||||
|
||||
|
||||
def _object_to_tool_call(data_object: Any) -> list[dict[str, Any]]:
|
||||
"""Attempts to convert an object to a valid tool call object List[Dict] and returns it, if it can, otherwise None"""
|
||||
# If it's a dictionary and not a list then wrap in a list
|
||||
if isinstance(data_object, dict):
|
||||
data_object = [data_object]
|
||||
|
||||
# Validate that the data is a list of dictionaries
|
||||
if isinstance(data_object, list) and all(isinstance(item, dict) for item in data_object):
|
||||
# Perfect format, a list of dictionaries
|
||||
|
||||
# Check that each dictionary has at least 'name', optionally 'arguments' and no other keys
|
||||
is_invalid = False
|
||||
for item in data_object:
|
||||
if not is_valid_tool_call_item(item):
|
||||
is_invalid = True
|
||||
break
|
||||
|
||||
# All passed, name and (optionally) arguments exist for all entries.
|
||||
if not is_invalid:
|
||||
return data_object
|
||||
elif isinstance(data_object, list):
|
||||
# If it's a list but the items are not dictionaries, check if they are strings that can be converted to dictionaries
|
||||
data_copy = data_object.copy()
|
||||
is_invalid = False
|
||||
for i, item in enumerate(data_copy):
|
||||
try:
|
||||
new_item = eval(item)
|
||||
if isinstance(new_item, dict):
|
||||
if is_valid_tool_call_item(new_item):
|
||||
data_object[i] = new_item
|
||||
else:
|
||||
is_invalid = True
|
||||
break
|
||||
else:
|
||||
is_invalid = True
|
||||
break
|
||||
except Exception:
|
||||
is_invalid = True
|
||||
break
|
||||
|
||||
if not is_invalid:
|
||||
return data_object
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_valid_tool_call_item(call_item: dict) -> bool:
|
||||
"""Check that a dictionary item has at least 'name', optionally 'arguments' and no other keys to match a tool call JSON"""
|
||||
if "name" not in call_item or not isinstance(call_item["name"], str):
|
||||
return False
|
||||
|
||||
if set(call_item.keys()) - {"name", "arguments"}: # noqa: SIM103
|
||||
return False
|
||||
|
||||
return True
|
||||
881
mm_agents/coact/autogen/oai/openai_utils.py
Normal file
881
mm_agents/coact/autogen/oai/openai_utils.py
Normal file
@@ -0,0 +1,881 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from packaging.version import parse
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import OpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from ..llm_config import LLMConfig
|
||||
|
||||
NON_CACHE_KEY = [
|
||||
"api_key",
|
||||
"base_url",
|
||||
"api_type",
|
||||
"api_version",
|
||||
"azure_ad_token",
|
||||
"azure_ad_token_provider",
|
||||
"credentials",
|
||||
]
|
||||
DEFAULT_AZURE_API_VERSION = "2024-02-01"
|
||||
|
||||
# The below pricing is for 1K tokens. Whenever there is an update in the LLM's pricing,
|
||||
# Please convert it to 1K tokens and update in the below dictionary in the format: (input_token_price, output_token_price).
|
||||
OAI_PRICE1K = {
|
||||
# https://openai.com/api/pricing/
|
||||
# o1
|
||||
"o1-preview-2024-09-12": (0.0015, 0.0060),
|
||||
"o1-preview": (0.0015, 0.0060),
|
||||
"o1-mini-2024-09-12": (0.0003, 0.0012),
|
||||
"o1-mini": (0.0003, 0.0012),
|
||||
"o1": (0.0015, 0.0060),
|
||||
"o1-2024-12-17": (0.0015, 0.0060),
|
||||
# o1 pro
|
||||
"o1-pro": (0.15, 0.6), # $150 / $600!
|
||||
"o1-pro-2025-03-19": (0.15, 0.6),
|
||||
# o3
|
||||
"o3": (0.0011, 0.0044),
|
||||
"o3-mini-2025-01-31": (0.0011, 0.0044),
|
||||
# gpt-4o
|
||||
"gpt-4o": (0.005, 0.015),
|
||||
"gpt-4o-2024-05-13": (0.005, 0.015),
|
||||
"gpt-4o-2024-08-06": (0.0025, 0.01),
|
||||
"gpt-4o-2024-11-20": (0.0025, 0.01),
|
||||
# gpt-4o-mini
|
||||
"gpt-4o-mini": (0.000150, 0.000600),
|
||||
"gpt-4o-mini-2024-07-18": (0.000150, 0.000600),
|
||||
# gpt-4-turbo
|
||||
"gpt-4-turbo-2024-04-09": (0.01, 0.03),
|
||||
# gpt-4
|
||||
"gpt-4": (0.03, 0.06),
|
||||
"gpt-4-32k": (0.06, 0.12),
|
||||
# gpt-4.1
|
||||
"gpt-4.1": (0.002, 0.008),
|
||||
"gpt-4.1-2025-04-14": (0.002, 0.008),
|
||||
# gpt-4.1 mini
|
||||
"gpt-4.1-mini": (0.0004, 0.0016),
|
||||
"gpt-4.1-mini-2025-04-14": (0.0004, 0.0016),
|
||||
# gpt-4.1 nano
|
||||
"gpt-4.1-nano": (0.0001, 0.0004),
|
||||
"gpt-4.1-nano-2025-04-14": (0.0001, 0.0004),
|
||||
# gpt-3.5 turbo
|
||||
"gpt-3.5-turbo": (0.0005, 0.0015), # default is 0125
|
||||
"gpt-3.5-turbo-0125": (0.0005, 0.0015), # 16k
|
||||
"gpt-3.5-turbo-instruct": (0.0015, 0.002),
|
||||
# base model
|
||||
"davinci-002": 0.002,
|
||||
"babbage-002": 0.0004,
|
||||
# old model
|
||||
"gpt-4-0125-preview": (0.01, 0.03),
|
||||
"gpt-4-1106-preview": (0.01, 0.03),
|
||||
"gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images
|
||||
"gpt-3.5-turbo-1106": (0.001, 0.002),
|
||||
"gpt-3.5-turbo-0613": (0.0015, 0.002),
|
||||
# "gpt-3.5-turbo-16k": (0.003, 0.004),
|
||||
"gpt-3.5-turbo-16k-0613": (0.003, 0.004),
|
||||
"gpt-3.5-turbo-0301": (0.0015, 0.002),
|
||||
"text-ada-001": 0.0004,
|
||||
"text-babbage-001": 0.0005,
|
||||
"text-curie-001": 0.002,
|
||||
"code-cushman-001": 0.024,
|
||||
"code-davinci-002": 0.1,
|
||||
"text-davinci-002": 0.02,
|
||||
"text-davinci-003": 0.02,
|
||||
"gpt-4-0314": (0.03, 0.06), # deprecate in Sep
|
||||
"gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep
|
||||
"gpt-4-0613": (0.03, 0.06),
|
||||
"gpt-4-32k-0613": (0.06, 0.12),
|
||||
"gpt-4-turbo-preview": (0.01, 0.03),
|
||||
# https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/#pricing
|
||||
"gpt-35-turbo": (0.0005, 0.0015), # what's the default? using 0125 here.
|
||||
"gpt-35-turbo-0125": (0.0005, 0.0015),
|
||||
"gpt-35-turbo-instruct": (0.0015, 0.002),
|
||||
"gpt-35-turbo-1106": (0.001, 0.002),
|
||||
"gpt-35-turbo-0613": (0.0015, 0.002),
|
||||
"gpt-35-turbo-0301": (0.0015, 0.002),
|
||||
"gpt-35-turbo-16k": (0.003, 0.004),
|
||||
"gpt-35-turbo-16k-0613": (0.003, 0.004),
|
||||
# deepseek
|
||||
"deepseek-chat": (0.00027, 0.0011),
|
||||
}
|
||||
|
||||
|
||||
def get_key(config: dict[str, Any]) -> str:
|
||||
"""Get a unique identifier of a configuration.
|
||||
|
||||
Args:
|
||||
config (dict or list): A configuration.
|
||||
|
||||
Returns:
|
||||
tuple: A unique identifier which can be used as a key for a dict.
|
||||
"""
|
||||
copied = False
|
||||
for key in NON_CACHE_KEY:
|
||||
if key in config:
|
||||
config, copied = config.copy() if not copied else config, True
|
||||
config.pop(key)
|
||||
return to_jsonable_python(config) # type: ignore [no-any-return]
|
||||
|
||||
|
||||
def is_valid_api_key(api_key: str) -> bool:
|
||||
"""Determine if input is valid OpenAI API key. As of 2024-09-24 there's no official definition of the key structure
|
||||
so we will allow anything starting with "sk-" and having at least 48 alphanumeric (plus underscore and dash) characters.
|
||||
Keys are known to start with "sk-", "sk-proj", "sk-None", and "sk-svcaat"
|
||||
|
||||
Args:
|
||||
api_key (str): An input string to be validated.
|
||||
|
||||
Returns:
|
||||
bool: A boolean that indicates if input is valid OpenAI API key.
|
||||
"""
|
||||
api_key_re = re.compile(r"^sk-[A-Za-z0-9_-]{48,}$")
|
||||
return bool(re.fullmatch(api_key_re, api_key))
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def get_config_list(
|
||||
api_keys: list[str],
|
||||
base_urls: Optional[list[str]] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get a list of configs for OpenAI API client.
|
||||
|
||||
Args:
|
||||
api_keys (list): The api keys for openai api calls.
|
||||
base_urls (list, optional): The api bases for openai api calls. If provided, should match the length of api_keys.
|
||||
api_type (str, optional): The api type for openai api calls.
|
||||
api_version (str, optional): The api version for openai api calls.
|
||||
|
||||
Returns:
|
||||
list: A list of configs for OepnAI API calls.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Define a list of API keys
|
||||
api_keys = ["key1", "key2", "key3"]
|
||||
|
||||
# Optionally, define a list of base URLs corresponding to each API key
|
||||
base_urls = ["https://api.service1.com", "https://api.service2.com", "https://api.service3.com"]
|
||||
|
||||
# Optionally, define the API type and version if they are common for all keys
|
||||
api_type = "azure"
|
||||
api_version = "2024-02-01"
|
||||
|
||||
# Call the get_config_list function to get a list of configuration dictionaries
|
||||
config_list = get_config_list(api_keys, base_urls, api_type, api_version)
|
||||
```
|
||||
|
||||
"""
|
||||
if base_urls is not None:
|
||||
assert len(api_keys) == len(base_urls), "The length of api_keys must match the length of base_urls"
|
||||
config_list = []
|
||||
for i, api_key in enumerate(api_keys):
|
||||
if not api_key.strip():
|
||||
continue
|
||||
config = {"api_key": api_key}
|
||||
if base_urls:
|
||||
config["base_url"] = base_urls[i]
|
||||
if api_type:
|
||||
config["api_type"] = api_type
|
||||
if api_version:
|
||||
config["api_version"] = api_version
|
||||
config_list.append(config)
|
||||
return config_list
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def get_first_llm_config(
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Get the first LLM config from the given LLM config.
|
||||
|
||||
Args:
|
||||
llm_config (dict): The LLM config.
|
||||
|
||||
Returns:
|
||||
dict: The first LLM config.
|
||||
|
||||
Raises:
|
||||
ValueError: If the LLM config is invalid.
|
||||
"""
|
||||
llm_config = deepcopy(llm_config)
|
||||
if "config_list" not in llm_config:
|
||||
if "model" in llm_config:
|
||||
return llm_config # type: ignore [return-value]
|
||||
raise ValueError("llm_config must be a valid config dictionary.")
|
||||
|
||||
if len(llm_config["config_list"]) == 0:
|
||||
raise ValueError("Config list must contain at least one config.")
|
||||
|
||||
to_return = llm_config["config_list"][0]
|
||||
return to_return if isinstance(to_return, dict) else to_return.model_dump() # type: ignore [no-any-return]
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def config_list_openai_aoai(
|
||||
key_file_path: Optional[str] = ".",
|
||||
openai_api_key_file: Optional[str] = "key_openai.txt",
|
||||
aoai_api_key_file: Optional[str] = "key_aoai.txt",
|
||||
openai_api_base_file: Optional[str] = "base_openai.txt",
|
||||
aoai_api_base_file: Optional[str] = "base_aoai.txt",
|
||||
exclude: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get a list of configs for OpenAI API client (including Azure or local model deployments that support OpenAI's chat completion API).
|
||||
|
||||
This function constructs configurations by reading API keys and base URLs from environment variables or text files.
|
||||
It supports configurations for both OpenAI and Azure OpenAI services, allowing for the exclusion of one or the other.
|
||||
When text files are used, the environment variables will be overwritten.
|
||||
To prevent text files from being used, set the corresponding file name to None.
|
||||
Or set key_file_path to None to disallow reading from text files.
|
||||
|
||||
Args:
|
||||
key_file_path (str, optional): The directory path where the API key files are located. Defaults to the current directory.
|
||||
openai_api_key_file (str, optional): The filename containing the OpenAI API key. Defaults to 'key_openai.txt'.
|
||||
aoai_api_key_file (str, optional): The filename containing the Azure OpenAI API key. Defaults to 'key_aoai.txt'.
|
||||
openai_api_base_file (str, optional): The filename containing the OpenAI API base URL. Defaults to 'base_openai.txt'.
|
||||
aoai_api_base_file (str, optional): The filename containing the Azure OpenAI API base URL. Defaults to 'base_aoai.txt'.
|
||||
exclude (str, optional): The API type to exclude from the configuration list. Can be 'openai' or 'aoai'. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A list of configuration dictionaries. Each dictionary contains keys for 'api_key',
|
||||
and optionally 'base_url', 'api_type', and 'api_version'.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified key files are not found and the corresponding API key is not set in the environment variables.
|
||||
|
||||
Example:
|
||||
# To generate configurations excluding Azure OpenAI:
|
||||
configs = config_list_openai_aoai(exclude='aoai')
|
||||
|
||||
File samples:
|
||||
- key_aoai.txt
|
||||
|
||||
```
|
||||
aoai-12345abcdef67890ghijklmnopqr
|
||||
aoai-09876zyxwvuts54321fedcba
|
||||
```
|
||||
|
||||
- base_aoai.txt
|
||||
|
||||
```
|
||||
https://api.azure.com/v1
|
||||
https://api.azure2.com/v1
|
||||
```
|
||||
|
||||
Notes:
|
||||
- The function checks for API keys and base URLs in the following environment variables: 'OPENAI_API_KEY', 'AZURE_OPENAI_API_KEY',
|
||||
'OPENAI_API_BASE' and 'AZURE_OPENAI_API_BASE'. If these are not found, it attempts to read from the specified files in the
|
||||
'key_file_path' directory.
|
||||
- The API version for Azure configurations is set to DEFAULT_AZURE_API_VERSION by default.
|
||||
- If 'exclude' is set to 'openai', only Azure OpenAI configurations are returned, and vice versa.
|
||||
- The function assumes that the API keys and base URLs in the environment variables are separated by new lines if there are
|
||||
multiple entries.
|
||||
"""
|
||||
if exclude != "openai" and key_file_path is not None:
|
||||
# skip if key_file_path is None
|
||||
if openai_api_key_file is not None:
|
||||
# skip if openai_api_key_file is None
|
||||
try:
|
||||
with open(f"{key_file_path}/{openai_api_key_file}") as key_file:
|
||||
os.environ["OPENAI_API_KEY"] = key_file.read().strip()
|
||||
except FileNotFoundError:
|
||||
logging.info(
|
||||
"OPENAI_API_KEY is not found in os.environ "
|
||||
"and key_openai.txt is not found in the specified path. You can specify the api_key in the config_list."
|
||||
)
|
||||
if openai_api_base_file is not None:
|
||||
# skip if openai_api_base_file is None
|
||||
try:
|
||||
with open(f"{key_file_path}/{openai_api_base_file}") as key_file:
|
||||
os.environ["OPENAI_API_BASE"] = key_file.read().strip()
|
||||
except FileNotFoundError:
|
||||
logging.info(
|
||||
"OPENAI_API_BASE is not found in os.environ "
|
||||
"and base_openai.txt is not found in the specified path. You can specify the base_url in the config_list."
|
||||
)
|
||||
if exclude != "aoai" and key_file_path is not None:
|
||||
# skip if key_file_path is None
|
||||
if aoai_api_key_file is not None:
|
||||
try:
|
||||
with open(f"{key_file_path}/{aoai_api_key_file}") as key_file:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = key_file.read().strip()
|
||||
except FileNotFoundError:
|
||||
logging.info(
|
||||
"AZURE_OPENAI_API_KEY is not found in os.environ "
|
||||
"and key_aoai.txt is not found in the specified path. You can specify the api_key in the config_list."
|
||||
)
|
||||
if aoai_api_base_file is not None:
|
||||
try:
|
||||
with open(f"{key_file_path}/{aoai_api_base_file}") as key_file:
|
||||
os.environ["AZURE_OPENAI_API_BASE"] = key_file.read().strip()
|
||||
except FileNotFoundError:
|
||||
logging.info(
|
||||
"AZURE_OPENAI_API_BASE is not found in os.environ "
|
||||
"and base_aoai.txt is not found in the specified path. You can specify the base_url in the config_list."
|
||||
)
|
||||
aoai_config = (
|
||||
get_config_list(
|
||||
# Assuming Azure OpenAI api keys in os.environ["AZURE_OPENAI_API_KEY"], in separated lines
|
||||
api_keys=os.environ.get("AZURE_OPENAI_API_KEY", "").split("\n"),
|
||||
# Assuming Azure OpenAI api bases in os.environ["AZURE_OPENAI_API_BASE"], in separated lines
|
||||
base_urls=os.environ.get("AZURE_OPENAI_API_BASE", "").split("\n"),
|
||||
api_type="azure",
|
||||
api_version=DEFAULT_AZURE_API_VERSION,
|
||||
)
|
||||
if exclude != "aoai"
|
||||
else []
|
||||
)
|
||||
# process openai base urls
|
||||
base_urls_env_var = os.environ.get("OPENAI_API_BASE", None)
|
||||
base_urls = base_urls_env_var if base_urls_env_var is None else base_urls_env_var.split("\n")
|
||||
openai_config = (
|
||||
get_config_list(
|
||||
# Assuming OpenAI API_KEY in os.environ["OPENAI_API_KEY"]
|
||||
api_keys=os.environ.get("OPENAI_API_KEY", "").split("\n"),
|
||||
base_urls=base_urls,
|
||||
)
|
||||
if exclude != "openai"
|
||||
else []
|
||||
)
|
||||
config_list = openai_config + aoai_config
|
||||
return config_list
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def config_list_from_models(
|
||||
key_file_path: Optional[str] = ".",
|
||||
openai_api_key_file: Optional[str] = "key_openai.txt",
|
||||
aoai_api_key_file: Optional[str] = "key_aoai.txt",
|
||||
aoai_api_base_file: Optional[str] = "base_aoai.txt",
|
||||
exclude: Optional[str] = None,
|
||||
model_list: Optional[list[str]] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get a list of configs for API calls with models specified in the model list.
|
||||
|
||||
This function extends `config_list_openai_aoai` by allowing to clone its' out for each of the models provided.
|
||||
Each configuration will have a 'model' key with the model name as its value. This is particularly useful when
|
||||
all endpoints have same set of models.
|
||||
|
||||
Args:
|
||||
key_file_path (str, optional): The path to the key files.
|
||||
openai_api_key_file (str, optional): The file name of the OpenAI API key.
|
||||
aoai_api_key_file (str, optional): The file name of the Azure OpenAI API key.
|
||||
aoai_api_base_file (str, optional): The file name of the Azure OpenAI API base.
|
||||
exclude (str, optional): The API type to exclude, "openai" or "aoai".
|
||||
model_list (list, optional): The list of model names to include in the configs.
|
||||
|
||||
Returns:
|
||||
list: A list of configs for OpenAI API calls, each including model information.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Define the path where the API key files are located
|
||||
key_file_path = "/path/to/key/files"
|
||||
|
||||
# Define the file names for the OpenAI and Azure OpenAI API keys and bases
|
||||
openai_api_key_file = "key_openai.txt"
|
||||
aoai_api_key_file = "key_aoai.txt"
|
||||
aoai_api_base_file = "base_aoai.txt"
|
||||
|
||||
# Define the list of models for which to create configurations
|
||||
model_list = ["gpt-4", "gpt-3.5-turbo"]
|
||||
|
||||
# Call the function to get a list of configuration dictionaries
|
||||
config_list = config_list_from_models(
|
||||
key_file_path=key_file_path,
|
||||
openai_api_key_file=openai_api_key_file,
|
||||
aoai_api_key_file=aoai_api_key_file,
|
||||
aoai_api_base_file=aoai_api_base_file,
|
||||
model_list=model_list,
|
||||
)
|
||||
|
||||
# The `config_list` will contain configurations for the specified models, for example:
|
||||
# [
|
||||
# {'api_key': '...', 'base_url': 'https://api.openai.com', 'model': 'gpt-4'},
|
||||
# {'api_key': '...', 'base_url': 'https://api.openai.com', 'model': 'gpt-3.5-turbo'}
|
||||
# ]
|
||||
```
|
||||
"""
|
||||
config_list = config_list_openai_aoai(
|
||||
key_file_path=key_file_path,
|
||||
openai_api_key_file=openai_api_key_file,
|
||||
aoai_api_key_file=aoai_api_key_file,
|
||||
aoai_api_base_file=aoai_api_base_file,
|
||||
exclude=exclude,
|
||||
)
|
||||
if model_list:
|
||||
config_list = [{**config, "model": model} for model in model_list for config in config_list]
|
||||
return config_list
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def config_list_gpt4_gpt35(
|
||||
key_file_path: Optional[str] = ".",
|
||||
openai_api_key_file: Optional[str] = "key_openai.txt",
|
||||
aoai_api_key_file: Optional[str] = "key_aoai.txt",
|
||||
aoai_api_base_file: Optional[str] = "base_aoai.txt",
|
||||
exclude: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get a list of configs for 'gpt-4' followed by 'gpt-3.5-turbo' API calls.
|
||||
|
||||
Args:
|
||||
key_file_path (str, optional): The path to the key files.
|
||||
openai_api_key_file (str, optional): The file name of the openai api key.
|
||||
aoai_api_key_file (str, optional): The file name of the azure openai api key.
|
||||
aoai_api_base_file (str, optional): The file name of the azure openai api base.
|
||||
exclude (str, optional): The api type to exclude, "openai" or "aoai".
|
||||
|
||||
Returns:
|
||||
list: A list of configs for openai api calls.
|
||||
"""
|
||||
return config_list_from_models(
|
||||
key_file_path,
|
||||
openai_api_key_file,
|
||||
aoai_api_key_file,
|
||||
aoai_api_base_file,
|
||||
exclude,
|
||||
model_list=["gpt-4", "gpt-3.5-turbo"],
|
||||
)
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def filter_config(
|
||||
config_list: list[dict[str, Any]],
|
||||
filter_dict: Optional[dict[str, Union[list[Union[str, None]], set[Union[str, None]]]]],
|
||||
exclude: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""This function filters `config_list` by checking each configuration dictionary against the criteria specified in
|
||||
`filter_dict`. A configuration dictionary is retained if for every key in `filter_dict`, see example below.
|
||||
|
||||
Args:
|
||||
config_list (list of dict): A list of configuration dictionaries to be filtered.
|
||||
filter_dict (dict): A dictionary representing the filter criteria, where each key is a
|
||||
field name to check within the configuration dictionaries, and the
|
||||
corresponding value is a list of acceptable values for that field.
|
||||
If the configuration's field's value is not a list, then a match occurs
|
||||
when it is found in the list of acceptable values. If the configuration's
|
||||
field's value is a list, then a match occurs if there is a non-empty
|
||||
intersection with the acceptable values.
|
||||
exclude (bool): If False (the default value), configs that match the filter will be included in the returned
|
||||
list. If True, configs that match the filter will be excluded in the returned list.
|
||||
|
||||
Returns:
|
||||
list of dict: A list of configuration dictionaries that meet all the criteria specified
|
||||
in `filter_dict`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Example configuration list with various models and API types
|
||||
configs = [
|
||||
{"model": "gpt-3.5-turbo"},
|
||||
{"model": "gpt-4"},
|
||||
{"model": "gpt-3.5-turbo", "api_type": "azure"},
|
||||
{"model": "gpt-3.5-turbo", "tags": ["gpt35_turbo", "gpt-35-turbo"]},
|
||||
]
|
||||
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
|
||||
# that are also using the 'azure' API type
|
||||
filter_criteria = {
|
||||
"model": ["gpt-3.5-turbo"], # Only accept configurations for 'gpt-3.5-turbo'
|
||||
"api_type": ["azure"], # Only accept configurations for 'azure' API type
|
||||
}
|
||||
# Apply the filter to the configuration list
|
||||
filtered_configs = filter_config(configs, filter_criteria)
|
||||
# The resulting `filtered_configs` will be:
|
||||
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
|
||||
# Define a filter to select a given tag
|
||||
filter_criteria = {
|
||||
"tags": ["gpt35_turbo"],
|
||||
}
|
||||
# Apply the filter to the configuration list
|
||||
filtered_configs = filter_config(configs, filter_criteria)
|
||||
# The resulting `filtered_configs` will be:
|
||||
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
|
||||
```
|
||||
Note:
|
||||
- If `filter_dict` is empty or None, no filtering is applied and `config_list` is returned as is.
|
||||
- If a configuration dictionary in `config_list` does not contain a key specified in `filter_dict`,
|
||||
it is considered a non-match and is excluded from the result.
|
||||
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
|
||||
dictionaries that do not have that key will also be considered a match.
|
||||
|
||||
"""
|
||||
if inspect.stack()[1].function != "where":
|
||||
warnings.warn(
|
||||
"filter_config is deprecated and will be removed in a future release. "
|
||||
'Please use the "autogen.LLMConfig.from_json(path="OAI_CONFIG_LIST").where(model="gpt-4o")" method instead.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if filter_dict:
|
||||
return [
|
||||
item
|
||||
for item in config_list
|
||||
if all(_satisfies_criteria(item.get(key), values) != exclude for key, values in filter_dict.items())
|
||||
]
|
||||
return config_list
|
||||
|
||||
|
||||
def _satisfies_criteria(value: Any, criteria_values: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if isinstance(value, list):
|
||||
return bool(set(value) & set(criteria_values)) # Non-empty intersection
|
||||
else:
|
||||
# In filter_dict, filter could be either a list of values or a single value.
|
||||
# For example, filter_dict = {"model": ["gpt-3.5-turbo"]} or {"model": "gpt-3.5-turbo"}
|
||||
if isinstance(criteria_values, list):
|
||||
return value in criteria_values
|
||||
return bool(value == criteria_values)
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def config_list_from_json(
|
||||
env_or_file: str,
|
||||
file_location: Optional[str] = "",
|
||||
filter_dict: Optional[dict[str, Union[list[Union[str, None]], set[Union[str, None]]]]] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Retrieves a list of API configurations from a JSON stored in an environment variable or a file.
|
||||
|
||||
This function attempts to parse JSON data from the given `env_or_file` parameter. If `env_or_file` is an
|
||||
environment variable containing JSON data, it will be used directly. Otherwise, it is assumed to be a filename,
|
||||
and the function will attempt to read the file from the specified `file_location`.
|
||||
|
||||
The `filter_dict` parameter allows for filtering the configurations based on specified criteria. Each key in the
|
||||
`filter_dict` corresponds to a field in the configuration dictionaries, and the associated value is a list or set
|
||||
of acceptable values for that field. If a field is missing in a configuration and `None` is included in the list
|
||||
of acceptable values for that field, the configuration will still be considered a match.
|
||||
|
||||
Args:
|
||||
env_or_file (str): The name of the environment variable, the filename, or the environment variable of the filename
|
||||
that containing the JSON data.
|
||||
file_location (str, optional): The directory path where the file is located, if `env_or_file` is a filename.
|
||||
filter_dict (dict, optional): A dictionary specifying the filtering criteria for the configurations, with
|
||||
keys representing field names and values being lists or sets of acceptable values for those fields.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Suppose we have an environment variable 'CONFIG_JSON' with the following content:
|
||||
# '[{"model": "gpt-3.5-turbo", "api_type": "azure"}, {"model": "gpt-4"}]'
|
||||
|
||||
# We can retrieve a filtered list of configurations like this:
|
||||
filter_criteria = {"model": ["gpt-3.5-turbo"]}
|
||||
configs = config_list_from_json("CONFIG_JSON", filter_dict=filter_criteria)
|
||||
# The 'configs' variable will now contain only the configurations that match the filter criteria.
|
||||
```
|
||||
|
||||
Returns:
|
||||
List[Dict]: A list of configuration dictionaries that match the filtering criteria specified in `filter_dict`.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if env_or_file is neither found as an environment variable nor a file
|
||||
"""
|
||||
if inspect.stack()[1].function != "from_json":
|
||||
warnings.warn(
|
||||
"config_list_from_json is deprecated and will be removed in a future release. "
|
||||
'Please use the "autogen.LLMConfig.from_json(path="OAI_CONFIG_LIST")" method instead.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
env_str = os.environ.get(env_or_file)
|
||||
|
||||
if env_str:
|
||||
# The environment variable exists. We should use information from it.
|
||||
if os.path.exists(env_str):
|
||||
# It is a file location, and we need to load the json from the file.
|
||||
with open(env_str) as file:
|
||||
json_str = file.read()
|
||||
else:
|
||||
# Else, it should be a JSON string by itself.
|
||||
json_str = env_str
|
||||
config_list = json.loads(json_str)
|
||||
else:
|
||||
# The environment variable does not exist.
|
||||
# So, `env_or_file` is a filename. We should use the file location.
|
||||
config_list_path = os.path.join(file_location, env_or_file) if file_location is not None else env_or_file
|
||||
|
||||
with open(config_list_path) as json_file:
|
||||
config_list = json.load(json_file)
|
||||
|
||||
config_list = filter_config(config_list, filter_dict)
|
||||
|
||||
return filter_config(config_list, filter_dict)
|
||||
|
||||
|
||||
def get_config(
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Constructs a configuration dictionary for a single model with the provided API configurations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
config = get_config(api_key="sk-abcdef1234567890", base_url="https://api.openai.com", api_version="v1")
|
||||
# The 'config' variable will now contain:
|
||||
# {
|
||||
# "api_key": "sk-abcdef1234567890",
|
||||
# "base_url": "https://api.openai.com",
|
||||
# "api_version": "v1"
|
||||
# }
|
||||
```
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for authenticating API requests.
|
||||
base_url (Optional[str]): The base URL of the API. If not provided, defaults to None.
|
||||
api_type (Optional[str]): The type of API. If not provided, defaults to None.
|
||||
api_version (Optional[str]): The version of the API. If not provided, defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary containing the provided API configurations.
|
||||
"""
|
||||
config = {"api_key": api_key}
|
||||
if base_url:
|
||||
config["base_url"] = os.getenv(base_url, default=base_url)
|
||||
if api_type:
|
||||
config["api_type"] = os.getenv(api_type, default=api_type)
|
||||
if api_version:
|
||||
config["api_version"] = os.getenv(api_version, default=api_version)
|
||||
return config
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def config_list_from_dotenv(
|
||||
dotenv_file_path: Optional[str] = None,
|
||||
model_api_key_map: Optional[dict[str, Any]] = None,
|
||||
filter_dict: Optional[dict[str, Union[list[Union[str, None]], set[Union[str, None]]]]] = None,
|
||||
) -> list[dict[str, Union[str, set[str]]]]:
|
||||
"""Load API configurations from a specified .env file or environment variables and construct a list of configurations.
|
||||
|
||||
This function will:
|
||||
- Load API keys from a provided .env file or from existing environment variables.
|
||||
- Create a configuration dictionary for each model using the API keys and additional configurations.
|
||||
- Filter and return the configurations based on provided filters.
|
||||
|
||||
model_api_key_map will default to `{"gpt-4": "OPENAI_API_KEY", "gpt-3.5-turbo": "OPENAI_API_KEY"}` if none
|
||||
|
||||
Args:
|
||||
dotenv_file_path (str, optional): The path to the .env file. Defaults to None.
|
||||
model_api_key_map (str/dict, optional): A dictionary mapping models to their API key configurations.
|
||||
If a string is provided as configuration, it is considered as an environment
|
||||
variable name storing the API key.
|
||||
If a dict is provided, it should contain at least 'api_key_env_var' key,
|
||||
and optionally other API configurations like 'base_url', 'api_type', and 'api_version'.
|
||||
Defaults to a basic map with 'gpt-4' and 'gpt-3.5-turbo' mapped to 'OPENAI_API_KEY'.
|
||||
filter_dict (dict, optional): A dictionary containing the models to be loaded.
|
||||
Containing a 'model' key mapped to a set of model names to be loaded.
|
||||
Defaults to None, which loads all found configurations.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Union[str, Set[str]]]]: A list of configuration dictionaries for each model.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified .env file does not exist.
|
||||
TypeError: If an unsupported type of configuration is provided in model_api_key_map.
|
||||
"""
|
||||
if dotenv_file_path:
|
||||
dotenv_path = Path(dotenv_file_path)
|
||||
if dotenv_path.exists():
|
||||
load_dotenv(dotenv_path)
|
||||
else:
|
||||
logging.warning(f"The specified .env file {dotenv_path} does not exist.")
|
||||
else:
|
||||
dotenv_path_str = find_dotenv()
|
||||
if not dotenv_path_str:
|
||||
logging.warning("No .env file found. Loading configurations from environment variables.")
|
||||
dotenv_path = Path(dotenv_path_str)
|
||||
load_dotenv(dotenv_path)
|
||||
|
||||
# Ensure the model_api_key_map is not None to prevent TypeErrors during key assignment.
|
||||
model_api_key_map = model_api_key_map or {}
|
||||
|
||||
# Ensure default models are always considered
|
||||
default_models = ["gpt-4", "gpt-3.5-turbo"]
|
||||
|
||||
for model in default_models:
|
||||
# Only assign default API key if the model is not present in the map.
|
||||
# If model is present but set to invalid/empty, do not overwrite.
|
||||
if model not in model_api_key_map:
|
||||
model_api_key_map[model] = "OPENAI_API_KEY"
|
||||
|
||||
env_var = []
|
||||
# Loop over the models and create configuration dictionaries
|
||||
for model, config in model_api_key_map.items():
|
||||
if isinstance(config, str):
|
||||
api_key_env_var = config
|
||||
config_dict = get_config(api_key=os.getenv(api_key_env_var))
|
||||
elif isinstance(config, dict):
|
||||
api_key = os.getenv(config.get("api_key_env_var", "OPENAI_API_KEY"))
|
||||
config_without_key_var = {k: v for k, v in config.items() if k != "api_key_env_var"}
|
||||
config_dict = get_config(api_key=api_key, **config_without_key_var)
|
||||
else:
|
||||
logging.warning(f"Unsupported type {type(config)} for model {model} configuration")
|
||||
|
||||
if not config_dict["api_key"] or config_dict["api_key"].strip() == "":
|
||||
logging.warning(
|
||||
f"API key not found or empty for model {model}. Please ensure path to .env file is correct."
|
||||
)
|
||||
continue # Skip this configuration and continue with the next
|
||||
|
||||
# Add model to the configuration and append to the list
|
||||
config_dict["model"] = model
|
||||
env_var.append(config_dict)
|
||||
|
||||
fd, temp_name = tempfile.mkstemp()
|
||||
try:
|
||||
with os.fdopen(fd, "w+") as temp:
|
||||
env_var_str = json.dumps(env_var)
|
||||
temp.write(env_var_str)
|
||||
temp.flush()
|
||||
|
||||
# Assuming config_list_from_json is a valid function from your code
|
||||
config_list = config_list_from_json(env_or_file=temp_name, filter_dict=filter_dict)
|
||||
finally:
|
||||
# The file is deleted after using its name (to prevent windows build from breaking)
|
||||
os.remove(temp_name)
|
||||
|
||||
if len(config_list) == 0:
|
||||
logging.error("No configurations loaded.")
|
||||
return []
|
||||
|
||||
logging.info(f"Models available: {[config['model'] for config in config_list]}")
|
||||
return config_list
|
||||
|
||||
|
||||
def retrieve_assistants_by_name(client: "OpenAI", name: str) -> list["Assistant"]:
|
||||
"""Return the assistants with the given name from OAI assistant API"""
|
||||
assistants = client.beta.assistants.list()
|
||||
candidate_assistants = []
|
||||
for assistant in assistants.data:
|
||||
if assistant.name == name:
|
||||
candidate_assistants.append(assistant)
|
||||
return candidate_assistants
|
||||
|
||||
|
||||
def detect_gpt_assistant_api_version() -> str:
|
||||
"""Detect the openai assistant API version"""
|
||||
oai_version = importlib.metadata.version("openai")
|
||||
return "v1" if parse(oai_version) < parse("1.21") else "v2"
|
||||
|
||||
|
||||
def create_gpt_vector_store(client: "OpenAI", name: str, fild_ids: list[str]) -> Any:
|
||||
"""Create a openai vector store for gpt assistant"""
|
||||
try:
|
||||
vector_store = client.vector_stores.create(name=name)
|
||||
except Exception as e:
|
||||
raise AttributeError(f"Failed to create vector store, please install the latest OpenAI python package: {e}")
|
||||
|
||||
# poll the status of the file batch for completion.
|
||||
batch = client.vector_stores.file_batches.create_and_poll(vector_store_id=vector_store.id, file_ids=fild_ids)
|
||||
|
||||
if batch.status == "in_progress":
|
||||
time.sleep(1)
|
||||
logging.debug(f"file batch status: {batch.file_counts}")
|
||||
batch = client.vector_stores.file_batches.poll(vector_store_id=vector_store.id, batch_id=batch.id)
|
||||
|
||||
if batch.status == "completed":
|
||||
return vector_store
|
||||
|
||||
raise ValueError(f"Failed to upload files to vector store {vector_store.id}:{batch.status}")
|
||||
|
||||
|
||||
def create_gpt_assistant(
|
||||
client: "OpenAI", name: str, instructions: str, model: str, assistant_config: dict[str, Any]
|
||||
) -> "Assistant":
|
||||
"""Create a openai gpt assistant"""
|
||||
assistant_create_kwargs = {}
|
||||
gpt_assistant_api_version = detect_gpt_assistant_api_version()
|
||||
tools = assistant_config.get("tools", [])
|
||||
|
||||
if gpt_assistant_api_version == "v2":
|
||||
tool_resources = assistant_config.get("tool_resources", {})
|
||||
file_ids = assistant_config.get("file_ids")
|
||||
if tool_resources.get("file_search") is not None and file_ids is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both `tool_resources['file_search']` tool and `file_ids` in the assistant config."
|
||||
)
|
||||
|
||||
# Designed for backwards compatibility for the V1 API
|
||||
# Instead of V1 AssistantFile, files are attached to Assistants using the tool_resources object.
|
||||
for tool in tools:
|
||||
if tool["type"] == "retrieval":
|
||||
tool["type"] = "file_search"
|
||||
if file_ids is not None:
|
||||
# create a vector store for the file search tool
|
||||
vs = create_gpt_vector_store(client, f"{name}-vectorestore", file_ids)
|
||||
tool_resources["file_search"] = {
|
||||
"vector_store_ids": [vs.id],
|
||||
}
|
||||
elif tool["type"] == "code_interpreter" and file_ids is not None:
|
||||
tool_resources["code_interpreter"] = {
|
||||
"file_ids": file_ids,
|
||||
}
|
||||
|
||||
assistant_create_kwargs["tools"] = tools
|
||||
if len(tool_resources) > 0:
|
||||
assistant_create_kwargs["tool_resources"] = tool_resources
|
||||
else:
|
||||
# not support forwards compatibility
|
||||
if "tool_resources" in assistant_config:
|
||||
raise ValueError("`tool_resources` argument are not supported in the openai assistant V1 API.")
|
||||
if any(tool["type"] == "file_search" for tool in tools):
|
||||
raise ValueError(
|
||||
"`file_search` tool are not supported in the openai assistant V1 API, please use `retrieval`."
|
||||
)
|
||||
assistant_create_kwargs["tools"] = tools
|
||||
assistant_create_kwargs["file_ids"] = assistant_config.get("file_ids", [])
|
||||
|
||||
logging.info(f"Creating assistant with config: {assistant_create_kwargs}")
|
||||
return client.beta.assistants.create(name=name, instructions=instructions, model=model, **assistant_create_kwargs)
|
||||
|
||||
|
||||
def update_gpt_assistant(client: "OpenAI", assistant_id: str, assistant_config: dict[str, Any]) -> "Assistant":
|
||||
"""Update openai gpt assistant"""
|
||||
gpt_assistant_api_version = detect_gpt_assistant_api_version()
|
||||
assistant_update_kwargs = {}
|
||||
|
||||
if assistant_config.get("tools") is not None:
|
||||
assistant_update_kwargs["tools"] = assistant_config["tools"]
|
||||
|
||||
if assistant_config.get("instructions") is not None:
|
||||
assistant_update_kwargs["instructions"] = assistant_config["instructions"]
|
||||
|
||||
if gpt_assistant_api_version == "v2":
|
||||
if assistant_config.get("tool_resources") is not None:
|
||||
assistant_update_kwargs["tool_resources"] = assistant_config["tool_resources"]
|
||||
else:
|
||||
if assistant_config.get("file_ids") is not None:
|
||||
assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]
|
||||
|
||||
return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)
|
||||
|
||||
|
||||
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
|
||||
if isinstance(config_value, list):
|
||||
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
|
||||
else:
|
||||
return config_value in acceptable_values
|
||||
370
mm_agents/coact/autogen/oai/together.py
Normal file
370
mm_agents/coact/autogen/oai/together.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Create an OpenAI-compatible client using Together.AI's API.
|
||||
|
||||
Example:
|
||||
```python
|
||||
llm_config = {
|
||||
"config_list": [
|
||||
{
|
||||
"api_type": "together",
|
||||
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"api_key": os.environ.get("TOGETHER_API_KEY"),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
||||
```
|
||||
|
||||
Install Together.AI python library using: pip install --upgrade together
|
||||
|
||||
Resources:
|
||||
- https://docs.together.ai/docs/inference-python
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from ..llm_config import LLMConfigEntry, register_llm_config
|
||||
from .client_utils import should_hide_tools, validate_parameter
|
||||
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
||||
|
||||
with optional_import_block():
|
||||
from together import Together
|
||||
|
||||
|
||||
@register_llm_config
|
||||
class TogetherLLMConfigEntry(LLMConfigEntry):
|
||||
api_type: Literal["together"] = "together"
|
||||
max_tokens: int = Field(default=512, ge=0)
|
||||
stream: bool = False
|
||||
temperature: Optional[float] = Field(default=None)
|
||||
top_p: Optional[float] = Field(default=None)
|
||||
top_k: Optional[int] = Field(default=None)
|
||||
repetition_penalty: Optional[float] = Field(default=None)
|
||||
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||
min_p: Optional[float] = Field(default=None, ge=0, le=1)
|
||||
safety_model: Optional[str] = None
|
||||
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
|
||||
price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
|
||||
tool_choice: Optional[Union[str, dict[str, Union[str, dict[str, str]]]]] = (
|
||||
None # dict is the tool to call: {"type": "function", "function": {"name": "my_function"}}
|
||||
)
|
||||
|
||||
def create_client(self):
|
||||
raise NotImplementedError("TogetherLLMConfigEntry.create_client is not implemented.")
|
||||
|
||||
|
||||
class TogetherClient:
|
||||
"""Client for Together.AI's API."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Requires api_key or environment variable to be set
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments to pass to the client.
|
||||
"""
|
||||
# Ensure we have the api_key upon instantiation
|
||||
self.api_key = kwargs.get("api_key")
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("TOGETHER_API_KEY")
|
||||
|
||||
if "response_format" in kwargs and kwargs["response_format"] is not None:
|
||||
warnings.warn("response_format is not supported for Together.AI, it will be ignored.", UserWarning)
|
||||
|
||||
assert self.api_key, (
|
||||
"Please include the api_key in your config list entry for Together.AI or set the TOGETHER_API_KEY env variable."
|
||||
)
|
||||
|
||||
def message_retrieval(self, response) -> list:
|
||||
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
||||
|
||||
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
||||
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
||||
"""
|
||||
return [choice.message for choice in response.choices]
|
||||
|
||||
def cost(self, response) -> float:
|
||||
return response.cost
|
||||
|
||||
@staticmethod
|
||||
def get_usage(response) -> dict:
|
||||
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
||||
# ... # pragma: no cover
|
||||
return {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"cost": response.cost,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Loads the parameters for Together.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
|
||||
together_params = {}
|
||||
|
||||
# Check that we have what we need to use Together.AI's API
|
||||
together_params["model"] = params.get("model")
|
||||
assert together_params["model"], (
|
||||
"Please specify the 'model' in your config list entry to nominate the Together.AI model to use."
|
||||
)
|
||||
|
||||
# Validate allowed Together.AI parameters
|
||||
# https://github.com/togethercomputer/together-python/blob/94ffb30daf0ac3e078be986af7228f85f79bde99/src/together/resources/completions.py#L44
|
||||
together_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, 512, (0, None), None)
|
||||
together_params["stream"] = validate_parameter(params, "stream", bool, False, False, None, None)
|
||||
together_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, None, None, None)
|
||||
together_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
|
||||
together_params["top_k"] = validate_parameter(params, "top_k", int, True, None, None, None)
|
||||
together_params["repetition_penalty"] = validate_parameter(
|
||||
params, "repetition_penalty", float, True, None, None, None
|
||||
)
|
||||
together_params["presence_penalty"] = validate_parameter(
|
||||
params, "presence_penalty", (int, float), True, None, (-2, 2), None
|
||||
)
|
||||
together_params["frequency_penalty"] = validate_parameter(
|
||||
params, "frequency_penalty", (int, float), True, None, (-2, 2), None
|
||||
)
|
||||
together_params["min_p"] = validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None)
|
||||
together_params["safety_model"] = validate_parameter(
|
||||
params, "safety_model", str, True, None, None, None
|
||||
) # We won't enforce the available models as they are likely to change
|
||||
|
||||
# Check if they want to stream and use tools, which isn't currently supported (TODO)
|
||||
if together_params["stream"] and "tools" in params:
|
||||
warnings.warn(
|
||||
"Streaming is not supported when using tools, streaming will be disabled.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
together_params["stream"] = False
|
||||
|
||||
if "tool_choice" in params:
|
||||
together_params["tool_choice"] = params["tool_choice"]
|
||||
|
||||
return together_params
|
||||
|
||||
@require_optional_import("together", "together")
|
||||
def create(self, params: dict) -> ChatCompletion:
|
||||
messages = params.get("messages", [])
|
||||
|
||||
# Convert AG2 messages to Together.AI messages
|
||||
together_messages = oai_messages_to_together_messages(messages)
|
||||
|
||||
# Parse parameters to Together.AI API's parameters
|
||||
together_params = self.parse_params(params)
|
||||
|
||||
# Add tools to the call if we have them and aren't hiding them
|
||||
if "tools" in params:
|
||||
hide_tools = validate_parameter(
|
||||
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
|
||||
)
|
||||
if not should_hide_tools(together_messages, params["tools"], hide_tools):
|
||||
together_params["tools"] = params["tools"]
|
||||
|
||||
together_params["messages"] = together_messages
|
||||
|
||||
# We use chat model by default
|
||||
client = Together(api_key=self.api_key)
|
||||
|
||||
# Token counts will be returned
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
response = client.chat.completions.create(**together_params)
|
||||
if together_params["stream"]:
|
||||
# Read in the chunks as they stream
|
||||
ans = ""
|
||||
for chunk in response:
|
||||
ans = ans + (chunk.choices[0].delta.content or "")
|
||||
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
total_tokens = chunk.usage.total_tokens
|
||||
else:
|
||||
ans: str = response.choices[0].message.content
|
||||
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
total_tokens = response.usage.total_tokens
|
||||
|
||||
if response.choices[0].finish_reason == "tool_calls":
|
||||
together_finish = "tool_calls"
|
||||
tool_calls = []
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=tool_call.id,
|
||||
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
else:
|
||||
together_finish = "stop"
|
||||
tool_calls = None
|
||||
|
||||
# 3. convert output
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=response.choices[0].message.content,
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
choices = [Choice(finish_reason=together_finish, index=0, message=message)]
|
||||
|
||||
response_oai = ChatCompletion(
|
||||
id=response.id,
|
||||
model=together_params["model"],
|
||||
created=int(time.time()),
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
),
|
||||
cost=calculate_together_cost(prompt_tokens, completion_tokens, together_params["model"]),
|
||||
)
|
||||
|
||||
return response_oai
|
||||
|
||||
|
||||
def oai_messages_to_together_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Together.AI format.
|
||||
We correct for any specific role orders and types.
|
||||
"""
|
||||
together_messages = copy.deepcopy(messages)
|
||||
|
||||
# If we have a message with role='tool', which occurs when a function is executed, change it to 'user'
|
||||
for msg in together_messages:
|
||||
if "role" in msg and msg["role"] == "tool":
|
||||
msg["role"] = "user"
|
||||
|
||||
return together_messages
|
||||
|
||||
|
||||
# MODELS AND COSTS
|
||||
chat_lang_code_model_sizes = {
|
||||
"zero-one-ai/Yi-34B-Chat": 34,
|
||||
"allenai/OLMo-7B-Instruct": 7,
|
||||
"allenai/OLMo-7B-Twin-2T": 7,
|
||||
"allenai/OLMo-7B": 7,
|
||||
"Austism/chronos-hermes-13b": 13,
|
||||
"deepseek-ai/deepseek-coder-33b-instruct": 33,
|
||||
"deepseek-ai/deepseek-llm-67b-chat": 67,
|
||||
"garage-bAInd/Platypus2-70B-instruct": 70,
|
||||
"google/gemma-2b-it": 2,
|
||||
"google/gemma-7b-it": 7,
|
||||
"Gryphe/MythoMax-L2-13b": 13,
|
||||
"lmsys/vicuna-13b-v1.5": 13,
|
||||
"lmsys/vicuna-7b-v1.5": 7,
|
||||
"codellama/CodeLlama-13b-Instruct-hf": 13,
|
||||
"codellama/CodeLlama-34b-Instruct-hf": 34,
|
||||
"codellama/CodeLlama-70b-Instruct-hf": 70,
|
||||
"codellama/CodeLlama-7b-Instruct-hf": 7,
|
||||
"meta-llama/Llama-2-70b-chat-hf": 70,
|
||||
"meta-llama/Llama-2-13b-chat-hf": 13,
|
||||
"meta-llama/Llama-2-7b-chat-hf": 7,
|
||||
"meta-llama/Llama-3-8b-chat-hf": 8,
|
||||
"meta-llama/Llama-3-70b-chat-hf": 70,
|
||||
"mistralai/Mistral-7B-Instruct-v0.1": 7,
|
||||
"mistralai/Mistral-7B-Instruct-v0.2": 7,
|
||||
"mistralai/Mistral-7B-Instruct-v0.3": 7,
|
||||
"NousResearch/Nous-Capybara-7B-V1p9": 7,
|
||||
"NousResearch/Nous-Hermes-llama-2-7b": 7,
|
||||
"NousResearch/Nous-Hermes-Llama2-13b": 13,
|
||||
"NousResearch/Nous-Hermes-2-Yi-34B": 34,
|
||||
"openchat/openchat-3.5-1210": 7,
|
||||
"Open-Orca/Mistral-7B-OpenOrca": 7,
|
||||
"Qwen/Qwen1.5-0.5B-Chat": 0.5,
|
||||
"Qwen/Qwen1.5-1.8B-Chat": 1.8,
|
||||
"Qwen/Qwen1.5-4B-Chat": 4,
|
||||
"Qwen/Qwen1.5-7B-Chat": 7,
|
||||
"Qwen/Qwen1.5-14B-Chat": 14,
|
||||
"Qwen/Qwen1.5-32B-Chat": 32,
|
||||
"Qwen/Qwen1.5-72B-Chat": 72,
|
||||
"Qwen/Qwen1.5-110B-Chat": 110,
|
||||
"Qwen/Qwen2-72B-Instruct": 72,
|
||||
"snorkelai/Snorkel-Mistral-PairRM-DPO": 7,
|
||||
"togethercomputer/alpaca-7b": 7,
|
||||
"teknium/OpenHermes-2-Mistral-7B": 7,
|
||||
"teknium/OpenHermes-2p5-Mistral-7B": 7,
|
||||
"togethercomputer/Llama-2-7B-32K-Instruct": 7,
|
||||
"togethercomputer/RedPajama-INCITE-Chat-3B-v1": 3,
|
||||
"togethercomputer/RedPajama-INCITE-7B-Chat": 7,
|
||||
"togethercomputer/StripedHyena-Nous-7B": 7,
|
||||
"Undi95/ReMM-SLERP-L2-13B": 13,
|
||||
"Undi95/Toppy-M-7B": 7,
|
||||
"WizardLM/WizardLM-13B-V1.2": 13,
|
||||
"upstage/SOLAR-10.7B-Instruct-v1.0": 11,
|
||||
}
|
||||
|
||||
# Cost per million tokens based on up to X Billion parameters, e.g. up 4B is $0.1/million
|
||||
chat_lang_code_model_costs = {4: 0.1, 8: 0.2, 21: 0.3, 41: 0.8, 80: 0.9, 110: 1.8}
|
||||
|
||||
mixture_model_sizes = {
|
||||
"cognitivecomputations/dolphin-2.5-mixtral-8x7b": 56,
|
||||
"databricks/dbrx-instruct": 132,
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": 47,
|
||||
"mistralai/Mixtral-8x22B-Instruct-v0.1": 141,
|
||||
"NousResearch/Nous-Hermes-2-Mistral-7B-DPO": 7,
|
||||
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 47,
|
||||
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT": 47,
|
||||
"Snowflake/snowflake-arctic-instruct": 480,
|
||||
}
|
||||
|
||||
# Cost per million tokens based on up to X Billion parameters, e.g. up 56B is $0.6/million
|
||||
mixture_costs = {56: 0.6, 176: 1.2, 480: 2.4}
|
||||
|
||||
|
||||
def calculate_together_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
|
||||
"""Cost calculation for inference"""
|
||||
if model_name in chat_lang_code_model_sizes or model_name in mixture_model_sizes:
|
||||
cost_per_mil = 0
|
||||
|
||||
# Chat, Language, Code models
|
||||
if model_name in chat_lang_code_model_sizes:
|
||||
size_in_b = chat_lang_code_model_sizes[model_name]
|
||||
|
||||
for top_size in chat_lang_code_model_costs:
|
||||
if size_in_b <= top_size:
|
||||
cost_per_mil = chat_lang_code_model_costs[top_size]
|
||||
break
|
||||
|
||||
else:
|
||||
# Mixture-of-experts
|
||||
size_in_b = mixture_model_sizes[model_name]
|
||||
|
||||
for top_size in mixture_costs:
|
||||
if size_in_b <= top_size:
|
||||
cost_per_mil = mixture_costs[top_size]
|
||||
break
|
||||
|
||||
if cost_per_mil == 0:
|
||||
warnings.warn("Model size doesn't align with cost structure.", UserWarning)
|
||||
|
||||
return cost_per_mil * ((input_tokens + output_tokens) / 1e6)
|
||||
|
||||
else:
|
||||
# Model is not in our list of models, can't determine the cost
|
||||
warnings.warn(
|
||||
"The model isn't catered for costing, to apply costs you can use the 'price' key on your config_list.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
return 0
|
||||
Reference in New Issue
Block a user