CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -0,0 +1,53 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from ..cache.cache import Cache
from .anthropic import AnthropicLLMConfigEntry
from .bedrock import BedrockLLMConfigEntry
from .cerebras import CerebrasLLMConfigEntry
from .client import AzureOpenAILLMConfigEntry, DeepSeekLLMConfigEntry, ModelClient, OpenAILLMConfigEntry, OpenAIWrapper
from .cohere import CohereLLMConfigEntry
from .gemini import GeminiLLMConfigEntry
from .groq import GroqLLMConfigEntry
from .mistral import MistralLLMConfigEntry
from .ollama import OllamaLLMConfigEntry
from .openai_utils import (
config_list_from_dotenv,
config_list_from_json,
config_list_from_models,
config_list_gpt4_gpt35,
config_list_openai_aoai,
filter_config,
get_config_list,
get_first_llm_config,
)
from .together import TogetherLLMConfigEntry
__all__ = [
"AnthropicLLMConfigEntry",
"AzureOpenAILLMConfigEntry",
"BedrockLLMConfigEntry",
"Cache",
"CerebrasLLMConfigEntry",
"CohereLLMConfigEntry",
"DeepSeekLLMConfigEntry",
"GeminiLLMConfigEntry",
"GroqLLMConfigEntry",
"MistralLLMConfigEntry",
"ModelClient",
"OllamaLLMConfigEntry",
"OpenAILLMConfigEntry",
"OpenAIWrapper",
"TogetherLLMConfigEntry",
"config_list_from_dotenv",
"config_list_from_json",
"config_list_from_models",
"config_list_gpt4_gpt35",
"config_list_openai_aoai",
"filter_config",
"get_config_list",
"get_first_llm_config",
]

View File

@@ -0,0 +1,714 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Create an OpenAI-compatible client for the Anthropic API.
Example usage:
Install the `anthropic` package by running `pip install --upgrade anthropic`.
- https://docs.anthropic.com/en/docs/quickstart-guide
```python
import autogen
config_list = [
{
"model": "claude-3-sonnet-20240229",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
"api_type": "anthropic",
}
]
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
```
Example usage for Anthropic Bedrock:
Install the `anthropic` package by running `pip install --upgrade anthropic`.
- https://docs.anthropic.com/en/docs/quickstart-guide
```python
import autogen
config_list = [
{
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"aws_access_key":<accessKey>,
"aws_secret_key":<secretKey>,
"aws_session_token":<sessionTok>,
"aws_region":"us-east-1",
"api_type": "anthropic",
}
]
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
```
Example usage for Anthropic VertexAI:
Install the `anthropic` package by running `pip install anthropic[vertex]`.
- https://docs.anthropic.com/en/docs/quickstart-guide
```python
import autogen
config_list = [
{
"model": "claude-3-5-sonnet-20240620-v1:0",
"gcp_project_id": "dummy_project_id",
"gcp_region": "us-west-2",
"gcp_auth_token": "dummy_auth_token",
"api_type": "anthropic",
}
]
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
```python
"""
from __future__ import annotations
import inspect
import json
import os
import re
import time
import warnings
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, Field
from ..import_utils import optional_import_block, require_optional_import
from ..llm_config import LLMConfigEntry, register_llm_config
from .client_utils import FormatterProtocol, validate_parameter
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
with optional_import_block():
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
from anthropic import __version__ as anthropic_version
from anthropic.types import Message, TextBlock, ToolUseBlock
TOOL_ENABLED = anthropic_version >= "0.23.1"
if TOOL_ENABLED:
pass
ANTHROPIC_PRICING_1k = {
"claude-3-7-sonnet-20250219": (0.003, 0.015),
"claude-3-5-sonnet-20241022": (0.003, 0.015),
"claude-3-5-haiku-20241022": (0.0008, 0.004),
"claude-3-5-sonnet-20240620": (0.003, 0.015),
"claude-3-sonnet-20240229": (0.003, 0.015),
"claude-3-opus-20240229": (0.015, 0.075),
"claude-3-haiku-20240307": (0.00025, 0.00125),
"claude-2.1": (0.008, 0.024),
"claude-2.0": (0.008, 0.024),
"claude-instant-1.2": (0.008, 0.024),
}
@register_llm_config
class AnthropicLLMConfigEntry(LLMConfigEntry):
api_type: Literal["anthropic"] = "anthropic"
timeout: Optional[int] = Field(default=None, ge=1)
temperature: float = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=None, ge=1)
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
stop_sequences: Optional[list[str]] = None
stream: bool = False
max_tokens: int = Field(default=4096, ge=1)
price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
tool_choice: Optional[dict] = None
thinking: Optional[dict] = None
gcp_project_id: Optional[str] = None
gcp_region: Optional[str] = None
gcp_auth_token: Optional[str] = None
def create_client(self):
raise NotImplementedError("AnthropicLLMConfigEntry.create_client is not implemented.")
@require_optional_import("anthropic", "anthropic")
class AnthropicClient:
def __init__(self, **kwargs: Any):
"""Initialize the Anthropic API client.
Args:
**kwargs: The configuration parameters for the client.
"""
self._api_key = kwargs.get("api_key")
self._aws_access_key = kwargs.get("aws_access_key")
self._aws_secret_key = kwargs.get("aws_secret_key")
self._aws_session_token = kwargs.get("aws_session_token")
self._aws_region = kwargs.get("aws_region")
self._gcp_project_id = kwargs.get("gcp_project_id")
self._gcp_region = kwargs.get("gcp_region")
self._gcp_auth_token = kwargs.get("gcp_auth_token")
self._base_url = kwargs.get("base_url")
if not self._api_key:
self._api_key = os.getenv("ANTHROPIC_API_KEY")
if not self._aws_access_key:
self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
if not self._aws_secret_key:
self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
if not self._aws_region:
self._aws_region = os.getenv("AWS_REGION")
if not self._gcp_region:
self._gcp_region = os.getenv("GCP_REGION")
if self._api_key is None:
if self._aws_region:
if self._aws_access_key is None or self._aws_secret_key is None:
raise ValueError("API key or AWS credentials are required to use the Anthropic API.")
elif self._gcp_region:
if self._gcp_project_id is None or self._gcp_region is None:
raise ValueError("API key or GCP credentials are required to use the Anthropic API.")
else:
raise ValueError("API key or AWS credentials or GCP credentials are required to use the Anthropic API.")
if self._api_key is not None:
client_kwargs = {"api_key": self._api_key}
if self._base_url:
client_kwargs["base_url"] = self._base_url
self._client = Anthropic(**client_kwargs)
elif self._gcp_region is not None:
kw = {}
for i, p in enumerate(inspect.signature(AnthropicVertex).parameters):
if hasattr(self, f"_gcp_{p}"):
kw[p] = getattr(self, f"_gcp_{p}")
if self._base_url:
kw["base_url"] = self._base_url
self._client = AnthropicVertex(**kw)
else:
client_kwargs = {
"aws_access_key": self._aws_access_key,
"aws_secret_key": self._aws_secret_key,
"aws_session_token": self._aws_session_token,
"aws_region": self._aws_region,
}
if self._base_url:
client_kwargs["base_url"] = self._base_url
self._client = AnthropicBedrock(**client_kwargs)
self._last_tooluse_status = {}
# Store the response format, if provided (for structured outputs)
self._response_format: Optional[type[BaseModel]] = None
def load_config(self, params: dict[str, Any]):
"""Load the configuration for the Anthropic API client."""
anthropic_params = {}
anthropic_params["model"] = params.get("model")
assert anthropic_params["model"], "Please provide a `model` in the config_list to use the Anthropic API."
anthropic_params["temperature"] = validate_parameter(
params, "temperature", (float, int), False, 1.0, (0.0, 1.0), None
)
anthropic_params["max_tokens"] = validate_parameter(params, "max_tokens", int, False, 4096, (1, None), None)
anthropic_params["timeout"] = validate_parameter(params, "timeout", int, True, None, (1, None), None)
anthropic_params["top_k"] = validate_parameter(params, "top_k", int, True, None, (1, None), None)
anthropic_params["top_p"] = validate_parameter(params, "top_p", (float, int), True, None, (0.0, 1.0), None)
anthropic_params["stop_sequences"] = validate_parameter(params, "stop_sequences", list, True, None, None, None)
anthropic_params["stream"] = validate_parameter(params, "stream", bool, False, False, None, None)
if "thinking" in params:
anthropic_params["thinking"] = params["thinking"]
if anthropic_params["stream"]:
warnings.warn(
"Streaming is not currently supported, streaming will be disabled.",
UserWarning,
)
anthropic_params["stream"] = False
# Note the Anthropic API supports "tool" for tool_choice but you must specify the tool name so we will ignore that here
# Dictionary, see options here: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview#controlling-claudes-output
# type = auto, any, tool, none | name = the name of the tool if type=tool
anthropic_params["tool_choice"] = validate_parameter(params, "tool_choice", dict, True, None, None, None)
return anthropic_params
def cost(self, response) -> float:
"""Calculate the cost of the completion using the Anthropic pricing."""
return response.cost
@property
def api_key(self):
return self._api_key
@property
def aws_access_key(self):
return self._aws_access_key
@property
def aws_secret_key(self):
return self._aws_secret_key
@property
def aws_session_token(self):
return self._aws_session_token
@property
def aws_region(self):
return self._aws_region
@property
def gcp_project_id(self):
return self._gcp_project_id
@property
def gcp_region(self):
return self._gcp_region
@property
def gcp_auth_token(self):
return self._gcp_auth_token
def create(self, params: dict[str, Any]) -> ChatCompletion:
"""Creates a completion using the Anthropic API."""
if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
params["functions"] = params.get("functions", []) + converted_functions
# Convert AG2 messages to Anthropic messages
anthropic_messages = oai_messages_to_anthropic_messages(params)
anthropic_params = self.load_config(params)
# If response_format exists, we want structured outputs
# Anthropic doesn't support response_format, so using Anthropic's "JSON Mode":
# https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
if params.get("response_format"):
self._response_format = params["response_format"]
self._add_response_format_to_system(params)
# TODO: support stream
params = params.copy()
if "functions" in params:
tools_configs = params.pop("functions")
tools_configs = [self.openai_func_to_anthropic(tool) for tool in tools_configs]
params["tools"] = tools_configs
# Anthropic doesn't accept None values, so we need to use keyword argument unpacking instead of setting parameters.
# Copy params we need into anthropic_params
# Remove any that don't have values
anthropic_params["messages"] = anthropic_messages
if "system" in params:
anthropic_params["system"] = params["system"]
if "tools" in params:
anthropic_params["tools"] = params["tools"]
if anthropic_params["top_k"] is None:
del anthropic_params["top_k"]
if anthropic_params["top_p"] is None:
del anthropic_params["top_p"]
if anthropic_params["stop_sequences"] is None:
del anthropic_params["stop_sequences"]
if anthropic_params["tool_choice"] is None:
del anthropic_params["tool_choice"]
response = self._client.messages.create(**anthropic_params)
tool_calls = []
message_text = ""
if self._response_format:
try:
parsed_response = self._extract_json_response(response)
message_text = _format_json_response(parsed_response)
except ValueError as e:
message_text = str(e)
anthropic_finish = "stop"
else:
if response is not None:
# If we have tool use as the response, populate completed tool calls for our return OAI response
if response.stop_reason == "tool_use":
anthropic_finish = "tool_calls"
for content in response.content:
if type(content) == ToolUseBlock:
tool_calls.append(
ChatCompletionMessageToolCall(
id=content.id,
function={"name": content.name, "arguments": json.dumps(content.input)},
type="function",
)
)
else:
anthropic_finish = "stop"
tool_calls = None
# Retrieve any text content from the response
for content in response.content:
if type(content) == TextBlock:
message_text = content.text
break
# Calculate and save the cost onto the response
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
# Convert output back to AG2 response format
message = ChatCompletionMessage(
role="assistant",
content=message_text,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=anthropic_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=response.id,
model=anthropic_params["model"],
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
cost=_calculate_cost(prompt_tokens, completion_tokens, anthropic_params["model"]),
)
return response_oai
def message_retrieval(self, response) -> list:
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
@staticmethod
def openai_func_to_anthropic(openai_func: dict) -> dict:
res = openai_func.copy()
res["input_schema"] = res.pop("parameters")
return res
@staticmethod
def get_usage(response: ChatCompletion) -> dict:
"""Get the usage of tokens and their cost information."""
return {
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
"total_tokens": response.usage.total_tokens if response.usage is not None else 0,
"cost": response.cost if hasattr(response, "cost") else 0.0,
"model": response.model,
}
@staticmethod
def convert_tools_to_functions(tools: list) -> list:
"""
Convert tool definitions into Anthropic-compatible functions,
updating nested $ref paths in property schemas.
Args:
tools (list): List of tool definitions.
Returns:
list: List of functions with updated $ref paths.
"""
def update_refs(obj, defs_keys, prop_name):
"""Recursively update $ref values that start with "#/$defs/"."""
if isinstance(obj, dict):
for key, value in obj.items():
if key == "$ref" and isinstance(value, str) and value.startswith("#/$defs/"):
ref_key = value[len("#/$defs/") :]
if ref_key in defs_keys:
obj[key] = f"#/properties/{prop_name}/$defs/{ref_key}"
else:
update_refs(value, defs_keys, prop_name)
elif isinstance(obj, list):
for item in obj:
update_refs(item, defs_keys, prop_name)
functions = []
for tool in tools:
if tool.get("type") == "function" and "function" in tool:
function = tool["function"]
parameters = function.get("parameters", {})
properties = parameters.get("properties", {})
for prop_name, prop_schema in properties.items():
if "$defs" in prop_schema:
defs_keys = set(prop_schema["$defs"].keys())
update_refs(prop_schema, defs_keys, prop_name)
functions.append(function)
return functions
def _add_response_format_to_system(self, params: dict[str, Any]):
"""Add prompt that will generate properly formatted JSON for structured outputs to system parameter.
Based on Anthropic's JSON Mode cookbook, we ask the LLM to put the JSON within <json_response> tags.
Args:
params (dict): The client parameters
"""
if not params.get("system"):
return
# Get the schema of the Pydantic model
if isinstance(self._response_format, dict):
schema = self._response_format
else:
schema = self._response_format.model_json_schema()
# Add instructions for JSON formatting
format_content = f"""Please provide your response as a JSON object that matches the following schema:
{json.dumps(schema, indent=2)}
Format your response as valid JSON within <json_response> tags.
Do not include any text before or after the tags.
Ensure the JSON is properly formatted and matches the schema exactly."""
# Add formatting to last user message
params["system"] += "\n\n" + format_content
def _extract_json_response(self, response: Message) -> Any:
"""Extract and validate JSON response from the output for structured outputs.
Args:
response (Message): The response from the API.
Returns:
Any: The parsed JSON response.
"""
if not self._response_format:
return response
# Extract content from response
content = response.content[0].text if response.content else ""
# Try to extract JSON from tags first
json_match = re.search(r"<json_response>(.*?)</json_response>", content, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
# Fallback to finding first JSON object
json_start = content.find("{")
json_end = content.rfind("}")
if json_start == -1 or json_end == -1:
raise ValueError("No valid JSON found in response for Structured Output.")
json_str = content[json_start : json_end + 1]
try:
# Parse JSON and validate against the Pydantic model if Pydantic model was provided
json_data = json.loads(json_str)
if isinstance(self._response_format, dict):
return json_str
else:
return self._response_format.model_validate(json_data)
except Exception as e:
raise ValueError(f"Failed to parse response as valid JSON matching the schema for Structured Output: {e!s}")
def _format_json_response(response: Any) -> str:
"""Formats the JSON response for structured outputs using the format method if it exists."""
if isinstance(response, str):
return response
elif isinstance(response, FormatterProtocol):
return response.format()
else:
return response.model_dump_json()
def process_image_content(content_item: dict[str, Any]) -> dict[str, Any]:
"""Process an OpenAI image content item into Claude format."""
if content_item["type"] != "image_url":
return content_item
url = content_item["image_url"]["url"]
try:
# Handle data URLs
if url.startswith("data:"):
data_url_pattern = r"data:image/([a-zA-Z]+);base64,(.+)"
match = re.match(data_url_pattern, url)
if match:
media_type, base64_data = match.groups()
return {
"type": "image",
"source": {"type": "base64", "media_type": f"image/{media_type}", "data": base64_data},
}
else:
print("Error processing image.")
# Return original content if image processing fails
return content_item
except Exception as e:
print(f"Error processing image image: {e}")
# Return original content if image processing fails
return content_item
def process_message_content(message: dict[str, Any]) -> Union[str, list[dict[str, Any]]]:
"""Process message content, handling both string and list formats with images."""
content = message.get("content", "")
# Handle empty content
if content == "":
return content
# If content is already a string, return as is
if isinstance(content, str):
return content
# Handle list content (mixed text and images)
if isinstance(content, list):
processed_content = []
for item in content:
if item["type"] == "text":
processed_content.append({"type": "text", "text": item["text"]})
elif item["type"] == "image_url":
processed_content.append(process_image_content(item))
return processed_content
return content
@require_optional_import("anthropic", "anthropic")
def oai_messages_to_anthropic_messages(params: dict[str, Any]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Anthropic format.
We correct for any specific role orders and types, etc.
"""
# Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages.
# Anthropic requires a tools parameter with the tools listed, if there are other messages with tool use or tool results.
# This can occur when we don't need tool calling, such as for group chat speaker selection.
has_tools = "tools" in params
# Convert messages to Anthropic compliant format
processed_messages = []
# Used to interweave user messages to ensure user/assistant alternating
user_continue_message = {"content": "Please continue.", "role": "user"}
assistant_continue_message = {"content": "Please continue.", "role": "assistant"}
tool_use_messages = 0
tool_result_messages = 0
last_tool_use_index = -1
last_tool_result_index = -1
for message in params["messages"]:
if message["role"] == "system":
content = process_message_content(message)
if isinstance(content, list):
# For system messages with images, concatenate only the text portions
text_content = " ".join(item.get("text", "") for item in content if item.get("type") == "text")
params["system"] = params.get("system", "") + (" " if "system" in params else "") + text_content
else:
params["system"] = params.get("system", "") + ("\n" if "system" in params else "") + content
else:
# New messages will be added here, manage role alternations
expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
if "tool_calls" in message:
# Map the tool call options to Anthropic's ToolUseBlock
tool_uses = []
tool_names = []
for tool_call in message["tool_calls"]:
tool_uses.append(
ToolUseBlock(
type="tool_use",
id=tool_call["id"],
name=tool_call["function"]["name"],
input=json.loads(tool_call["function"]["arguments"]),
)
)
if has_tools:
tool_use_messages += 1
tool_names.append(tool_call["function"]["name"])
if expected_role == "user":
# Insert an extra user message as we will append an assistant message
processed_messages.append(user_continue_message)
if has_tools:
processed_messages.append({"role": "assistant", "content": tool_uses})
last_tool_use_index = len(processed_messages) - 1
else:
# Not using tools, so put in a plain text message
processed_messages.append({
"role": "assistant",
"content": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]",
})
elif "tool_call_id" in message:
if has_tools:
# Map the tool usage call to tool_result for Anthropic
tool_result = {
"type": "tool_result",
"tool_use_id": message["tool_call_id"],
"content": message["content"],
}
# If the previous message also had a tool_result, add it to that
# Otherwise append a new message
if last_tool_result_index == len(processed_messages) - 1:
processed_messages[-1]["content"].append(tool_result)
else:
if expected_role == "assistant":
# Insert an extra assistant message as we will append a user message
processed_messages.append(assistant_continue_message)
processed_messages.append({"role": "user", "content": [tool_result]})
last_tool_result_index = len(processed_messages) - 1
tool_result_messages += 1
else:
# Not using tools, so put in a plain text message
processed_messages.append({
"role": "user",
"content": f"Running the function returned: {message['content']}",
})
elif message["content"] == "":
# Ignoring empty messages
pass
else:
if expected_role != message["role"]:
# Inserting the alternating continue message
processed_messages.append(
user_continue_message if expected_role == "user" else assistant_continue_message
)
# Process messages for images
processed_content = process_message_content(message)
processed_message = message.copy()
processed_message["content"] = processed_content
processed_messages.append(processed_message)
# We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function)
if has_tools and tool_use_messages != tool_result_messages:
processed_messages[last_tool_use_index] = assistant_continue_message
# name is not a valid field on messages
for message in processed_messages:
if "name" in message:
message.pop("name", None)
# Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response
# So, if the last role is not user, add a 'user' continue message at the end
if processed_messages[-1]["role"] != "user":
processed_messages.append(user_continue_message)
return processed_messages
def _calculate_cost(input_tokens: int, output_tokens: int, model: str) -> float:
"""Calculate the cost of the completion using the Anthropic pricing."""
total = 0.0
if model in ANTHROPIC_PRICING_1k:
input_cost_per_1k, output_cost_per_1k = ANTHROPIC_PRICING_1k[model]
input_cost = (input_tokens / 1000) * input_cost_per_1k
output_cost = (output_tokens / 1000) * output_cost_per_1k
total = input_cost + output_cost
else:
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
return total

View File

@@ -0,0 +1,628 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Create a compatible client for the Amazon Bedrock Converse API.
Example usage:
Install the `boto3` package by running `pip install --upgrade boto3`.
- https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
```python
import autogen
config_list = [
{
"api_type": "bedrock",
"model": "meta.llama3-1-8b-instruct-v1:0",
"aws_region": "us-west-2",
"aws_access_key": "",
"aws_secret_key": "",
"price": [0.003, 0.015],
}
]
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
```
"""
from __future__ import annotations
import base64
import json
import os
import re
import time
import warnings
from typing import Any, Literal, Optional
import requests
from pydantic import Field, SecretStr, field_serializer
from ..import_utils import optional_import_block, require_optional_import
from ..llm_config import LLMConfigEntry, register_llm_config
from .client_utils import validate_parameter
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
with optional_import_block():
import boto3
from botocore.config import Config
@register_llm_config
class BedrockLLMConfigEntry(LLMConfigEntry):
api_type: Literal["bedrock"] = "bedrock"
aws_region: str
aws_access_key: Optional[SecretStr] = None
aws_secret_key: Optional[SecretStr] = None
aws_session_token: Optional[SecretStr] = None
aws_profile_name: Optional[str] = None
temperature: Optional[float] = None
topP: Optional[float] = None # noqa: N815
maxTokens: Optional[int] = None # noqa: N815
top_p: Optional[float] = None
top_k: Optional[int] = None
k: Optional[int] = None
seed: Optional[int] = None
cache_seed: Optional[int] = None
supports_system_prompts: bool = True
stream: bool = False
price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
timeout: Optional[int] = None
@field_serializer("aws_access_key", "aws_secret_key", "aws_session_token", when_used="unless-none")
def serialize_aws_secrets(self, v: SecretStr) -> str:
return v.get_secret_value()
def create_client(self):
raise NotImplementedError("BedrockLLMConfigEntry.create_client must be implemented.")
@require_optional_import("boto3", "bedrock")
class BedrockClient:
"""Client for Amazon's Bedrock Converse API."""
_retries = 5
def __init__(self, **kwargs: Any):
"""Initialises BedrockClient for Amazon's Bedrock Converse API"""
self._aws_access_key = kwargs.get("aws_access_key")
self._aws_secret_key = kwargs.get("aws_secret_key")
self._aws_session_token = kwargs.get("aws_session_token")
self._aws_region = kwargs.get("aws_region")
self._aws_profile_name = kwargs.get("aws_profile_name")
self._timeout = kwargs.get("timeout")
if not self._aws_access_key:
self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
if not self._aws_secret_key:
self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
if not self._aws_session_token:
self._aws_session_token = os.getenv("AWS_SESSION_TOKEN")
if not self._aws_region:
self._aws_region = os.getenv("AWS_REGION")
if self._aws_region is None:
raise ValueError("Region is required to use the Amazon Bedrock API.")
if self._timeout is None:
self._timeout = 60
# Initialize Bedrock client, session, and runtime
bedrock_config = Config(
region_name=self._aws_region,
signature_version="v4",
retries={"max_attempts": self._retries, "mode": "standard"},
read_timeout=self._timeout,
)
session = boto3.Session(
aws_access_key_id=self._aws_access_key,
aws_secret_access_key=self._aws_secret_key,
aws_session_token=self._aws_session_token,
profile_name=self._aws_profile_name,
)
if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning)
# if haven't got any access_key or secret_key in environment variable or via arguments then
if (
self._aws_access_key is None
or self._aws_access_key == ""
or self._aws_secret_key is None
or self._aws_secret_key == ""
):
# attempts to get client from attached role of managed service (lambda, ec2, ecs, etc.)
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime", config=bedrock_config)
else:
session = boto3.Session(
aws_access_key_id=self._aws_access_key,
aws_secret_access_key=self._aws_secret_key,
aws_session_token=self._aws_session_token,
profile_name=self._aws_profile_name,
)
self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)
def message_retrieval(self, response):
"""Retrieve the messages from the response."""
return [choice.message for choice in response.choices]
def parse_custom_params(self, params: dict[str, Any]):
"""Parses custom parameters for logic in this client class"""
# Should we separate system messages into its own request parameter, default is True
# This is required because not all models support a system prompt (e.g. Mistral Instruct).
self._supports_system_prompts = params.get("supports_system_prompts", True)
def parse_params(self, params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
"""Loads the valid parameters required to invoke Bedrock Converse
Returns a tuple of (base_params, additional_params)
"""
base_params = {}
additional_params = {}
# Amazon Bedrock base model IDs are here:
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
self._model_id = params.get("model")
assert self._model_id, "Please provide the 'model` in the config_list to use Amazon Bedrock"
# Parameters vary based on the model used.
# As we won't cater for all models and parameters, it's the developer's
# responsibility to implement the parameters and they will only be
# included if the developer has it in the config.
#
# Important:
# No defaults will be used (as they can vary per model)
# No ranges will be used (as they can vary)
# We will cover all the main parameters but there may be others
# that need to be added later
#
# Here are some pages that show the parameters available for different models
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-chat-completion.html
# Here are the possible "base" parameters and their suitable types
base_parameters = [["temperature", (float, int)], ["topP", (float, int)], ["maxTokens", (int)]]
for param_name, suitable_types in base_parameters:
if param_name in params:
base_params[param_name] = validate_parameter(
params, param_name, suitable_types, False, None, None, None
)
# Here are the possible "model-specific" parameters and their suitable types, known as additional parameters
additional_parameters = [
["top_p", (float, int)],
["top_k", (int)],
["k", (int)],
["seed", (int)],
]
for param_name, suitable_types in additional_parameters:
if param_name in params:
additional_params[param_name] = validate_parameter(
params, param_name, suitable_types, False, None, None, None
)
# Streaming
self._streaming = params.get("stream", False)
# For this release we will not support streaming as many models do not support streaming with tool use
if self._streaming:
warnings.warn(
"Streaming is not currently supported, streaming will be disabled.",
UserWarning,
)
self._streaming = False
return base_params, additional_params
def create(self, params) -> ChatCompletion:
"""Run Amazon Bedrock inference and return AG2 response"""
# Set custom client class settings
self.parse_custom_params(params)
# Parse the inference parameters
base_params, additional_params = self.parse_params(params)
has_tools = "tools" in params
messages = oai_messages_to_bedrock_messages(params["messages"], has_tools, self._supports_system_prompts)
if self._supports_system_prompts:
system_messages = extract_system_messages(params["messages"])
tool_config = format_tools(params["tools"] if has_tools else [])
request_args = {"messages": messages, "modelId": self._model_id}
# Base and additional args
if len(base_params) > 0:
request_args["inferenceConfig"] = base_params
if len(additional_params) > 0:
request_args["additionalModelRequestFields"] = additional_params
if self._supports_system_prompts:
request_args["system"] = system_messages
if len(tool_config["tools"]) > 0:
request_args["toolConfig"] = tool_config
response = self.bedrock_runtime.converse(**request_args)
if response is None:
raise RuntimeError(f"Failed to get response from Bedrock after retrying {self._retries} times.")
finish_reason = convert_stop_reason_to_finish_reason(response["stopReason"])
response_message = response["output"]["message"]
tool_calls = format_tool_calls(response_message["content"]) if finish_reason == "tool_calls" else None
text = ""
for content in response_message["content"]:
if "text" in content:
text = content["text"]
# NOTE: other types of output may be dealt with here
message = ChatCompletionMessage(role="assistant", content=text, tool_calls=tool_calls)
response_usage = response["usage"]
usage = CompletionUsage(
prompt_tokens=response_usage["inputTokens"],
completion_tokens=response_usage["outputTokens"],
total_tokens=response_usage["totalTokens"],
)
return ChatCompletion(
id=response["ResponseMetadata"]["RequestId"],
choices=[Choice(finish_reason=finish_reason, index=0, message=message)],
created=int(time.time()),
model=self._model_id,
object="chat.completion",
usage=usage,
)
def cost(self, response: ChatCompletion) -> float:
"""Calculate the cost of the response."""
return calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens, response.model)
@staticmethod
def get_usage(response) -> dict:
"""Get the usage of tokens and their cost information."""
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def extract_system_messages(messages: list[dict[str, Any]]) -> list:
"""Extract the system messages from the list of messages.
Args:
messages (list[dict[str, Any]]): List of messages.
Returns:
List[SystemMessage]: List of System messages.
"""
"""
system_messages = [message.get("content")[0]["text"] for message in messages if message.get("role") == "system"]
return system_messages # ''.join(system_messages)
"""
for message in messages:
if message.get("role") == "system":
if isinstance(message["content"], str):
return [{"text": message.get("content")}]
else:
return [{"text": message.get("content")[0]["text"]}]
return []
def oai_messages_to_bedrock_messages(
messages: list[dict[str, Any]], has_tools: bool, supports_system_prompts: bool
) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Bedrock format.
We correct for any specific role orders and types, etc.
AWS Bedrock requires messages to alternate between user and assistant roles. This function ensures that the messages
are in the correct order and format for Bedrock by inserting "Please continue" messages as needed.
This is the same method as the one in the Autogen Anthropic client
"""
# Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages.
# Bedrock requires a tools parameter with the tools listed, if there are other messages with tool use or tool results.
# This can occur when we don't need tool calling, such as for group chat speaker selection
# Convert messages to Bedrock compliant format
# Take out system messages if the model supports it, otherwise leave them in.
if supports_system_prompts:
messages = [x for x in messages if x["role"] != "system"]
else:
# Replace role="system" with role="user"
for msg in messages:
if msg["role"] == "system":
msg["role"] = "user"
processed_messages = []
# Used to interweave user messages to ensure user/assistant alternating
user_continue_message = {"content": [{"text": "Please continue."}], "role": "user"}
assistant_continue_message = {
"content": [{"text": "Please continue."}],
"role": "assistant",
}
tool_use_messages = 0
tool_result_messages = 0
last_tool_use_index = -1
last_tool_result_index = -1
# user_role_index = 0 if supports_system_prompts else 1 # If system prompts are supported, messages start with user, otherwise they'll be the second message
for message in messages:
# New messages will be added here, manage role alternations
expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
if "tool_calls" in message:
# Map the tool call options to Bedrock's format
tool_uses = []
tool_names = []
for tool_call in message["tool_calls"]:
tool_uses.append({
"toolUse": {
"toolUseId": tool_call["id"],
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"]),
}
})
if has_tools:
tool_use_messages += 1
tool_names.append(tool_call["function"]["name"])
if expected_role == "user":
# Insert an extra user message as we will append an assistant message
processed_messages.append(user_continue_message)
if has_tools:
processed_messages.append({"role": "assistant", "content": tool_uses})
last_tool_use_index = len(processed_messages) - 1
else:
# Not using tools, so put in a plain text message
processed_messages.append({
"role": "assistant",
"content": [{"text": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]"}],
})
elif "tool_call_id" in message:
if has_tools:
# Map the tool usage call to tool_result for Bedrock
tool_result = {
"toolResult": {
"toolUseId": message["tool_call_id"],
"content": [{"text": message["content"]}],
}
}
# If the previous message also had a tool_result, add it to that
# Otherwise append a new message
if last_tool_result_index == len(processed_messages) - 1:
processed_messages[-1]["content"].append(tool_result)
else:
if expected_role == "assistant":
# Insert an extra assistant message as we will append a user message
processed_messages.append(assistant_continue_message)
processed_messages.append({"role": "user", "content": [tool_result]})
last_tool_result_index = len(processed_messages) - 1
tool_result_messages += 1
else:
# Not using tools, so put in a plain text message
processed_messages.append({
"role": "user",
"content": [{"text": f"Running the function returned: {message['content']}"}],
})
elif message["content"] == "":
# Ignoring empty messages
pass
else:
if expected_role != message["role"] and not (len(processed_messages) == 0 and message["role"] == "system"):
# Inserting the alternating continue message (ignore if it's the first message and a system message)
processed_messages.append(
user_continue_message if expected_role == "user" else assistant_continue_message
)
processed_messages.append({
"role": message["role"],
"content": parse_content_parts(message=message),
})
# We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function)
if has_tools and tool_use_messages != tool_result_messages:
processed_messages[last_tool_use_index] = assistant_continue_message
# name is not a valid field on messages
for message in processed_messages:
if "name" in message:
message.pop("name", None)
# Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response
# So, if the last role is not user, add a 'user' continue message at the end
if processed_messages[-1]["role"] != "user":
processed_messages.append(user_continue_message)
return processed_messages
def parse_content_parts(
message: dict[str, Any],
) -> list[dict[str, Any]]:
content: str | list[dict[str, Any]] = message.get("content")
if isinstance(content, str):
return [
{
"text": content,
}
]
content_parts = []
for part in content:
# part_content: Dict = part.get("content")
if "text" in part: # part_content:
content_parts.append({
"text": part.get("text"),
})
elif "image_url" in part: # part_content:
image_data, content_type = parse_image(part.get("image_url").get("url"))
content_parts.append({
"image": {
"format": content_type[6:], # image/
"source": {"bytes": image_data},
},
})
else:
# Ignore..
continue
return content_parts
def parse_image(image_url: str) -> tuple[bytes, str]:
"""Try to get the raw data from an image url.
Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html
returns a tuple of (Image Data, Content Type)
"""
pattern = r"^data:(image/[a-z]*);base64,\s*"
content_type = re.search(pattern, image_url)
# if already base64 encoded.
# Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp'
if content_type:
image_data = re.sub(pattern, "", image_url)
return base64.b64decode(image_data), content_type.group(1)
# Send a request to the image URL
response = requests.get(image_url)
# Check if the request was successful
if response.status_code == 200:
content_type = response.headers.get("Content-Type")
if not content_type.startswith("image"):
content_type = "image/jpeg"
# Get the image content
image_content = response.content
return image_content, content_type
else:
raise RuntimeError("Unable to access the image url")
def format_tools(tools: list[dict[str, Any]]) -> dict[Literal["tools"], list[dict[str, Any]]]:
converted_schema = {"tools": []}
for tool in tools:
if tool["type"] == "function":
function = tool["function"]
converted_tool = {
"toolSpec": {
"name": function["name"],
"description": function["description"],
"inputSchema": {"json": {"type": "object", "properties": {}, "required": []}},
}
}
for prop_name, prop_details in function["parameters"]["properties"].items():
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name] = {
"type": prop_details["type"],
"description": prop_details.get("description", ""),
}
if "enum" in prop_details:
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["enum"] = prop_details[
"enum"
]
if "default" in prop_details:
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["default"] = (
prop_details["default"]
)
if "required" in function["parameters"]:
converted_tool["toolSpec"]["inputSchema"]["json"]["required"] = function["parameters"]["required"]
converted_schema["tools"].append(converted_tool)
return converted_schema
def format_tool_calls(content):
"""Converts Converse API response tool calls to AG2 format"""
tool_calls = []
for tool_request in content:
if "toolUse" in tool_request:
tool = tool_request["toolUse"]
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool["toolUseId"],
function={
"name": tool["name"],
"arguments": json.dumps(tool["input"]),
},
type="function",
)
)
return tool_calls
def convert_stop_reason_to_finish_reason(
stop_reason: str,
) -> Literal["stop", "length", "tool_calls", "content_filter"]:
"""Converts Bedrock finish reasons to our finish reasons, according to OpenAI:
- stop: if the model hit a natural stop point or a provided stop sequence,
- length: if the maximum number of tokens specified in the request was reached,
- content_filter: if content was omitted due to a flag from our content filters,
- tool_calls: if the model called a tool
"""
if stop_reason:
finish_reason_mapping = {
"tool_use": "tool_calls",
"finished": "stop",
"end_turn": "stop",
"max_tokens": "length",
"stop_sequence": "stop",
"complete": "stop",
"content_filtered": "content_filter",
}
return finish_reason_mapping.get(stop_reason.lower(), stop_reason.lower())
warnings.warn(f"Unsupported stop reason: {stop_reason}", UserWarning)
return None
# NOTE: As this will be quite dynamic, it's expected that the developer will use the "price" parameter in their config
# These may be removed.
PRICES_PER_K_TOKENS = {
"meta.llama3-8b-instruct-v1:0": (0.0003, 0.0006),
"meta.llama3-70b-instruct-v1:0": (0.00265, 0.0035),
"mistral.mistral-7b-instruct-v0:2": (0.00015, 0.0002),
"mistral.mixtral-8x7b-instruct-v0:1": (0.00045, 0.0007),
"mistral.mistral-large-2402-v1:0": (0.004, 0.012),
"mistral.mistral-small-2402-v1:0": (0.001, 0.003),
}
def calculate_cost(input_tokens: int, output_tokens: int, model_id: str) -> float:
"""Calculate the cost of the completion using the Bedrock pricing."""
if model_id in PRICES_PER_K_TOKENS:
input_cost_per_k, output_cost_per_k = PRICES_PER_K_TOKENS[model_id]
input_cost = (input_tokens / 1000) * input_cost_per_k
output_cost = (output_tokens / 1000) * output_cost_per_k
return input_cost + output_cost
else:
warnings.warn(
f'Cannot get the costs for {model_id}. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.',
UserWarning,
)
return 0

View File

@@ -0,0 +1,299 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Create an OpenAI-compatible client using Cerebras's API.
Example:
```python
llm_config = {
"config_list": [{"api_type": "cerebras", "model": "llama3.1-8b", "api_key": os.environ.get("CEREBRAS_API_KEY")}]
}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
```
Install Cerebras's python library using: pip install --upgrade cerebras_cloud_sdk
Resources:
- https://inference-docs.cerebras.ai/quickstart
"""
from __future__ import annotations
import copy
import math
import os
import time
import warnings
from typing import Any, Literal, Optional
from pydantic import Field, ValidationInfo, field_validator
from ..import_utils import optional_import_block, require_optional_import
from ..llm_config import LLMConfigEntry, register_llm_config
from .client_utils import should_hide_tools, validate_parameter
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
with optional_import_block():
from cerebras.cloud.sdk import Cerebras, Stream
CEREBRAS_PRICING_1K = {
# Convert pricing per million to per thousand tokens.
"llama3.1-8b": (0.10 / 1000, 0.10 / 1000),
"llama-3.3-70b": (0.85 / 1000, 1.20 / 1000),
}
@register_llm_config
class CerebrasLLMConfigEntry(LLMConfigEntry):
api_type: Literal["cerebras"] = "cerebras"
max_tokens: Optional[int] = None
seed: Optional[int] = None
stream: bool = False
temperature: float = Field(default=1.0, ge=0.0, le=1.5)
top_p: Optional[float] = None
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
tool_choice: Optional[Literal["none", "auto", "required"]] = None
@field_validator("top_p", mode="before")
@classmethod
def check_top_p(cls, v: Any, info: ValidationInfo) -> Any:
if v is not None and info.data.get("temperature") is not None:
raise ValueError("temperature and top_p cannot be set at the same time.")
return v
def create_client(self):
raise NotImplementedError("CerebrasLLMConfigEntry.create_client is not implemented.")
class CerebrasClient:
"""Client for Cerebras's API."""
def __init__(self, api_key=None, **kwargs):
"""Requires api_key or environment variable to be set
Args:
api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set)
**kwargs: Additional keyword arguments to pass to the Cerebras client
"""
# Ensure we have the api_key upon instantiation
self.api_key = api_key
if not self.api_key:
self.api_key = os.getenv("CEREBRAS_API_KEY")
assert self.api_key, (
"Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable."
)
if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Crebras, it will be ignored.", UserWarning)
def message_retrieval(self, response: ChatCompletion) -> list:
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
def cost(self, response: ChatCompletion) -> float:
# Note: This field isn't explicitly in `ChatCompletion`, but is injected during chat creation.
return response.cost
@staticmethod
def get_usage(response: ChatCompletion) -> dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Loads the parameters for Cerebras API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
cerebras_params = {}
# Check that we have what we need to use Cerebras's API
# We won't enforce the available models as they are likely to change
cerebras_params["model"] = params.get("model")
assert cerebras_params["model"], (
"Please specify the 'model' in your config list entry to nominate the Cerebras model to use."
)
# Validate allowed Cerebras parameters
# https://inference-docs.cerebras.ai/api-reference/chat-completions
cerebras_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
cerebras_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
cerebras_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
cerebras_params["temperature"] = validate_parameter(
params, "temperature", (int, float), True, 1, (0, 1.5), None
)
cerebras_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
cerebras_params["tool_choice"] = validate_parameter(
params, "tool_choice", str, True, None, None, ["none", "auto", "required"]
)
return cerebras_params
@require_optional_import("cerebras", "cerebras")
def create(self, params: dict) -> ChatCompletion:
messages = params.get("messages", [])
# Convert AG2 messages to Cerebras messages
cerebras_messages = oai_messages_to_cerebras_messages(messages)
# Parse parameters to the Cerebras API's parameters
cerebras_params = self.parse_params(params)
# Add tools to the call if we have them and aren't hiding them
if "tools" in params:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
if not should_hide_tools(cerebras_messages, params["tools"], hide_tools):
cerebras_params["tools"] = params["tools"]
cerebras_params["messages"] = cerebras_messages
# We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
client = Cerebras(api_key=self.api_key, max_retries=5)
# Token counts will be returned
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
# Streaming tool call recommendations
streaming_tool_calls = []
ans = None
response = client.chat.completions.create(**cerebras_params)
if cerebras_params["stream"]:
# Read in the chunks as they stream, taking in tool_calls which may be across
# multiple chunks if more than one suggested
ans = ""
for chunk in response:
# Grab first choice, which _should_ always be generated.
ans = ans + (getattr(chunk.choices[0].delta, "content", None) or "")
if "tool_calls" in chunk.choices[0].delta:
# We have a tool call recommendation
for tool_call in chunk.choices[0].delta["tool_calls"]:
streaming_tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call["id"],
function={
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
},
type="function",
)
)
if chunk.choices[0].finish_reason:
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
total_tokens = chunk.usage.total_tokens
else:
# Non-streaming finished
ans: str = response.choices[0].message.content
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
total_tokens = response.usage.total_tokens
if response is not None:
if isinstance(response, Stream):
# Streaming response
if chunk.choices[0].finish_reason == "tool_calls":
cerebras_finish = "tool_calls"
tool_calls = streaming_tool_calls
else:
cerebras_finish = "stop"
tool_calls = None
response_content = ans
response_id = chunk.id
else:
# Non-streaming response
# If we have tool calls as the response, populate completed tool calls for our return OAI response
if response.choices[0].finish_reason == "tool_calls":
cerebras_finish = "tool_calls"
tool_calls = []
for tool_call in response.choices[0].message.tool_calls:
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
type="function",
)
)
else:
cerebras_finish = "stop"
tool_calls = None
response_content = response.choices[0].message.content
response_id = response.id
# 3. convert output
message = ChatCompletionMessage(
role="assistant",
content=response_content,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=cerebras_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=response_id,
model=cerebras_params["model"],
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
# Note: This seems to be a field that isn't in the schema of `ChatCompletion`, so Pydantic
# just adds it dynamically.
cost=calculate_cerebras_cost(prompt_tokens, completion_tokens, cerebras_params["model"]),
)
return response_oai
def oai_messages_to_cerebras_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Cerebras's format.
We correct for any specific role orders and types.
"""
cerebras_messages = copy.deepcopy(messages)
# Remove the name field
for message in cerebras_messages:
if "name" in message:
message.pop("name", None)
return cerebras_messages
def calculate_cerebras_cost(input_tokens: int, output_tokens: int, model: str) -> float:
"""Calculate the cost of the completion using the Cerebras pricing."""
total = 0.0
if model in CEREBRAS_PRICING_1K:
input_cost_per_k, output_cost_per_k = CEREBRAS_PRICING_1K[model]
input_cost = math.ceil((input_tokens / 1000) * input_cost_per_k * 1e6) / 1e6
output_cost = math.ceil((output_tokens / 1000) * output_cost_per_k * 1e6) / 1e6
total = math.ceil((input_cost + output_cost) * 1e6) / 1e6
else:
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
return total

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,169 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Utilities for client classes"""
import logging
import warnings
from typing import Any, Optional, Protocol, runtime_checkable
@runtime_checkable
class FormatterProtocol(Protocol):
"""Structured Output classes with a format method"""
def format(self) -> str: ...
def validate_parameter(
params: dict[str, Any],
param_name: str,
allowed_types: tuple[Any, ...],
allow_None: bool, # noqa: N803
default_value: Any,
numerical_bound: Optional[tuple[Optional[float], Optional[float]]],
allowed_values: Optional[list[Any]],
) -> Any:
"""Validates a given config parameter, checking its type, values, and setting defaults
Parameters:
params (Dict[str, Any]): Dictionary containing parameters to validate.
param_name (str): The name of the parameter to validate.
allowed_types (Tuple): Tuple of acceptable types for the parameter.
allow_None (bool): Whether the parameter can be `None`.
default_value (Any): The default value to use if the parameter is invalid or missing.
numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]):
A tuple specifying the lower and upper bounds for numerical parameters.
Each bound can be `None` if not applicable.
allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter.
Can be `None` if no specific values are required.
Returns:
Any: The validated parameter value or the default value if validation fails.
Raises:
TypeError: If `allowed_values` is provided but is not a list.
Example Usage:
```python
# Validating a numerical parameter within specific bounds
params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"}
temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None)
# Result: 0.5
# Validating a parameter that can be one of a list of allowed values
model = validate_parameter(
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
)
# If "safety_model" is missing or invalid in params, defaults to "default"
```
"""
if allowed_values is not None and not isinstance(allowed_values, list):
raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}")
param_value = params.get(param_name, default_value)
warning = ""
if param_value is None and allow_None:
pass
elif param_value is None:
if not allow_None:
warning = "cannot be None"
elif not isinstance(param_value, allowed_types):
# Check types and list possible types if invalid
if isinstance(allowed_types, tuple):
formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")"
else:
formatted_types = f"{allowed_types.__name__}"
warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}"
elif numerical_bound:
# Check the value fits in possible bounds
lower_bound, upper_bound = numerical_bound
if (lower_bound is not None and param_value < lower_bound) or (
upper_bound is not None and param_value > upper_bound
):
warning = "has numerical bounds"
if lower_bound is not None:
warning += f", >= {lower_bound!s}"
if upper_bound is not None:
if lower_bound is not None:
warning += " and"
warning += f" <= {upper_bound!s}"
if allow_None:
warning += ", or can be None"
elif allowed_values: # noqa: SIM102
# Check if the value matches any allowed values
if not (allow_None and param_value is None) and param_value not in allowed_values:
warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}"
# If we failed any checks, warn and set to default value
if warning:
warnings.warn(
f"Config error - {param_name} {warning}, defaulting to {default_value}.",
UserWarning,
)
param_value = default_value
return param_value
def should_hide_tools(messages: list[dict[str, Any]], tools: list[dict[str, Any]], hide_tools_param: str) -> bool:
"""Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't.
Parameters:
messages (List[Dict[str, Any]]): List of messages
tools (List[Dict[str, Any]]): List of tools
hide_tools_param (str): "hide_tools" parameter value. Can be "if_all_run" (hide tools if all tools have been run), "if_any_run" (hide tools if any of the tools have been run), "never" (never hide tools). Default is "never".
Returns:
bool: Indicates whether the tools should be excluded from the response create request
Example Usage:
```python
# Validating a numerical parameter within specific bounds
messages = params.get("messages", [])
tools = params.get("tools", None)
hide_tools = should_hide_tools(messages, tools, params["hide_tools"])
"""
if hide_tools_param == "never" or tools is None or len(tools) == 0:
return False
elif hide_tools_param == "if_any_run":
# Return True if any tool_call_id exists, indicating a tool call has been executed. False otherwise.
return any(["tool_call_id" in dictionary for dictionary in messages])
elif hide_tools_param == "if_all_run":
# Return True if all tools have been executed at least once. False otherwise.
# Get the list of tool names
check_tool_names = [item["function"]["name"] for item in tools]
# Prepare a list of tool call ids and related function names
tool_call_ids = {}
# Loop through the messages and check if the tools have been run, removing them as we go
for message in messages:
if "tool_calls" in message:
# Register the tool ids and the function names (there could be multiple tool calls)
for tool_call in message["tool_calls"]:
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
elif "tool_call_id" in message:
# Tool called, get the name of the function based on the id
tool_name_called = tool_call_ids[message["tool_call_id"]]
# If we had not yet called the tool, check and remove it to indicate we have
if tool_name_called in check_tool_names:
check_tool_names.remove(tool_name_called)
# Return True if all tools have been called at least once (accounted for)
return len(check_tool_names) == 0
else:
raise TypeError(
f"hide_tools_param is not a valid value ['if_all_run','if_any_run','never'], got '{hide_tools_param}'"
)
# Logging format (originally from FLAML)
logging_formatter = logging.Formatter(
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
)

View File

@@ -0,0 +1,479 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Create an OpenAI-compatible client using Cohere's API.
Example:
```python
llm_config={
"config_list": [{
"api_type": "cohere",
"model": "command-r-plus",
"api_key": os.environ.get("COHERE_API_KEY")
"client_name": "autogen-cohere", # Optional parameter
}
]}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
```
Install Cohere's python library using: pip install --upgrade cohere
Resources:
- https://docs.cohere.com/reference/chat
"""
from __future__ import annotations
import json
import logging
import os
import sys
import time
import warnings
from typing import Any, Literal, Optional, Type
from pydantic import BaseModel, Field
from autogen.oai.client_utils import FormatterProtocol, logging_formatter, validate_parameter
from ..import_utils import optional_import_block, require_optional_import
from ..llm_config import LLMConfigEntry, register_llm_config
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
with optional_import_block():
from cohere import ClientV2 as CohereV2
from cohere.types import ToolResult
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
_ch = logging.StreamHandler(stream=sys.stdout)
_ch.setFormatter(logging_formatter)
logger.addHandler(_ch)
COHERE_PRICING_1K = {
"command-r-plus": (0.003, 0.015),
"command-r": (0.0005, 0.0015),
"command-nightly": (0.00025, 0.00125),
"command": (0.015, 0.075),
"command-light": (0.008, 0.024),
"command-light-nightly": (0.008, 0.024),
}
@register_llm_config
class CohereLLMConfigEntry(LLMConfigEntry):
api_type: Literal["cohere"] = "cohere"
temperature: float = Field(default=0.3, ge=0)
max_tokens: Optional[int] = Field(default=None, ge=0)
k: int = Field(default=0, ge=0, le=500)
p: float = Field(default=0.75, ge=0.01, le=0.99)
seed: Optional[int] = None
frequency_penalty: float = Field(default=0, ge=0, le=1)
presence_penalty: float = Field(default=0, ge=0, le=1)
client_name: Optional[str] = None
strict_tools: bool = False
stream: bool = False
tool_choice: Optional[Literal["NONE", "REQUIRED"]] = None
def create_client(self):
raise NotImplementedError("CohereLLMConfigEntry.create_client is not implemented.")
class CohereClient:
"""Client for Cohere's API."""
def __init__(self, **kwargs):
"""Requires api_key or environment variable to be set
Args:
**kwargs: The keyword arguments to pass to the Cohere API.
"""
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key")
if not self.api_key:
self.api_key = os.getenv("COHERE_API_KEY")
assert self.api_key, (
"Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."
)
# Store the response format, if provided (for structured outputs)
self._response_format: Optional[Type[BaseModel]] = None
def message_retrieval(self, response) -> list:
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
def cost(self, response) -> float:
return response.cost
@staticmethod
def get_usage(response) -> dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Loads the parameters for Cohere API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
cohere_params = {}
# Check that we have what we need to use Cohere's API
# We won't enforce the available models as they are likely to change
cohere_params["model"] = params.get("model")
assert cohere_params["model"], (
"Please specify the 'model' in your config list entry to nominate the Cohere model to use."
)
# Handle structured output response format from Pydantic model
if "response_format" in params and params["response_format"] is not None:
self._response_format = params.get("response_format")
response_format = params["response_format"]
# Check if it's a Pydantic model
if hasattr(response_format, "model_json_schema"):
# Get the JSON schema from the Pydantic model
schema = response_format.model_json_schema()
def resolve_ref(ref: str, defs: dict) -> dict:
"""Resolve a $ref to its actual schema definition"""
# Extract the definition name from "#/$defs/Name"
def_name = ref.split("/")[-1]
return defs[def_name]
def ensure_type_fields(obj: dict, defs: dict) -> dict:
"""Recursively ensure all objects in the schema have a type and properties field"""
if isinstance(obj, dict):
# If it has a $ref, replace it with the actual definition
if "$ref" in obj:
ref_def = resolve_ref(obj["$ref"], defs)
# Merge the reference definition with any existing fields
obj = {**ref_def, **obj}
# Remove the $ref as we've replaced it
del obj["$ref"]
# Process each value recursively
return {
k: ensure_type_fields(v, defs) if isinstance(v, (dict, list)) else v for k, v in obj.items()
}
elif isinstance(obj, list):
return [ensure_type_fields(item, defs) for item in obj]
return obj
# Make a copy of $defs before processing
defs = schema.get("$defs", {})
# Process the schema
processed_schema = ensure_type_fields(schema, defs)
cohere_params["response_format"] = {"type": "json_object", "json_schema": processed_schema}
else:
raise ValueError("response_format must be a Pydantic BaseModel")
# Handle strict tools parameter for structured outputs with tools
if "tools" in params:
cohere_params["strict_tools"] = validate_parameter(params, "strict_tools", bool, False, False, None, None)
# Validate allowed Cohere parameters
# https://docs.cohere.com/reference/chat
if "temperature" in params:
cohere_params["temperature"] = validate_parameter(
params, "temperature", (int, float), False, 0.3, (0, None), None
)
if "max_tokens" in params:
cohere_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
if "k" in params:
cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None)
if "p" in params:
cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None)
if "seed" in params:
cohere_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
if "frequency_penalty" in params:
cohere_params["frequency_penalty"] = validate_parameter(
params, "frequency_penalty", (int, float), True, 0, (0, 1), None
)
if "presence_penalty" in params:
cohere_params["presence_penalty"] = validate_parameter(
params, "presence_penalty", (int, float), True, 0, (0, 1), None
)
if "tool_choice" in params:
cohere_params["tool_choice"] = validate_parameter(
params, "tool_choice", str, True, None, None, ["NONE", "REQUIRED"]
)
return cohere_params
@require_optional_import("cohere", "cohere")
def create(self, params: dict) -> ChatCompletion:
messages = params.get("messages", [])
client_name = params.get("client_name") or "AG2"
cohere_tool_names = set()
tool_calls_modified_ids = set()
# Parse parameters to the Cohere API's parameters
cohere_params = self.parse_params(params)
cohere_params["messages"] = messages
if "tools" in params:
cohere_tool_names = set([tool["function"]["name"] for tool in params["tools"]])
cohere_params["tools"] = params["tools"]
# Strip out name
for message in cohere_params["messages"]:
message_name = message.pop("name", "")
# Extract and prepend name to content or tool_plan if available
message["content"] = (
f"{message_name}: {(message.get('content') or message.get('tool_plan'))}"
if message_name
else (message.get("content") or message.get("tool_plan"))
)
# Handle tool calls
if message.get("tool_calls") is not None and len(message["tool_calls"]) > 0:
message["tool_plan"] = message.get("tool_plan", message["content"])
del message["content"] # Remove content as tool_plan is prioritized
# If tool call name is missing or not recognized, modify role and content
for tool_call in message["tool_calls"] or []:
if (not tool_call.get("function", {}).get("name")) or tool_call.get("function", {}).get(
"name"
) not in cohere_tool_names:
message["role"] = "assistant"
message["content"] = f"{message.pop('tool_plan', '')}{str(message['tool_calls'])}"
tool_calls_modified_ids = tool_calls_modified_ids.union(
set([tool_call.get("id") for tool_call in message["tool_calls"]])
)
del message["tool_calls"]
break
# Adjust role if message comes from a tool with a modified ID
if message.get("role") == "tool":
tool_id = message.get("tool_call_id")
if tool_id in tool_calls_modified_ids:
message["role"] = "user"
del message["tool_call_id"] # Remove the tool call ID
# We use chat model by default
client = CohereV2(api_key=self.api_key, client_name=client_name)
# Token counts will be returned
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
# Stream if in parameters
streaming = params.get("stream")
cohere_finish = "stop"
tool_calls = None
ans = None
if streaming:
response = client.chat_stream(**cohere_params)
# Streaming...
ans = ""
plan = ""
prompt_tokens = 0
completion_tokens = 0
for chunk in response:
if chunk.type == "content-delta":
ans = ans + chunk.delta.message.content.text
elif chunk.type == "tool-plan-delta":
plan = plan + chunk.delta.message.tool_plan
elif chunk.type == "tool-call-start":
cohere_finish = "tool_calls"
# Initialize a new tool call
tool_call = chunk.delta.message.tool_calls
current_tool = {
"id": tool_call.id,
"type": "function",
"function": {"name": tool_call.function.name, "arguments": ""},
}
elif chunk.type == "tool-call-delta":
# Progressively build the arguments as they stream in
if current_tool is not None:
current_tool["function"]["arguments"] += chunk.delta.message.tool_calls.function.arguments
elif chunk.type == "tool-call-end":
# Append the finished tool call to the list
if current_tool is not None:
if tool_calls is None:
tool_calls = []
tool_calls.append(ChatCompletionMessageToolCall(**current_tool))
current_tool = None
elif chunk.type == "message-start":
response_id = chunk.id
elif chunk.type == "message-end":
prompt_tokens = (
chunk.delta.usage.billed_units.input_tokens
) # Note total (billed+non-billed) available with ...usage.tokens...
completion_tokens = chunk.delta.usage.billed_units.output_tokens
total_tokens = prompt_tokens + completion_tokens
else:
response = client.chat(**cohere_params)
if response.message.tool_calls is not None:
ans = response.message.tool_plan
cohere_finish = "tool_calls"
tool_calls = []
for tool_call in response.message.tool_calls:
# if parameters are null, clear them out (Cohere can return a string "null" if no parameter values)
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={
"name": tool_call.function.name,
"arguments": (
"" if tool_call.function.arguments is None else tool_call.function.arguments
),
},
type="function",
)
)
else:
ans: str = response.message.content[0].text
# Not using billed_units, but that may be better for cost purposes
prompt_tokens = (
response.usage.billed_units.input_tokens
) # Note total (billed+non-billed) available with ...usage.tokens...
completion_tokens = response.usage.billed_units.output_tokens
total_tokens = prompt_tokens + completion_tokens
response_id = response.id
# Clean up structured output if needed
if self._response_format:
# ans = clean_return_response_format(ans)
try:
parsed_response = self._convert_json_response(ans)
ans = _format_json_response(parsed_response, ans)
except ValueError as e:
ans = str(e)
# 3. convert output
message = ChatCompletionMessage(
role="assistant",
content=ans,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=cohere_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=response_id,
model=cohere_params["model"],
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
cost=calculate_cohere_cost(prompt_tokens, completion_tokens, cohere_params["model"]),
)
return response_oai
def _convert_json_response(self, response: str) -> Any:
"""Extract and validate JSON response from the output for structured outputs.
Args:
response (str): The response from the API.
Returns:
Any: The parsed JSON response.
"""
if not self._response_format:
return response
try:
# Parse JSON and validate against the Pydantic model
json_data = json.loads(response)
return self._response_format.model_validate(json_data)
except Exception as e:
raise ValueError(
f"Failed to parse response as valid JSON matching the schema for Structured Output: {str(e)}"
)
def _format_json_response(response: Any, original_answer: str) -> str:
"""Formats the JSON response for structured outputs using the format method if it exists."""
return (
response.format() if isinstance(response, FormatterProtocol) else clean_return_response_format(original_answer)
)
def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> list[dict[str, Any]]:
temp_tool_results = []
for tool_call in all_tool_calls:
if tool_call["id"] == tool_call_id:
call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"] if tool_call["function"]["arguments"] != "" else "{}"
),
}
output = [{"value": content_output}]
temp_tool_results.append(ToolResult(call=call, outputs=output))
return temp_tool_results
def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float:
"""Calculate the cost of the completion using the Cohere pricing."""
total = 0.0
if model in COHERE_PRICING_1K:
input_cost_per_k, output_cost_per_k = COHERE_PRICING_1K[model]
input_cost = (input_tokens / 1000) * input_cost_per_k
output_cost = (output_tokens / 1000) * output_cost_per_k
total = input_cost + output_cost
else:
warnings.warn(f"Cost calculation not available for {model} model", UserWarning)
return total
def clean_return_response_format(response_str: str) -> str:
"""Clean up the response string by parsing through json library."""
# Parse the string to a JSON object to handle escapes
data = json.loads(response_str)
# Convert back to JSON string with minimal formatting
return json.dumps(data)
class CohereError(Exception):
"""Base class for other Cohere exceptions"""
pass
class CohereRateLimitError(CohereError):
"""Raised when rate limit is exceeded"""
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,156 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import enum
import warnings
from typing import Any, Optional, Type, TypeVar, Union, get_args, get_origin
from pydantic import BaseModel as BaseModel
from pydantic import ConfigDict, Field, alias_generators
def _remove_extra_fields(model: Any, response: dict[str, object]) -> None:
"""Removes extra fields from the response that are not in the model.
Mutates the response in place.
"""
key_values = list(response.items())
for key, value in key_values:
# Need to convert to snake case to match model fields names
# ex: UsageMetadata
alias_map = {field_info.alias: key for key, field_info in model.model_fields.items()}
if key not in model.model_fields and key not in alias_map:
response.pop(key)
continue
key = alias_map.get(key, key)
annotation = model.model_fields[key].annotation
# Get the BaseModel if Optional
if get_origin(annotation) is Union:
annotation = get_args(annotation)[0]
# if dict, assume BaseModel but also check that field type is not dict
# example: FunctionCall.args
if isinstance(value, dict) and get_origin(annotation) is not dict:
_remove_extra_fields(annotation, value)
elif isinstance(value, list):
for item in value:
# assume a list of dict is list of BaseModel
if isinstance(item, dict):
_remove_extra_fields(get_args(annotation)[0], item)
T = TypeVar("T", bound="BaseModel")
class CommonBaseModel(BaseModel):
model_config = ConfigDict(
alias_generator=alias_generators.to_camel,
populate_by_name=True,
from_attributes=True,
protected_namespaces=(),
extra="forbid",
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
arbitrary_types_allowed=True,
ser_json_bytes="base64",
val_json_bytes="base64",
ignored_types=(TypeVar,),
)
@classmethod
def _from_response(cls: Type[T], *, response: dict[str, object], kwargs: dict[str, object]) -> T:
# To maintain forward compatibility, we need to remove extra fields from
# the response.
# We will provide another mechanism to allow users to access these fields.
_remove_extra_fields(cls, response)
validated_response = cls.model_validate(response)
return validated_response
def to_json_dict(self) -> dict[str, object]:
return self.model_dump(exclude_none=True, mode="json")
class CaseInSensitiveEnum(str, enum.Enum):
"""Case insensitive enum."""
@classmethod
def _missing_(cls, value: Any) -> Optional["CaseInSensitiveEnum"]:
try:
return cls[value.upper()] # Try to access directly with uppercase
except KeyError:
try:
return cls[value.lower()] # Try to access directly with lowercase
except KeyError:
warnings.warn(f"{value} is not a valid {cls.__name__}")
try:
# Creating a enum instance based on the value
# We need to use super() to avoid infinite recursion.
unknown_enum_val = super().__new__(cls, value)
unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
unknown_enum_val._value_ = value # pylint: disable=protected-access
return unknown_enum_val
except: # noqa: E722
return None
class FunctionCallingConfigMode(CaseInSensitiveEnum):
"""Config for the function calling config mode."""
MODE_UNSPECIFIED = "MODE_UNSPECIFIED"
AUTO = "AUTO"
ANY = "ANY"
NONE = "NONE"
class LatLng(CommonBaseModel):
"""An object that represents a latitude/longitude pair.
This is expressed as a pair of doubles to represent degrees latitude and
degrees longitude. Unless specified otherwise, this object must conform to the
<a href="https://en.wikipedia.org/wiki/World_Geodetic_System#1984_version">
WGS84 standard</a>. Values must be within normalized ranges.
"""
latitude: Optional[float] = Field(
default=None,
description="""The latitude in degrees. It must be in the range [-90.0, +90.0].""",
)
longitude: Optional[float] = Field(
default=None,
description="""The longitude in degrees. It must be in the range [-180.0, +180.0]""",
)
class FunctionCallingConfig(CommonBaseModel):
"""Function calling config."""
mode: Optional[FunctionCallingConfigMode] = Field(default=None, description="""Optional. Function calling mode.""")
allowed_function_names: Optional[list[str]] = Field(
default=None,
description="""Optional. Function names to call. Only set when the Mode is ANY. Function names should match [FunctionDeclaration.name]. With mode set to ANY, model will predict a function call from the set of function names provided.""",
)
class RetrievalConfig(CommonBaseModel):
"""Retrieval config."""
lat_lng: Optional[LatLng] = Field(default=None, description="""Optional. The location of the user.""")
language_code: Optional[str] = Field(default=None, description="""The language code of the user.""")
class ToolConfig(CommonBaseModel):
"""Tool config.
This config is shared for all tools provided in the request.
"""
function_calling_config: Optional[FunctionCallingConfig] = Field(
default=None, description="""Optional. Function calling config."""
)
retrieval_config: Optional[RetrievalConfig] = Field(default=None, description="""Optional. Retrieval config.""")

View File

@@ -0,0 +1,305 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Create an OpenAI-compatible client using Groq's API.
Example:
```python
llm_config = {
"config_list": [{"api_type": "groq", "model": "mixtral-8x7b-32768", "api_key": os.environ.get("GROQ_API_KEY")}]
}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
```
Install Groq's python library using: pip install --upgrade groq
Resources:
- https://console.groq.com/docs/quickstart
"""
from __future__ import annotations
import copy
import os
import time
import warnings
from typing import Any, Literal, Optional
from pydantic import Field
from ..import_utils import optional_import_block, require_optional_import
from ..llm_config import LLMConfigEntry, register_llm_config
from .client_utils import should_hide_tools, validate_parameter
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
with optional_import_block():
from groq import Groq, Stream
# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
GROQ_PRICING_1K = {
"llama3-70b-8192": (0.00059, 0.00079),
"mixtral-8x7b-32768": (0.00024, 0.00024),
"llama3-8b-8192": (0.00005, 0.00008),
"gemma-7b-it": (0.00007, 0.00007),
}
@register_llm_config
class GroqLLMConfigEntry(LLMConfigEntry):
api_type: Literal["groq"] = "groq"
frequency_penalty: float = Field(default=None, ge=-2, le=2)
max_tokens: int = Field(default=None, ge=0)
presence_penalty: float = Field(default=None, ge=-2, le=2)
seed: int = Field(default=None)
stream: bool = Field(default=False)
temperature: float = Field(default=1, ge=0, le=2)
top_p: float = Field(default=None)
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
tool_choice: Optional[Literal["none", "auto", "required"]] = None
def create_client(self):
raise NotImplementedError("GroqLLMConfigEntry.create_client is not implemented.")
class GroqClient:
"""Client for Groq's API."""
def __init__(self, **kwargs):
"""Requires api_key or environment variable to be set
Args:
**kwargs: Additional parameters to pass to the Groq API
"""
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key")
if not self.api_key:
self.api_key = os.getenv("GROQ_API_KEY")
assert self.api_key, (
"Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
)
if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Groq API, it will be ignored.", UserWarning)
self.base_url = kwargs.get("base_url")
def message_retrieval(self, response) -> list:
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
def cost(self, response) -> float:
return response.cost
@staticmethod
def get_usage(response) -> dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
groq_params = {}
# Check that we have what we need to use Groq's API
# We won't enforce the available models as they are likely to change
groq_params["model"] = params.get("model")
assert groq_params["model"], (
"Please specify the 'model' in your config list entry to nominate the Groq model to use."
)
# Validate allowed Groq parameters
# https://console.groq.com/docs/api-reference#chat
groq_params["frequency_penalty"] = validate_parameter(
params, "frequency_penalty", (int, float), True, None, (-2, 2), None
)
groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
groq_params["presence_penalty"] = validate_parameter(
params, "presence_penalty", (int, float), True, None, (-2, 2), None
)
groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
if "tool_choice" in params:
groq_params["tool_choice"] = validate_parameter(
params, "tool_choice", str, True, None, None, ["none", "auto", "required"]
)
# Groq parameters not supported by their models yet, ignoring
# logit_bias, logprobs, top_logprobs
# Groq parameters we are ignoring:
# n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
# parallel_tool_calls (defaults to True), stop
# function_call (deprecated), functions (deprecated)
# tool_choice (none if no tools, auto if there are tools)
return groq_params
@require_optional_import("groq", "groq")
def create(self, params: dict) -> ChatCompletion:
messages = params.get("messages", [])
# Convert AG2 messages to Groq messages
groq_messages = oai_messages_to_groq_messages(messages)
# Parse parameters to the Groq API's parameters
groq_params = self.parse_params(params)
# Add tools to the call if we have them and aren't hiding them
if "tools" in params:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
if not should_hide_tools(groq_messages, params["tools"], hide_tools):
groq_params["tools"] = params["tools"]
groq_params["messages"] = groq_messages
# We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
client = Groq(api_key=self.api_key, max_retries=5, base_url=self.base_url)
# Token counts will be returned
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
# Streaming tool call recommendations
streaming_tool_calls = []
ans = None
response = client.chat.completions.create(**groq_params)
if groq_params["stream"]:
# Read in the chunks as they stream, taking in tool_calls which may be across
# multiple chunks if more than one suggested
ans = ""
for chunk in response:
ans = ans + (chunk.choices[0].delta.content or "")
if chunk.choices[0].delta.tool_calls:
# We have a tool call recommendation
for tool_call in chunk.choices[0].delta.tool_calls:
streaming_tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
type="function",
)
)
if chunk.choices[0].finish_reason:
prompt_tokens = chunk.x_groq.usage.prompt_tokens
completion_tokens = chunk.x_groq.usage.completion_tokens
total_tokens = chunk.x_groq.usage.total_tokens
else:
# Non-streaming finished
ans: str = response.choices[0].message.content
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
total_tokens = response.usage.total_tokens
if response is not None:
if isinstance(response, Stream):
# Streaming response
if chunk.choices[0].finish_reason == "tool_calls":
groq_finish = "tool_calls"
tool_calls = streaming_tool_calls
else:
groq_finish = "stop"
tool_calls = None
response_content = ans
response_id = chunk.id
else:
# Non-streaming response
# If we have tool calls as the response, populate completed tool calls for our return OAI response
if response.choices[0].finish_reason == "tool_calls":
groq_finish = "tool_calls"
tool_calls = []
for tool_call in response.choices[0].message.tool_calls:
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
type="function",
)
)
else:
groq_finish = "stop"
tool_calls = None
response_content = response.choices[0].message.content
response_id = response.id
else:
raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
# 3. convert output
message = ChatCompletionMessage(
role="assistant",
content=response_content,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=response_id,
model=groq_params["model"],
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
)
return response_oai
def oai_messages_to_groq_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Groq's format.
We correct for any specific role orders and types.
"""
groq_messages = copy.deepcopy(messages)
# Remove the name field
for message in groq_messages:
if "name" in message:
message.pop("name", None)
return groq_messages
def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
"""Calculate the cost of the completion using the Groq pricing."""
total = 0.0
if model in GROQ_PRICING_1K:
input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
input_cost = (input_tokens / 1000) * input_cost_per_k
output_cost = (output_tokens / 1000) * output_cost_per_k
total = input_cost + output_cost
else:
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
return total

View File

@@ -0,0 +1,303 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Create an OpenAI-compatible client using Mistral.AI's API.
Example:
```python
llm_config = {
"config_list": [
{"api_type": "mistral", "model": "open-mixtral-8x22b", "api_key": os.environ.get("MISTRAL_API_KEY")}
]
}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
```
Install Mistral.AI python library using: pip install --upgrade mistralai
Resources:
- https://docs.mistral.ai/getting-started/quickstart/
NOTE: Requires mistralai package version >= 1.0.1
"""
import json
import os
import time
import warnings
from typing import Any, Literal, Optional, Union
from pydantic import Field
from ..import_utils import optional_import_block, require_optional_import
from ..llm_config import LLMConfigEntry, register_llm_config
from .client_utils import should_hide_tools, validate_parameter
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
with optional_import_block():
# Mistral libraries
# pip install mistralai
from mistralai import (
AssistantMessage,
Function,
FunctionCall,
Mistral,
SystemMessage,
ToolCall,
ToolMessage,
UserMessage,
)
@register_llm_config
class MistralLLMConfigEntry(LLMConfigEntry):
api_type: Literal["mistral"] = "mistral"
temperature: float = Field(default=0.7)
top_p: Optional[float] = None
max_tokens: Optional[int] = Field(default=None, ge=0)
safe_prompt: bool = False
random_seed: Optional[int] = None
stream: bool = False
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
tool_choice: Optional[Literal["none", "auto", "any"]] = None
def create_client(self):
raise NotImplementedError("MistralLLMConfigEntry.create_client is not implemented.")
@require_optional_import("mistralai", "mistral")
class MistralAIClient:
"""Client for Mistral.AI's API."""
def __init__(self, **kwargs):
"""Requires api_key or environment variable to be set
Args:
**kwargs: Additional keyword arguments to pass to the Mistral client.
"""
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key")
if not self.api_key:
self.api_key = os.getenv("MISTRAL_API_KEY", None)
assert self.api_key, (
"Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."
)
if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Mistral.AI, it will be ignored.", UserWarning)
self._client = Mistral(api_key=self.api_key)
def message_retrieval(self, response: ChatCompletion) -> Union[list[str], list[ChatCompletionMessage]]:
"""Retrieve the messages from the response."""
return [choice.message for choice in response.choices]
def cost(self, response) -> float:
return response.cost
@require_optional_import("mistralai", "mistral")
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Loads the parameters for Mistral.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
mistral_params = {}
# 1. Validate models
mistral_params["model"] = params.get("model")
assert mistral_params["model"], (
"Please specify the 'model' in your config list entry to nominate the Mistral.ai model to use."
)
# 2. Validate allowed Mistral.AI parameters
mistral_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 0.7, None, None)
mistral_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
mistral_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
mistral_params["safe_prompt"] = validate_parameter(
params, "safe_prompt", bool, False, False, None, [True, False]
)
mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None)
mistral_params["tool_choice"] = validate_parameter(
params, "tool_choice", str, False, None, None, ["none", "auto", "any"]
)
# TODO
if params.get("stream", False):
warnings.warn(
"Streaming is not currently supported, streaming will be disabled.",
UserWarning,
)
# 3. Convert messages to Mistral format
mistral_messages = []
tool_call_ids = {} # tool call ids to function name mapping
for message in params["messages"]:
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
# Convert OAI ToolCall to Mistral ToolCall
mistral_messages_tools = []
for toolcall in message["tool_calls"]:
mistral_messages_tools.append(
ToolCall(
id=toolcall["id"],
function=FunctionCall(
name=toolcall["function"]["name"],
arguments=json.loads(toolcall["function"]["arguments"]),
),
)
)
mistral_messages.append(AssistantMessage(content="", tool_calls=mistral_messages_tools))
# Map tool call id to the function name
for tool_call in message["tool_calls"]:
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
elif message["role"] == "system":
if len(mistral_messages) > 0 and mistral_messages[-1].role == "assistant":
# System messages can't appear after an Assistant message, so use a UserMessage
mistral_messages.append(UserMessage(content=message["content"]))
else:
mistral_messages.append(SystemMessage(content=message["content"]))
elif message["role"] == "assistant":
mistral_messages.append(AssistantMessage(content=message["content"]))
elif message["role"] == "user":
mistral_messages.append(UserMessage(content=message["content"]))
elif message["role"] == "tool":
# Indicates the result of a tool call, the name is the function name called
mistral_messages.append(
ToolMessage(
name=tool_call_ids[message["tool_call_id"]],
content=message["content"],
tool_call_id=message["tool_call_id"],
)
)
else:
warnings.warn(f"Unknown message role {message['role']}", UserWarning)
# 4. Last message needs to be user or tool, if not, add a "please continue" message
if not isinstance(mistral_messages[-1], UserMessage) and not isinstance(mistral_messages[-1], ToolMessage):
mistral_messages.append(UserMessage(content="Please continue."))
mistral_params["messages"] = mistral_messages
# 5. Add tools to the call if we have them and aren't hiding them
if "tools" in params:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
if not should_hide_tools(params["messages"], params["tools"], hide_tools):
mistral_params["tools"] = tool_def_to_mistral(params["tools"])
return mistral_params
@require_optional_import("mistralai", "mistral")
def create(self, params: dict[str, Any]) -> ChatCompletion:
# 1. Parse parameters to Mistral.AI API's parameters
mistral_params = self.parse_params(params)
# 2. Call Mistral.AI API
mistral_response = self._client.chat.complete(**mistral_params)
# TODO: Handle streaming
# 3. Convert Mistral response to OAI compatible format
if mistral_response.choices[0].finish_reason == "tool_calls":
mistral_finish = "tool_calls"
tool_calls = []
for tool_call in mistral_response.choices[0].message.tool_calls:
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
type="function",
)
)
else:
mistral_finish = "stop"
tool_calls = None
message = ChatCompletionMessage(
role="assistant",
content=mistral_response.choices[0].message.content,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=mistral_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=mistral_response.id,
model=mistral_response.model,
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=mistral_response.usage.prompt_tokens,
completion_tokens=mistral_response.usage.completion_tokens,
total_tokens=mistral_response.usage.prompt_tokens + mistral_response.usage.completion_tokens,
),
cost=calculate_mistral_cost(
mistral_response.usage.prompt_tokens, mistral_response.usage.completion_tokens, mistral_response.model
),
)
return response_oai
@staticmethod
def get_usage(response: ChatCompletion) -> dict:
return {
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
"total_tokens": (
response.usage.prompt_tokens + response.usage.completion_tokens if response.usage is not None else 0
),
"cost": response.cost if hasattr(response, "cost") else 0,
"model": response.model,
}
@require_optional_import("mistralai", "mistral")
def tool_def_to_mistral(tool_definitions: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts AG2 tool definition to a mistral tool format"""
mistral_tools = []
for autogen_tool in tool_definitions:
mistral_tool = {
"type": "function",
"function": Function(
name=autogen_tool["function"]["name"],
description=autogen_tool["function"]["description"],
parameters=autogen_tool["function"]["parameters"],
),
}
mistral_tools.append(mistral_tool)
return mistral_tools
def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
"""Calculate the cost of the mistral response."""
# Prices per 1 thousand tokens
# https://mistral.ai/technology/
model_cost_map = {
"open-mistral-7b": {"input": 0.00025, "output": 0.00025},
"open-mixtral-8x7b": {"input": 0.0007, "output": 0.0007},
"open-mixtral-8x22b": {"input": 0.002, "output": 0.006},
"mistral-small-latest": {"input": 0.001, "output": 0.003},
"mistral-medium-latest": {"input": 0.00275, "output": 0.0081},
"mistral-large-latest": {"input": 0.0003, "output": 0.0003},
"mistral-large-2407": {"input": 0.0003, "output": 0.0003},
"open-mistral-nemo-2407": {"input": 0.0003, "output": 0.0003},
"codestral-2405": {"input": 0.001, "output": 0.003},
}
# Ensure we have the model they are using and return the total cost
if model_name in model_cost_map:
costs = model_cost_map[model_name]
return (input_tokens * costs["input"] / 1000) + (output_tokens * costs["output"] / 1000)
else:
warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning)
return 0

View File

@@ -0,0 +1,11 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .chat_completion import ChatCompletionExtended as ChatCompletion
from .chat_completion import Choice
from .chat_completion_message import ChatCompletionMessage
from .chat_completion_message_tool_call import ChatCompletionMessageToolCall
from .completion_usage import CompletionUsage
__all__ = ["ChatCompletion", "ChatCompletionMessage", "ChatCompletionMessageToolCall", "Choice", "CompletionUsage"]

View File

@@ -0,0 +1,16 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
# Taken over from https://github.com/openai/openai-python/blob/main/src/openai/_models.py
import pydantic
import pydantic.generics
from pydantic import ConfigDict
from typing_extensions import ClassVar
__all__ = ["BaseModel"]
class BaseModel(pydantic.BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")

View File

@@ -0,0 +1,87 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/chat/chat_completion.py
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Any, Callable, List, Optional
from typing_extensions import Literal
from ._models import BaseModel
from .chat_completion_message import ChatCompletionMessage
from .chat_completion_token_logprob import ChatCompletionTokenLogprob
from .completion_usage import CompletionUsage
__all__ = ["ChatCompletion", "Choice", "ChoiceLogprobs"]
class ChoiceLogprobs(BaseModel):
content: Optional[List[ChatCompletionTokenLogprob]] = None
"""A list of message content tokens with log probability information."""
refusal: Optional[List[ChatCompletionTokenLogprob]] = None
"""A list of message refusal tokens with log probability information."""
class Choice(BaseModel):
finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
"""The reason the model stopped generating tokens.
This will be `stop` if the model hit a natural stop point or a provided stop
sequence, `length` if the maximum number of tokens specified in the request was
reached, `content_filter` if content was omitted due to a flag from our content
filters, `tool_calls` if the model called a tool, or `function_call`
(deprecated) if the model called a function.
"""
index: int
"""The index of the choice in the list of choices."""
logprobs: Optional[ChoiceLogprobs] = None
"""Log probability information for the choice."""
message: ChatCompletionMessage
"""A chat completion message generated by the model."""
class ChatCompletion(BaseModel):
id: str
"""A unique identifier for the chat completion."""
choices: List[Choice]
"""A list of chat completion choices.
Can be more than one if `n` is greater than 1.
"""
created: int
"""The Unix timestamp (in seconds) of when the chat completion was created."""
model: str
"""The model used for the chat completion."""
object: Literal["chat.completion"]
"""The object type, which is always `chat.completion`."""
service_tier: Optional[Literal["auto", "default", "flex", "scale"]] = None
"""The service tier used for processing the request."""
system_fingerprint: Optional[str] = None
"""This fingerprint represents the backend configuration that the model runs with.
Can be used in conjunction with the `seed` request parameter to understand when
backend changes have been made that might impact determinism.
"""
usage: Optional[CompletionUsage] = None
"""Usage statistics for the completion request."""
class ChatCompletionExtended(ChatCompletion):
message_retrieval_function: Optional[Callable[[Any, "ChatCompletion"], list[ChatCompletionMessage]]] = None
config_id: Optional[str] = None
pass_filter: Optional[Callable[..., bool]] = None
cost: Optional[float] = None

View File

@@ -0,0 +1,32 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/chat/chat_completion_audio.py
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from ._models import BaseModel
__all__ = ["ChatCompletionAudio"]
class ChatCompletionAudio(BaseModel):
id: str
"""Unique identifier for this audio response."""
data: str
"""
Base64 encoded audio bytes generated by the model, in the format specified in
the request.
"""
expires_at: int
"""
The Unix timestamp (in seconds) for when this audio response will no longer be
accessible on the server for use in multi-turn conversations.
"""
transcript: str
"""Transcript of the audio generated by the model."""

View File

@@ -0,0 +1,86 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
# Taken over from https://github.com/openai/openai-python/blob/16a10604fbd0d82c1382b84b417a1d6a2d33a7f1/src/openai/types/chat/chat_completion_message.py
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List, Optional
from typing_extensions import Literal
from ._models import BaseModel
from .chat_completion_audio import ChatCompletionAudio
from .chat_completion_message_tool_call import ChatCompletionMessageToolCall
__all__ = ["Annotation", "AnnotationURLCitation", "ChatCompletionMessage", "FunctionCall"]
class AnnotationURLCitation(BaseModel):
end_index: int
"""The index of the last character of the URL citation in the message."""
start_index: int
"""The index of the first character of the URL citation in the message."""
title: str
"""The title of the web resource."""
url: str
"""The URL of the web resource."""
class Annotation(BaseModel):
type: Literal["url_citation"]
"""The type of the URL citation. Always `url_citation`."""
url_citation: AnnotationURLCitation
"""A URL citation when using web search."""
class FunctionCall(BaseModel):
arguments: str
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""
name: str
"""The name of the function to call."""
class ChatCompletionMessage(BaseModel):
content: Optional[str] = None
"""The contents of the message."""
refusal: Optional[str] = None
"""The refusal message generated by the model."""
role: Literal["assistant"]
"""The role of the author of this message."""
annotations: Optional[List[Annotation]] = None
"""
Annotations for the message, when applicable, as when using the
[web search tool](https://platform.openai.com/docs/guides/tools-web-search?api-mode=chat).
"""
audio: Optional[ChatCompletionAudio] = None
"""
If the audio output modality is requested, this object contains data about the
audio response from the model.
[Learn more](https://platform.openai.com/docs/guides/audio).
"""
function_call: Optional[FunctionCall] = None
"""Deprecated and replaced by `tool_calls`.
The name and arguments of a function that should be called, as generated by the
model.
"""
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
"""The tool calls generated by the model, such as function calls."""

View File

@@ -0,0 +1,37 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/chat/chat_completion_message_tool_call.py
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal
from ._models import BaseModel
__all__ = ["ChatCompletionMessageToolCall", "Function"]
class Function(BaseModel):
arguments: str
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""
name: str
"""The name of the function to call."""
class ChatCompletionMessageToolCall(BaseModel):
id: str
"""The ID of the tool call."""
function: Function
"""The function that the model called."""
type: Literal["function"]
"""The type of the tool. Currently, only `function` is supported."""

View File

@@ -0,0 +1,63 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/chat/chat_completion_token_logprob.py
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List, Optional
from ._models import BaseModel
__all__ = ["ChatCompletionTokenLogprob", "TopLogprob"]
class TopLogprob(BaseModel):
token: str
"""The token."""
bytes: Optional[List[int]] = None
"""A list of integers representing the UTF-8 bytes representation of the token.
Useful in instances where characters are represented by multiple tokens and
their byte representations must be combined to generate the correct text
representation. Can be `null` if there is no bytes representation for the token.
"""
logprob: float
"""The log probability of this token, if it is within the top 20 most likely
tokens.
Otherwise, the value `-9999.0` is used to signify that the token is very
unlikely.
"""
class ChatCompletionTokenLogprob(BaseModel):
token: str
"""The token."""
bytes: Optional[List[int]] = None
"""A list of integers representing the UTF-8 bytes representation of the token.
Useful in instances where characters are represented by multiple tokens and
their byte representations must be combined to generate the correct text
representation. Can be `null` if there is no bytes representation for the token.
"""
logprob: float
"""The log probability of this token, if it is within the top 20 most likely
tokens.
Otherwise, the value `-9999.0` is used to signify that the token is very
unlikely.
"""
top_logprobs: List[TopLogprob]
"""List of the most likely tokens and their log probability, at this token
position.
In rare cases, there may be fewer than the number of requested `top_logprobs`
returned.
"""

View File

@@ -0,0 +1,60 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
# Taken over from https://github.com/openai/openai-python/blob/3e69750d47df4f0759d4a28ddc68e4b38756d9ca/src/openai/types/completion_usage.py
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from ._models import BaseModel
__all__ = ["CompletionTokensDetails", "CompletionUsage", "PromptTokensDetails"]
class CompletionTokensDetails(BaseModel):
accepted_prediction_tokens: Optional[int] = None
"""
When using Predicted Outputs, the number of tokens in the prediction that
appeared in the completion.
"""
audio_tokens: Optional[int] = None
"""Audio input tokens generated by the model."""
reasoning_tokens: Optional[int] = None
"""Tokens generated by the model for reasoning."""
rejected_prediction_tokens: Optional[int] = None
"""
When using Predicted Outputs, the number of tokens in the prediction that did
not appear in the completion. However, like reasoning tokens, these tokens are
still counted in the total completion tokens for purposes of billing, output,
and context window limits.
"""
class PromptTokensDetails(BaseModel):
audio_tokens: Optional[int] = None
"""Audio input tokens present in the prompt."""
cached_tokens: Optional[int] = None
"""Cached tokens present in the prompt."""
class CompletionUsage(BaseModel):
completion_tokens: int
"""Number of tokens in the generated completion."""
prompt_tokens: int
"""Number of tokens in the prompt."""
total_tokens: int
"""Total number of tokens used in the request (prompt + completion)."""
completion_tokens_details: Optional[CompletionTokensDetails] = None
"""Breakdown of tokens used in a completion."""
prompt_tokens_details: Optional[PromptTokensDetails] = None
"""Breakdown of tokens used in the prompt."""

View File

@@ -0,0 +1,643 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Create an OpenAI-compatible client using Ollama's API.
Example:
```python
llm_config = {"config_list": [{"api_type": "ollama", "model": "mistral:7b-instruct-v0.3-q6_K"}]}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
```
Install Ollama's python library using: pip install --upgrade ollama
Install fix-busted-json library: pip install --upgrade fix-busted-json
Resources:
- https://github.com/ollama/ollama-python
"""
from __future__ import annotations
import copy
import json
import random
import re
import time
import warnings
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, Field, HttpUrl
from ..import_utils import optional_import_block, require_optional_import
from ..llm_config import LLMConfigEntry, register_llm_config
from .client_utils import FormatterProtocol, should_hide_tools, validate_parameter
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
with optional_import_block():
import ollama
from fix_busted_json import repair_json
from ollama import Client
@register_llm_config
class OllamaLLMConfigEntry(LLMConfigEntry):
api_type: Literal["ollama"] = "ollama"
client_host: Optional[HttpUrl] = None
stream: bool = False
num_predict: int = Field(
default=-1,
description="Maximum number of tokens to predict, note: -1 is infinite (default), -2 is fill context.",
)
num_ctx: int = Field(default=2048)
repeat_penalty: float = Field(default=1.1)
seed: int = Field(default=0)
temperature: float = Field(default=0.8)
top_k: int = Field(default=40)
top_p: float = Field(default=0.9)
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
def create_client(self):
raise NotImplementedError("OllamaLLMConfigEntry.create_client is not implemented.")
class OllamaClient:
"""Client for Ollama's API."""
# Defaults for manual tool calling
# Instruction is added to the first system message and provides directions to follow a two step
# process
# 1. (before tools have been called) Return JSON with the functions to call
# 2. (directly after tools have been called) Return Text describing the results of the function calls in text format
# Override using "manual_tool_call_instruction" config parameter
TOOL_CALL_MANUAL_INSTRUCTION = (
"You are to follow a strict two step process that will occur over "
"a number of interactions, so pay attention to what step you are in based on the full "
"conversation. We will be taking turns so only do one step at a time so don't perform step "
"2 until step 1 is complete and I've told you the result. The first step is to choose one "
"or more functions based on the request given and return only JSON with the functions and "
"arguments to use. The second step is to analyse the given output of the function and summarise "
"it returning only TEXT and not Python or JSON. "
"For argument values, be sure numbers aren't strings, they should not have double quotes around them. "
"In terms of your response format, for step 1 return only JSON and NO OTHER text, "
"for step 2 return only text and NO JSON/Python/Markdown. "
'The format for running a function is [{"name": "function_name1", "arguments":{"argument_name": "argument_value"}},{"name": "function_name2", "arguments":{"argument_name": "argument_value"}}] '
'Make sure the keys "name" and "arguments" are as described. '
"If you don't get the format correct, try again. "
"The following functions are available to you:[FUNCTIONS_LIST]"
)
# Appended to the last user message if no tools have been called
# Override using "manual_tool_call_step1" config parameter
TOOL_CALL_MANUAL_STEP1 = " (proceed with step 1)"
# Appended to the user message after tools have been executed. Will create a 'user' message if one doesn't exist.
# Override using "manual_tool_call_step2" config parameter
TOOL_CALL_MANUAL_STEP2 = " (proceed with step 2)"
def __init__(self, response_format: Optional[Union[BaseModel, dict[str, Any]]] = None, **kwargs):
"""Note that no api_key or environment variable is required for Ollama."""
# Store the response format, if provided (for structured outputs)
self._response_format: Optional[Union[BaseModel, dict[str, Any]]] = response_format
def message_retrieval(self, response) -> list:
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
def cost(self, response) -> float:
return response.cost
@staticmethod
def get_usage(response) -> dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Loads the parameters for Ollama API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
ollama_params = {}
# Check that we have what we need to use Ollama's API
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
# The main parameters are model, prompt, stream, and options
# Options is a dictionary of parameters for the model
# There are other, advanced, parameters such as format, system (to override system message), template, raw, etc. - not used
# We won't enforce the available models
ollama_params["model"] = params.get("model")
assert ollama_params["model"], (
"Please specify the 'model' in your config list entry to nominate the Ollama model to use."
)
ollama_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
# Build up the options dictionary
# https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
options_dict = {}
if "num_predict" in params:
# Maximum number of tokens to predict, note: -1 is infinite, -2 is fill context, 128 is default
options_dict["num_predict"] = validate_parameter(params, "num_predict", int, False, 128, None, None)
if "num_ctx" in params:
# Set size of context window used to generate next token, 2048 is default
options_dict["num_ctx"] = validate_parameter(params, "num_ctx", int, False, 2048, None, None)
if "repeat_penalty" in params:
options_dict["repeat_penalty"] = validate_parameter(
params, "repeat_penalty", (int, float), False, 1.1, None, None
)
if "seed" in params:
options_dict["seed"] = validate_parameter(params, "seed", int, False, 42, None, None)
if "temperature" in params:
options_dict["temperature"] = validate_parameter(
params, "temperature", (int, float), False, 0.8, None, None
)
if "top_k" in params:
options_dict["top_k"] = validate_parameter(params, "top_k", int, False, 40, None, None)
if "top_p" in params:
options_dict["top_p"] = validate_parameter(params, "top_p", (int, float), False, 0.9, None, None)
if self._native_tool_calls and self._tools_in_conversation and not self._should_hide_tools:
ollama_params["tools"] = params["tools"]
# Ollama doesn't support streaming with tools natively
if ollama_params["stream"] and self._native_tool_calls:
warnings.warn(
"Streaming is not supported when using tools and 'Native' tool calling, streaming will be disabled.",
UserWarning,
)
ollama_params["stream"] = False
if not self._native_tool_calls and self._tools_in_conversation:
# For manual tool calling we have injected the available tools into the prompt
# and we don't want to force JSON mode
ollama_params["format"] = "" # Don't force JSON for manual tool calling mode
if len(options_dict) != 0:
ollama_params["options"] = options_dict
# Structured outputs (see https://ollama.com/blog/structured-outputs)
if not self._response_format and params.get("response_format"):
self._response_format = params["response_format"]
if self._response_format:
if isinstance(self._response_format, dict):
ollama_params["format"] = self._response_format
else:
# Keep self._response_format as a Pydantic model for when process the response
ollama_params["format"] = self._response_format.model_json_schema()
return ollama_params
@require_optional_import(["ollama", "fix_busted_json"], "ollama")
def create(self, params: dict) -> ChatCompletion:
messages = params.get("messages", [])
# Are tools involved in this conversation?
self._tools_in_conversation = "tools" in params
# We provide second-level filtering out of tools to avoid LLMs re-calling tools continuously
if self._tools_in_conversation:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
self._should_hide_tools = should_hide_tools(messages, params["tools"], hide_tools)
else:
self._should_hide_tools = False
# Are we using native Ollama tool calling, otherwise we're doing manual tool calling
# We allow the user to decide if they want to use Ollama's tool calling
# or for tool calling to be handled manually through text messages
# Default is True = Ollama's tool calling
self._native_tool_calls = validate_parameter(params, "native_tool_calls", bool, False, True, None, None)
if not self._native_tool_calls:
# Load defaults
self._manual_tool_call_instruction = validate_parameter(
params, "manual_tool_call_instruction", str, False, self.TOOL_CALL_MANUAL_INSTRUCTION, None, None
)
self._manual_tool_call_step1 = validate_parameter(
params, "manual_tool_call_step1", str, False, self.TOOL_CALL_MANUAL_STEP1, None, None
)
self._manual_tool_call_step2 = validate_parameter(
params, "manual_tool_call_step2", str, False, self.TOOL_CALL_MANUAL_STEP2, None, None
)
# Convert AG2 messages to Ollama messages
ollama_messages = self.oai_messages_to_ollama_messages(
messages,
(
params["tools"]
if (not self._native_tool_calls and self._tools_in_conversation) and not self._should_hide_tools
else None
),
)
# Parse parameters to the Ollama API's parameters
ollama_params = self.parse_params(params)
ollama_params["messages"] = ollama_messages
# Token counts will be returned
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
ans = None
if "client_host" in params:
# Convert client_host to string from HttpUrl
client = Client(host=str(params["client_host"]))
response = client.chat(**ollama_params)
else:
response = ollama.chat(**ollama_params)
if ollama_params["stream"]:
# Read in the chunks as they stream, taking in tool_calls which may be across
# multiple chunks if more than one suggested
ans = ""
for chunk in response:
ans = ans + (chunk["message"]["content"] or "")
if "done_reason" in chunk:
prompt_tokens = chunk.get("prompt_eval_count", 0)
completion_tokens = chunk.get("eval_count", 0)
total_tokens = prompt_tokens + completion_tokens
else:
# Non-streaming finished
ans: str = response["message"]["content"]
prompt_tokens = response.get("prompt_eval_count", 0)
completion_tokens = response.get("eval_count", 0)
total_tokens = prompt_tokens + completion_tokens
if response is not None:
# Defaults
ollama_finish = "stop"
tool_calls = None
# Id and streaming text into response
if ollama_params["stream"]:
response_content = ans
response_id = chunk["created_at"]
else:
response_content = response["message"]["content"]
response_id = response["created_at"]
# Process tools in the response
if self._tools_in_conversation:
if self._native_tool_calls:
if not ollama_params["stream"]:
response_content = response["message"]["content"]
# Native tool calling
if "tool_calls" in response["message"]:
ollama_finish = "tool_calls"
tool_calls = []
random_id = random.randint(0, 10000)
for tool_call in response["message"]["tool_calls"]:
tool_calls.append(
ChatCompletionMessageToolCall(
id=f"ollama_func_{random_id}",
function={
"name": tool_call["function"]["name"],
"arguments": json.dumps(tool_call["function"]["arguments"]),
},
type="function",
)
)
random_id += 1
elif not self._native_tool_calls:
# Try to convert the response to a tool call object
response_toolcalls = response_to_tool_call(ans)
# If we can, then we've got tool call(s)
if response_toolcalls is not None:
ollama_finish = "tool_calls"
tool_calls = []
random_id = random.randint(0, 10000)
for json_function in response_toolcalls:
tool_calls.append(
ChatCompletionMessageToolCall(
id=f"ollama_manual_func_{random_id}",
function={
"name": json_function["name"],
"arguments": (
json.dumps(json_function["arguments"])
if "arguments" in json_function
else "{}"
),
},
type="function",
)
)
random_id += 1
# Blank the message content
response_content = ""
if ollama_finish == "stop": # noqa: SIM102
# Not a tool call, so let's check if we need to process structured output
if self._response_format and response_content:
try:
parsed_response = self._convert_json_response(response_content)
response_content = _format_json_response(parsed_response, response_content)
except ValueError as e:
response_content = str(e)
else:
raise RuntimeError("Failed to get response from Ollama.")
# Convert response to AG2 response
message = ChatCompletionMessage(
role="assistant",
content=response_content,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=ollama_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=response_id,
model=ollama_params["model"],
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
cost=0, # Local models, FREE!
)
return response_oai
def oai_messages_to_ollama_messages(self, messages: list[dict[str, Any]], tools: list) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Ollama's format.
We correct for any specific role orders and types, and convert tools to messages (as Ollama can't use tool messages)
"""
ollama_messages = copy.deepcopy(messages)
# Remove the name field
for message in ollama_messages:
if "name" in message:
message.pop("name", None)
# Having a 'system' message on the end does not work well with Ollama, so we change it to 'user'
# 'system' messages on the end are typical of the summarisation message: summary_method="reflection_with_llm"
if len(ollama_messages) > 1 and ollama_messages[-1]["role"] == "system":
ollama_messages[-1]["role"] = "user"
# Process messages for tool calling manually
if tools is not None and not self._native_tool_calls:
# 1. We need to append instructions to the starting system message on function calling
# 2. If we have not yet called tools we append "step 1 instruction" to the latest user message
# 3. If we have already called tools we append "step 2 instruction" to the latest user message
have_tool_calls = False
have_tool_results = False
last_tool_result_index = -1
for i, message in enumerate(ollama_messages):
if "tool_calls" in message:
have_tool_calls = True
if "tool_call_id" in message:
have_tool_results = True
last_tool_result_index = i
tool_result_is_last_msg = have_tool_results and last_tool_result_index == len(ollama_messages) - 1
if ollama_messages[0]["role"] == "system":
manual_instruction = self._manual_tool_call_instruction
# Build a string of the functions available
functions_string = ""
for function in tools:
functions_string += f"""\n{function}\n"""
# Replace single quotes with double questions - Not sure why this helps the LLM perform
# better, but it seems to. Monitor and remove if not necessary.
functions_string = functions_string.replace("'", '"')
manual_instruction = manual_instruction.replace("[FUNCTIONS_LIST]", functions_string)
# Update the system message with the instructions and functions
ollama_messages[0]["content"] = ollama_messages[0]["content"] + manual_instruction.rstrip()
# If we are still in the function calling or evaluating process, append the steps instruction
if (not have_tool_calls or tool_result_is_last_msg) and ollama_messages[0]["role"] == "system":
# NOTE: we require a system message to exist for the manual steps texts
# Append the manual step instructions
content_to_append = (
self._manual_tool_call_step1 if not have_tool_results else self._manual_tool_call_step2
)
if content_to_append != "":
# Append the relevant tool call instruction to the latest user message
if ollama_messages[-1]["role"] == "user":
ollama_messages[-1]["content"] = ollama_messages[-1]["content"] + content_to_append
else:
ollama_messages.append({"role": "user", "content": content_to_append})
# Convert tool call and tool result messages to normal text messages for Ollama
for i, message in enumerate(ollama_messages):
if "tool_calls" in message:
# Recommended tool calls
content = "Run the following function(s):"
for tool_call in message["tool_calls"]:
content = content + "\n" + str(tool_call)
ollama_messages[i] = {"role": "assistant", "content": content}
if "tool_call_id" in message:
# Executed tool results
message["result"] = message["content"]
del message["content"]
del message["role"]
content = "The following function was run: " + str(message)
ollama_messages[i] = {"role": "user", "content": content}
# As we are changing messages, let's merge if they have two user messages on the end and the last one is tool call step instructions
if (
len(ollama_messages) >= 2
and not self._native_tool_calls
and ollama_messages[-2]["role"] == "user"
and ollama_messages[-1]["role"] == "user"
and (
ollama_messages[-1]["content"] == self._manual_tool_call_step1
or ollama_messages[-1]["content"] == self._manual_tool_call_step2
)
):
ollama_messages[-2]["content"] = ollama_messages[-2]["content"] + ollama_messages[-1]["content"]
del ollama_messages[-1]
# Ensure the last message is a user / system message, if not, add a user message
if ollama_messages[-1]["role"] != "user" and ollama_messages[-1]["role"] != "system":
ollama_messages.append({"role": "user", "content": "Please continue."})
return ollama_messages
def _convert_json_response(self, response: str) -> Any:
"""Extract and validate JSON response from the output for structured outputs.
Args:
response (str): The response from the API.
Returns:
Any: The parsed JSON response.
"""
if not self._response_format:
return response
try:
# Parse JSON and validate against the Pydantic model if Pydantic model was provided
if isinstance(self._response_format, dict):
return response
else:
return self._response_format.model_validate_json(response)
except Exception as e:
raise ValueError(f"Failed to parse response as valid JSON matching the schema for Structured Output: {e!s}")
def _format_json_response(response: Any, original_answer: str) -> str:
"""Formats the JSON response for structured outputs using the format method if it exists."""
return response.format() if isinstance(response, FormatterProtocol) else original_answer
@require_optional_import("fix_busted_json", "ollama")
def response_to_tool_call(response_string: str) -> Any:
"""Attempts to convert the response to an object, aimed to align with function format `[{},{}]`"""
# We try and detect the list[dict[str, Any]] format:
# Pattern 1 is [{},{}]
# Pattern 2 is {} (without the [], so could be a single function call)
patterns = [r"\[[\s\S]*?\]", r"\{[\s\S]*\}"]
for i, pattern in enumerate(patterns):
# Search for the pattern in the input string
matches = re.findall(pattern, response_string.strip())
for match in matches:
# It has matched, extract it and load it
json_str = match.strip()
data_object = None
try:
# Attempt to convert it as is
data_object = json.loads(json_str)
except Exception:
try:
# If that fails, attempt to repair it
if i == 0:
# Enclose to a JSON object for repairing, which is restored upon fix
fixed_json = repair_json("{'temp':" + json_str + "}")
data_object = json.loads(fixed_json)
data_object = data_object["temp"]
else:
fixed_json = repair_json(json_str)
data_object = json.loads(fixed_json)
except json.JSONDecodeError as e:
if e.msg == "Invalid \\escape":
# Handle Mistral/Mixtral trying to escape underlines with \\
try:
json_str = json_str.replace("\\_", "_")
if i == 0:
fixed_json = repair_json("{'temp':" + json_str + "}")
data_object = json.loads(fixed_json)
data_object = data_object["temp"]
else:
fixed_json = repair_json("{'temp':" + json_str + "}")
data_object = json.loads(fixed_json)
except Exception:
pass
except Exception:
pass
if data_object is not None:
data_object = _object_to_tool_call(data_object)
if data_object is not None:
return data_object
# There's no tool call in the response
return None
def _object_to_tool_call(data_object: Any) -> list[dict[str, Any]]:
"""Attempts to convert an object to a valid tool call object List[Dict] and returns it, if it can, otherwise None"""
# If it's a dictionary and not a list then wrap in a list
if isinstance(data_object, dict):
data_object = [data_object]
# Validate that the data is a list of dictionaries
if isinstance(data_object, list) and all(isinstance(item, dict) for item in data_object):
# Perfect format, a list of dictionaries
# Check that each dictionary has at least 'name', optionally 'arguments' and no other keys
is_invalid = False
for item in data_object:
if not is_valid_tool_call_item(item):
is_invalid = True
break
# All passed, name and (optionally) arguments exist for all entries.
if not is_invalid:
return data_object
elif isinstance(data_object, list):
# If it's a list but the items are not dictionaries, check if they are strings that can be converted to dictionaries
data_copy = data_object.copy()
is_invalid = False
for i, item in enumerate(data_copy):
try:
new_item = eval(item)
if isinstance(new_item, dict):
if is_valid_tool_call_item(new_item):
data_object[i] = new_item
else:
is_invalid = True
break
else:
is_invalid = True
break
except Exception:
is_invalid = True
break
if not is_invalid:
return data_object
return None
def is_valid_tool_call_item(call_item: dict) -> bool:
"""Check that a dictionary item has at least 'name', optionally 'arguments' and no other keys to match a tool call JSON"""
if "name" not in call_item or not isinstance(call_item["name"], str):
return False
if set(call_item.keys()) - {"name", "arguments"}: # noqa: SIM103
return False
return True

View File

@@ -0,0 +1,881 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import importlib
import importlib.metadata
import inspect
import json
import logging
import os
import re
import tempfile
import time
import warnings
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
from dotenv import find_dotenv, load_dotenv
from packaging.version import parse
from pydantic_core import to_jsonable_python
if TYPE_CHECKING:
from openai import OpenAI
from openai.types.beta.assistant import Assistant
from ..doc_utils import export_module
from ..llm_config import LLMConfig
NON_CACHE_KEY = [
"api_key",
"base_url",
"api_type",
"api_version",
"azure_ad_token",
"azure_ad_token_provider",
"credentials",
]
DEFAULT_AZURE_API_VERSION = "2024-02-01"
# The below pricing is for 1K tokens. Whenever there is an update in the LLM's pricing,
# Please convert it to 1K tokens and update in the below dictionary in the format: (input_token_price, output_token_price).
OAI_PRICE1K = {
# https://openai.com/api/pricing/
# o1
"o1-preview-2024-09-12": (0.0015, 0.0060),
"o1-preview": (0.0015, 0.0060),
"o1-mini-2024-09-12": (0.0003, 0.0012),
"o1-mini": (0.0003, 0.0012),
"o1": (0.0015, 0.0060),
"o1-2024-12-17": (0.0015, 0.0060),
# o1 pro
"o1-pro": (0.15, 0.6), # $150 / $600!
"o1-pro-2025-03-19": (0.15, 0.6),
# o3
"o3": (0.0011, 0.0044),
"o3-mini-2025-01-31": (0.0011, 0.0044),
# gpt-4o
"gpt-4o": (0.005, 0.015),
"gpt-4o-2024-05-13": (0.005, 0.015),
"gpt-4o-2024-08-06": (0.0025, 0.01),
"gpt-4o-2024-11-20": (0.0025, 0.01),
# gpt-4o-mini
"gpt-4o-mini": (0.000150, 0.000600),
"gpt-4o-mini-2024-07-18": (0.000150, 0.000600),
# gpt-4-turbo
"gpt-4-turbo-2024-04-09": (0.01, 0.03),
# gpt-4
"gpt-4": (0.03, 0.06),
"gpt-4-32k": (0.06, 0.12),
# gpt-4.1
"gpt-4.1": (0.002, 0.008),
"gpt-4.1-2025-04-14": (0.002, 0.008),
# gpt-4.1 mini
"gpt-4.1-mini": (0.0004, 0.0016),
"gpt-4.1-mini-2025-04-14": (0.0004, 0.0016),
# gpt-4.1 nano
"gpt-4.1-nano": (0.0001, 0.0004),
"gpt-4.1-nano-2025-04-14": (0.0001, 0.0004),
# gpt-3.5 turbo
"gpt-3.5-turbo": (0.0005, 0.0015), # default is 0125
"gpt-3.5-turbo-0125": (0.0005, 0.0015), # 16k
"gpt-3.5-turbo-instruct": (0.0015, 0.002),
# base model
"davinci-002": 0.002,
"babbage-002": 0.0004,
# old model
"gpt-4-0125-preview": (0.01, 0.03),
"gpt-4-1106-preview": (0.01, 0.03),
"gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images
"gpt-3.5-turbo-1106": (0.001, 0.002),
"gpt-3.5-turbo-0613": (0.0015, 0.002),
# "gpt-3.5-turbo-16k": (0.003, 0.004),
"gpt-3.5-turbo-16k-0613": (0.003, 0.004),
"gpt-3.5-turbo-0301": (0.0015, 0.002),
"text-ada-001": 0.0004,
"text-babbage-001": 0.0005,
"text-curie-001": 0.002,
"code-cushman-001": 0.024,
"code-davinci-002": 0.1,
"text-davinci-002": 0.02,
"text-davinci-003": 0.02,
"gpt-4-0314": (0.03, 0.06), # deprecate in Sep
"gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep
"gpt-4-0613": (0.03, 0.06),
"gpt-4-32k-0613": (0.06, 0.12),
"gpt-4-turbo-preview": (0.01, 0.03),
# https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/#pricing
"gpt-35-turbo": (0.0005, 0.0015), # what's the default? using 0125 here.
"gpt-35-turbo-0125": (0.0005, 0.0015),
"gpt-35-turbo-instruct": (0.0015, 0.002),
"gpt-35-turbo-1106": (0.001, 0.002),
"gpt-35-turbo-0613": (0.0015, 0.002),
"gpt-35-turbo-0301": (0.0015, 0.002),
"gpt-35-turbo-16k": (0.003, 0.004),
"gpt-35-turbo-16k-0613": (0.003, 0.004),
# deepseek
"deepseek-chat": (0.00027, 0.0011),
}
def get_key(config: dict[str, Any]) -> str:
"""Get a unique identifier of a configuration.
Args:
config (dict or list): A configuration.
Returns:
tuple: A unique identifier which can be used as a key for a dict.
"""
copied = False
for key in NON_CACHE_KEY:
if key in config:
config, copied = config.copy() if not copied else config, True
config.pop(key)
return to_jsonable_python(config) # type: ignore [no-any-return]
def is_valid_api_key(api_key: str) -> bool:
"""Determine if input is valid OpenAI API key. As of 2024-09-24 there's no official definition of the key structure
so we will allow anything starting with "sk-" and having at least 48 alphanumeric (plus underscore and dash) characters.
Keys are known to start with "sk-", "sk-proj", "sk-None", and "sk-svcaat"
Args:
api_key (str): An input string to be validated.
Returns:
bool: A boolean that indicates if input is valid OpenAI API key.
"""
api_key_re = re.compile(r"^sk-[A-Za-z0-9_-]{48,}$")
return bool(re.fullmatch(api_key_re, api_key))
@export_module("autogen")
def get_config_list(
api_keys: list[str],
base_urls: Optional[list[str]] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
) -> list[dict[str, Any]]:
"""Get a list of configs for OpenAI API client.
Args:
api_keys (list): The api keys for openai api calls.
base_urls (list, optional): The api bases for openai api calls. If provided, should match the length of api_keys.
api_type (str, optional): The api type for openai api calls.
api_version (str, optional): The api version for openai api calls.
Returns:
list: A list of configs for OepnAI API calls.
Example:
```python
# Define a list of API keys
api_keys = ["key1", "key2", "key3"]
# Optionally, define a list of base URLs corresponding to each API key
base_urls = ["https://api.service1.com", "https://api.service2.com", "https://api.service3.com"]
# Optionally, define the API type and version if they are common for all keys
api_type = "azure"
api_version = "2024-02-01"
# Call the get_config_list function to get a list of configuration dictionaries
config_list = get_config_list(api_keys, base_urls, api_type, api_version)
```
"""
if base_urls is not None:
assert len(api_keys) == len(base_urls), "The length of api_keys must match the length of base_urls"
config_list = []
for i, api_key in enumerate(api_keys):
if not api_key.strip():
continue
config = {"api_key": api_key}
if base_urls:
config["base_url"] = base_urls[i]
if api_type:
config["api_type"] = api_type
if api_version:
config["api_version"] = api_version
config_list.append(config)
return config_list
@export_module("autogen")
def get_first_llm_config(
llm_config: Union[LLMConfig, dict[str, Any]],
) -> dict[str, Any]:
"""Get the first LLM config from the given LLM config.
Args:
llm_config (dict): The LLM config.
Returns:
dict: The first LLM config.
Raises:
ValueError: If the LLM config is invalid.
"""
llm_config = deepcopy(llm_config)
if "config_list" not in llm_config:
if "model" in llm_config:
return llm_config # type: ignore [return-value]
raise ValueError("llm_config must be a valid config dictionary.")
if len(llm_config["config_list"]) == 0:
raise ValueError("Config list must contain at least one config.")
to_return = llm_config["config_list"][0]
return to_return if isinstance(to_return, dict) else to_return.model_dump() # type: ignore [no-any-return]
@export_module("autogen")
def config_list_openai_aoai(
key_file_path: Optional[str] = ".",
openai_api_key_file: Optional[str] = "key_openai.txt",
aoai_api_key_file: Optional[str] = "key_aoai.txt",
openai_api_base_file: Optional[str] = "base_openai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
exclude: Optional[str] = None,
) -> list[dict[str, Any]]:
"""Get a list of configs for OpenAI API client (including Azure or local model deployments that support OpenAI's chat completion API).
This function constructs configurations by reading API keys and base URLs from environment variables or text files.
It supports configurations for both OpenAI and Azure OpenAI services, allowing for the exclusion of one or the other.
When text files are used, the environment variables will be overwritten.
To prevent text files from being used, set the corresponding file name to None.
Or set key_file_path to None to disallow reading from text files.
Args:
key_file_path (str, optional): The directory path where the API key files are located. Defaults to the current directory.
openai_api_key_file (str, optional): The filename containing the OpenAI API key. Defaults to 'key_openai.txt'.
aoai_api_key_file (str, optional): The filename containing the Azure OpenAI API key. Defaults to 'key_aoai.txt'.
openai_api_base_file (str, optional): The filename containing the OpenAI API base URL. Defaults to 'base_openai.txt'.
aoai_api_base_file (str, optional): The filename containing the Azure OpenAI API base URL. Defaults to 'base_aoai.txt'.
exclude (str, optional): The API type to exclude from the configuration list. Can be 'openai' or 'aoai'. Defaults to None.
Returns:
List[Dict]: A list of configuration dictionaries. Each dictionary contains keys for 'api_key',
and optionally 'base_url', 'api_type', and 'api_version'.
Raises:
FileNotFoundError: If the specified key files are not found and the corresponding API key is not set in the environment variables.
Example:
# To generate configurations excluding Azure OpenAI:
configs = config_list_openai_aoai(exclude='aoai')
File samples:
- key_aoai.txt
```
aoai-12345abcdef67890ghijklmnopqr
aoai-09876zyxwvuts54321fedcba
```
- base_aoai.txt
```
https://api.azure.com/v1
https://api.azure2.com/v1
```
Notes:
- The function checks for API keys and base URLs in the following environment variables: 'OPENAI_API_KEY', 'AZURE_OPENAI_API_KEY',
'OPENAI_API_BASE' and 'AZURE_OPENAI_API_BASE'. If these are not found, it attempts to read from the specified files in the
'key_file_path' directory.
- The API version for Azure configurations is set to DEFAULT_AZURE_API_VERSION by default.
- If 'exclude' is set to 'openai', only Azure OpenAI configurations are returned, and vice versa.
- The function assumes that the API keys and base URLs in the environment variables are separated by new lines if there are
multiple entries.
"""
if exclude != "openai" and key_file_path is not None:
# skip if key_file_path is None
if openai_api_key_file is not None:
# skip if openai_api_key_file is None
try:
with open(f"{key_file_path}/{openai_api_key_file}") as key_file:
os.environ["OPENAI_API_KEY"] = key_file.read().strip()
except FileNotFoundError:
logging.info(
"OPENAI_API_KEY is not found in os.environ "
"and key_openai.txt is not found in the specified path. You can specify the api_key in the config_list."
)
if openai_api_base_file is not None:
# skip if openai_api_base_file is None
try:
with open(f"{key_file_path}/{openai_api_base_file}") as key_file:
os.environ["OPENAI_API_BASE"] = key_file.read().strip()
except FileNotFoundError:
logging.info(
"OPENAI_API_BASE is not found in os.environ "
"and base_openai.txt is not found in the specified path. You can specify the base_url in the config_list."
)
if exclude != "aoai" and key_file_path is not None:
# skip if key_file_path is None
if aoai_api_key_file is not None:
try:
with open(f"{key_file_path}/{aoai_api_key_file}") as key_file:
os.environ["AZURE_OPENAI_API_KEY"] = key_file.read().strip()
except FileNotFoundError:
logging.info(
"AZURE_OPENAI_API_KEY is not found in os.environ "
"and key_aoai.txt is not found in the specified path. You can specify the api_key in the config_list."
)
if aoai_api_base_file is not None:
try:
with open(f"{key_file_path}/{aoai_api_base_file}") as key_file:
os.environ["AZURE_OPENAI_API_BASE"] = key_file.read().strip()
except FileNotFoundError:
logging.info(
"AZURE_OPENAI_API_BASE is not found in os.environ "
"and base_aoai.txt is not found in the specified path. You can specify the base_url in the config_list."
)
aoai_config = (
get_config_list(
# Assuming Azure OpenAI api keys in os.environ["AZURE_OPENAI_API_KEY"], in separated lines
api_keys=os.environ.get("AZURE_OPENAI_API_KEY", "").split("\n"),
# Assuming Azure OpenAI api bases in os.environ["AZURE_OPENAI_API_BASE"], in separated lines
base_urls=os.environ.get("AZURE_OPENAI_API_BASE", "").split("\n"),
api_type="azure",
api_version=DEFAULT_AZURE_API_VERSION,
)
if exclude != "aoai"
else []
)
# process openai base urls
base_urls_env_var = os.environ.get("OPENAI_API_BASE", None)
base_urls = base_urls_env_var if base_urls_env_var is None else base_urls_env_var.split("\n")
openai_config = (
get_config_list(
# Assuming OpenAI API_KEY in os.environ["OPENAI_API_KEY"]
api_keys=os.environ.get("OPENAI_API_KEY", "").split("\n"),
base_urls=base_urls,
)
if exclude != "openai"
else []
)
config_list = openai_config + aoai_config
return config_list
@export_module("autogen")
def config_list_from_models(
key_file_path: Optional[str] = ".",
openai_api_key_file: Optional[str] = "key_openai.txt",
aoai_api_key_file: Optional[str] = "key_aoai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
exclude: Optional[str] = None,
model_list: Optional[list[str]] = None,
) -> list[dict[str, Any]]:
"""Get a list of configs for API calls with models specified in the model list.
This function extends `config_list_openai_aoai` by allowing to clone its' out for each of the models provided.
Each configuration will have a 'model' key with the model name as its value. This is particularly useful when
all endpoints have same set of models.
Args:
key_file_path (str, optional): The path to the key files.
openai_api_key_file (str, optional): The file name of the OpenAI API key.
aoai_api_key_file (str, optional): The file name of the Azure OpenAI API key.
aoai_api_base_file (str, optional): The file name of the Azure OpenAI API base.
exclude (str, optional): The API type to exclude, "openai" or "aoai".
model_list (list, optional): The list of model names to include in the configs.
Returns:
list: A list of configs for OpenAI API calls, each including model information.
Example:
```python
# Define the path where the API key files are located
key_file_path = "/path/to/key/files"
# Define the file names for the OpenAI and Azure OpenAI API keys and bases
openai_api_key_file = "key_openai.txt"
aoai_api_key_file = "key_aoai.txt"
aoai_api_base_file = "base_aoai.txt"
# Define the list of models for which to create configurations
model_list = ["gpt-4", "gpt-3.5-turbo"]
# Call the function to get a list of configuration dictionaries
config_list = config_list_from_models(
key_file_path=key_file_path,
openai_api_key_file=openai_api_key_file,
aoai_api_key_file=aoai_api_key_file,
aoai_api_base_file=aoai_api_base_file,
model_list=model_list,
)
# The `config_list` will contain configurations for the specified models, for example:
# [
# {'api_key': '...', 'base_url': 'https://api.openai.com', 'model': 'gpt-4'},
# {'api_key': '...', 'base_url': 'https://api.openai.com', 'model': 'gpt-3.5-turbo'}
# ]
```
"""
config_list = config_list_openai_aoai(
key_file_path=key_file_path,
openai_api_key_file=openai_api_key_file,
aoai_api_key_file=aoai_api_key_file,
aoai_api_base_file=aoai_api_base_file,
exclude=exclude,
)
if model_list:
config_list = [{**config, "model": model} for model in model_list for config in config_list]
return config_list
@export_module("autogen")
def config_list_gpt4_gpt35(
key_file_path: Optional[str] = ".",
openai_api_key_file: Optional[str] = "key_openai.txt",
aoai_api_key_file: Optional[str] = "key_aoai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
exclude: Optional[str] = None,
) -> list[dict[str, Any]]:
"""Get a list of configs for 'gpt-4' followed by 'gpt-3.5-turbo' API calls.
Args:
key_file_path (str, optional): The path to the key files.
openai_api_key_file (str, optional): The file name of the openai api key.
aoai_api_key_file (str, optional): The file name of the azure openai api key.
aoai_api_base_file (str, optional): The file name of the azure openai api base.
exclude (str, optional): The api type to exclude, "openai" or "aoai".
Returns:
list: A list of configs for openai api calls.
"""
return config_list_from_models(
key_file_path,
openai_api_key_file,
aoai_api_key_file,
aoai_api_base_file,
exclude,
model_list=["gpt-4", "gpt-3.5-turbo"],
)
@export_module("autogen")
def filter_config(
config_list: list[dict[str, Any]],
filter_dict: Optional[dict[str, Union[list[Union[str, None]], set[Union[str, None]]]]],
exclude: bool = False,
) -> list[dict[str, Any]]:
"""This function filters `config_list` by checking each configuration dictionary against the criteria specified in
`filter_dict`. A configuration dictionary is retained if for every key in `filter_dict`, see example below.
Args:
config_list (list of dict): A list of configuration dictionaries to be filtered.
filter_dict (dict): A dictionary representing the filter criteria, where each key is a
field name to check within the configuration dictionaries, and the
corresponding value is a list of acceptable values for that field.
If the configuration's field's value is not a list, then a match occurs
when it is found in the list of acceptable values. If the configuration's
field's value is a list, then a match occurs if there is a non-empty
intersection with the acceptable values.
exclude (bool): If False (the default value), configs that match the filter will be included in the returned
list. If True, configs that match the filter will be excluded in the returned list.
Returns:
list of dict: A list of configuration dictionaries that meet all the criteria specified
in `filter_dict`.
Example:
```python
# Example configuration list with various models and API types
configs = [
{"model": "gpt-3.5-turbo"},
{"model": "gpt-4"},
{"model": "gpt-3.5-turbo", "api_type": "azure"},
{"model": "gpt-3.5-turbo", "tags": ["gpt35_turbo", "gpt-35-turbo"]},
]
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
# that are also using the 'azure' API type
filter_criteria = {
"model": ["gpt-3.5-turbo"], # Only accept configurations for 'gpt-3.5-turbo'
"api_type": ["azure"], # Only accept configurations for 'azure' API type
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
# Define a filter to select a given tag
filter_criteria = {
"tags": ["gpt35_turbo"],
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
```
Note:
- If `filter_dict` is empty or None, no filtering is applied and `config_list` is returned as is.
- If a configuration dictionary in `config_list` does not contain a key specified in `filter_dict`,
it is considered a non-match and is excluded from the result.
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
dictionaries that do not have that key will also be considered a match.
"""
if inspect.stack()[1].function != "where":
warnings.warn(
"filter_config is deprecated and will be removed in a future release. "
'Please use the "autogen.LLMConfig.from_json(path="OAI_CONFIG_LIST").where(model="gpt-4o")" method instead.',
DeprecationWarning,
)
if filter_dict:
return [
item
for item in config_list
if all(_satisfies_criteria(item.get(key), values) != exclude for key, values in filter_dict.items())
]
return config_list
def _satisfies_criteria(value: Any, criteria_values: Any) -> bool:
if value is None:
return False
if isinstance(value, list):
return bool(set(value) & set(criteria_values)) # Non-empty intersection
else:
# In filter_dict, filter could be either a list of values or a single value.
# For example, filter_dict = {"model": ["gpt-3.5-turbo"]} or {"model": "gpt-3.5-turbo"}
if isinstance(criteria_values, list):
return value in criteria_values
return bool(value == criteria_values)
@export_module("autogen")
def config_list_from_json(
env_or_file: str,
file_location: Optional[str] = "",
filter_dict: Optional[dict[str, Union[list[Union[str, None]], set[Union[str, None]]]]] = None,
) -> list[dict[str, Any]]:
"""Retrieves a list of API configurations from a JSON stored in an environment variable or a file.
This function attempts to parse JSON data from the given `env_or_file` parameter. If `env_or_file` is an
environment variable containing JSON data, it will be used directly. Otherwise, it is assumed to be a filename,
and the function will attempt to read the file from the specified `file_location`.
The `filter_dict` parameter allows for filtering the configurations based on specified criteria. Each key in the
`filter_dict` corresponds to a field in the configuration dictionaries, and the associated value is a list or set
of acceptable values for that field. If a field is missing in a configuration and `None` is included in the list
of acceptable values for that field, the configuration will still be considered a match.
Args:
env_or_file (str): The name of the environment variable, the filename, or the environment variable of the filename
that containing the JSON data.
file_location (str, optional): The directory path where the file is located, if `env_or_file` is a filename.
filter_dict (dict, optional): A dictionary specifying the filtering criteria for the configurations, with
keys representing field names and values being lists or sets of acceptable values for those fields.
Example:
```python
# Suppose we have an environment variable 'CONFIG_JSON' with the following content:
# '[{"model": "gpt-3.5-turbo", "api_type": "azure"}, {"model": "gpt-4"}]'
# We can retrieve a filtered list of configurations like this:
filter_criteria = {"model": ["gpt-3.5-turbo"]}
configs = config_list_from_json("CONFIG_JSON", filter_dict=filter_criteria)
# The 'configs' variable will now contain only the configurations that match the filter criteria.
```
Returns:
List[Dict]: A list of configuration dictionaries that match the filtering criteria specified in `filter_dict`.
Raises:
FileNotFoundError: if env_or_file is neither found as an environment variable nor a file
"""
if inspect.stack()[1].function != "from_json":
warnings.warn(
"config_list_from_json is deprecated and will be removed in a future release. "
'Please use the "autogen.LLMConfig.from_json(path="OAI_CONFIG_LIST")" method instead.',
DeprecationWarning,
)
env_str = os.environ.get(env_or_file)
if env_str:
# The environment variable exists. We should use information from it.
if os.path.exists(env_str):
# It is a file location, and we need to load the json from the file.
with open(env_str) as file:
json_str = file.read()
else:
# Else, it should be a JSON string by itself.
json_str = env_str
config_list = json.loads(json_str)
else:
# The environment variable does not exist.
# So, `env_or_file` is a filename. We should use the file location.
config_list_path = os.path.join(file_location, env_or_file) if file_location is not None else env_or_file
with open(config_list_path) as json_file:
config_list = json.load(json_file)
config_list = filter_config(config_list, filter_dict)
return filter_config(config_list, filter_dict)
def get_config(
api_key: Optional[str],
base_url: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
) -> dict[str, Any]:
"""Constructs a configuration dictionary for a single model with the provided API configurations.
Example:
```python
config = get_config(api_key="sk-abcdef1234567890", base_url="https://api.openai.com", api_version="v1")
# The 'config' variable will now contain:
# {
# "api_key": "sk-abcdef1234567890",
# "base_url": "https://api.openai.com",
# "api_version": "v1"
# }
```
Args:
api_key (str): The API key for authenticating API requests.
base_url (Optional[str]): The base URL of the API. If not provided, defaults to None.
api_type (Optional[str]): The type of API. If not provided, defaults to None.
api_version (Optional[str]): The version of the API. If not provided, defaults to None.
Returns:
Dict: A dictionary containing the provided API configurations.
"""
config = {"api_key": api_key}
if base_url:
config["base_url"] = os.getenv(base_url, default=base_url)
if api_type:
config["api_type"] = os.getenv(api_type, default=api_type)
if api_version:
config["api_version"] = os.getenv(api_version, default=api_version)
return config
@export_module("autogen")
def config_list_from_dotenv(
dotenv_file_path: Optional[str] = None,
model_api_key_map: Optional[dict[str, Any]] = None,
filter_dict: Optional[dict[str, Union[list[Union[str, None]], set[Union[str, None]]]]] = None,
) -> list[dict[str, Union[str, set[str]]]]:
"""Load API configurations from a specified .env file or environment variables and construct a list of configurations.
This function will:
- Load API keys from a provided .env file or from existing environment variables.
- Create a configuration dictionary for each model using the API keys and additional configurations.
- Filter and return the configurations based on provided filters.
model_api_key_map will default to `{"gpt-4": "OPENAI_API_KEY", "gpt-3.5-turbo": "OPENAI_API_KEY"}` if none
Args:
dotenv_file_path (str, optional): The path to the .env file. Defaults to None.
model_api_key_map (str/dict, optional): A dictionary mapping models to their API key configurations.
If a string is provided as configuration, it is considered as an environment
variable name storing the API key.
If a dict is provided, it should contain at least 'api_key_env_var' key,
and optionally other API configurations like 'base_url', 'api_type', and 'api_version'.
Defaults to a basic map with 'gpt-4' and 'gpt-3.5-turbo' mapped to 'OPENAI_API_KEY'.
filter_dict (dict, optional): A dictionary containing the models to be loaded.
Containing a 'model' key mapped to a set of model names to be loaded.
Defaults to None, which loads all found configurations.
Returns:
List[Dict[str, Union[str, Set[str]]]]: A list of configuration dictionaries for each model.
Raises:
FileNotFoundError: If the specified .env file does not exist.
TypeError: If an unsupported type of configuration is provided in model_api_key_map.
"""
if dotenv_file_path:
dotenv_path = Path(dotenv_file_path)
if dotenv_path.exists():
load_dotenv(dotenv_path)
else:
logging.warning(f"The specified .env file {dotenv_path} does not exist.")
else:
dotenv_path_str = find_dotenv()
if not dotenv_path_str:
logging.warning("No .env file found. Loading configurations from environment variables.")
dotenv_path = Path(dotenv_path_str)
load_dotenv(dotenv_path)
# Ensure the model_api_key_map is not None to prevent TypeErrors during key assignment.
model_api_key_map = model_api_key_map or {}
# Ensure default models are always considered
default_models = ["gpt-4", "gpt-3.5-turbo"]
for model in default_models:
# Only assign default API key if the model is not present in the map.
# If model is present but set to invalid/empty, do not overwrite.
if model not in model_api_key_map:
model_api_key_map[model] = "OPENAI_API_KEY"
env_var = []
# Loop over the models and create configuration dictionaries
for model, config in model_api_key_map.items():
if isinstance(config, str):
api_key_env_var = config
config_dict = get_config(api_key=os.getenv(api_key_env_var))
elif isinstance(config, dict):
api_key = os.getenv(config.get("api_key_env_var", "OPENAI_API_KEY"))
config_without_key_var = {k: v for k, v in config.items() if k != "api_key_env_var"}
config_dict = get_config(api_key=api_key, **config_without_key_var)
else:
logging.warning(f"Unsupported type {type(config)} for model {model} configuration")
if not config_dict["api_key"] or config_dict["api_key"].strip() == "":
logging.warning(
f"API key not found or empty for model {model}. Please ensure path to .env file is correct."
)
continue # Skip this configuration and continue with the next
# Add model to the configuration and append to the list
config_dict["model"] = model
env_var.append(config_dict)
fd, temp_name = tempfile.mkstemp()
try:
with os.fdopen(fd, "w+") as temp:
env_var_str = json.dumps(env_var)
temp.write(env_var_str)
temp.flush()
# Assuming config_list_from_json is a valid function from your code
config_list = config_list_from_json(env_or_file=temp_name, filter_dict=filter_dict)
finally:
# The file is deleted after using its name (to prevent windows build from breaking)
os.remove(temp_name)
if len(config_list) == 0:
logging.error("No configurations loaded.")
return []
logging.info(f"Models available: {[config['model'] for config in config_list]}")
return config_list
def retrieve_assistants_by_name(client: "OpenAI", name: str) -> list["Assistant"]:
"""Return the assistants with the given name from OAI assistant API"""
assistants = client.beta.assistants.list()
candidate_assistants = []
for assistant in assistants.data:
if assistant.name == name:
candidate_assistants.append(assistant)
return candidate_assistants
def detect_gpt_assistant_api_version() -> str:
"""Detect the openai assistant API version"""
oai_version = importlib.metadata.version("openai")
return "v1" if parse(oai_version) < parse("1.21") else "v2"
def create_gpt_vector_store(client: "OpenAI", name: str, fild_ids: list[str]) -> Any:
"""Create a openai vector store for gpt assistant"""
try:
vector_store = client.vector_stores.create(name=name)
except Exception as e:
raise AttributeError(f"Failed to create vector store, please install the latest OpenAI python package: {e}")
# poll the status of the file batch for completion.
batch = client.vector_stores.file_batches.create_and_poll(vector_store_id=vector_store.id, file_ids=fild_ids)
if batch.status == "in_progress":
time.sleep(1)
logging.debug(f"file batch status: {batch.file_counts}")
batch = client.vector_stores.file_batches.poll(vector_store_id=vector_store.id, batch_id=batch.id)
if batch.status == "completed":
return vector_store
raise ValueError(f"Failed to upload files to vector store {vector_store.id}:{batch.status}")
def create_gpt_assistant(
client: "OpenAI", name: str, instructions: str, model: str, assistant_config: dict[str, Any]
) -> "Assistant":
"""Create a openai gpt assistant"""
assistant_create_kwargs = {}
gpt_assistant_api_version = detect_gpt_assistant_api_version()
tools = assistant_config.get("tools", [])
if gpt_assistant_api_version == "v2":
tool_resources = assistant_config.get("tool_resources", {})
file_ids = assistant_config.get("file_ids")
if tool_resources.get("file_search") is not None and file_ids is not None:
raise ValueError(
"Cannot specify both `tool_resources['file_search']` tool and `file_ids` in the assistant config."
)
# Designed for backwards compatibility for the V1 API
# Instead of V1 AssistantFile, files are attached to Assistants using the tool_resources object.
for tool in tools:
if tool["type"] == "retrieval":
tool["type"] = "file_search"
if file_ids is not None:
# create a vector store for the file search tool
vs = create_gpt_vector_store(client, f"{name}-vectorestore", file_ids)
tool_resources["file_search"] = {
"vector_store_ids": [vs.id],
}
elif tool["type"] == "code_interpreter" and file_ids is not None:
tool_resources["code_interpreter"] = {
"file_ids": file_ids,
}
assistant_create_kwargs["tools"] = tools
if len(tool_resources) > 0:
assistant_create_kwargs["tool_resources"] = tool_resources
else:
# not support forwards compatibility
if "tool_resources" in assistant_config:
raise ValueError("`tool_resources` argument are not supported in the openai assistant V1 API.")
if any(tool["type"] == "file_search" for tool in tools):
raise ValueError(
"`file_search` tool are not supported in the openai assistant V1 API, please use `retrieval`."
)
assistant_create_kwargs["tools"] = tools
assistant_create_kwargs["file_ids"] = assistant_config.get("file_ids", [])
logging.info(f"Creating assistant with config: {assistant_create_kwargs}")
return client.beta.assistants.create(name=name, instructions=instructions, model=model, **assistant_create_kwargs)
def update_gpt_assistant(client: "OpenAI", assistant_id: str, assistant_config: dict[str, Any]) -> "Assistant":
"""Update openai gpt assistant"""
gpt_assistant_api_version = detect_gpt_assistant_api_version()
assistant_update_kwargs = {}
if assistant_config.get("tools") is not None:
assistant_update_kwargs["tools"] = assistant_config["tools"]
if assistant_config.get("instructions") is not None:
assistant_update_kwargs["instructions"] = assistant_config["instructions"]
if gpt_assistant_api_version == "v2":
if assistant_config.get("tool_resources") is not None:
assistant_update_kwargs["tool_resources"] = assistant_config["tool_resources"]
else:
if assistant_config.get("file_ids") is not None:
assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]
return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
if isinstance(config_value, list):
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
else:
return config_value in acceptable_values

View File

@@ -0,0 +1,370 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Create an OpenAI-compatible client using Together.AI's API.
Example:
```python
llm_config = {
"config_list": [
{
"api_type": "together",
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"api_key": os.environ.get("TOGETHER_API_KEY"),
}
]
}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
```
Install Together.AI python library using: pip install --upgrade together
Resources:
- https://docs.together.ai/docs/inference-python
"""
from __future__ import annotations
import copy
import os
import time
import warnings
from typing import Any, Literal, Optional, Union
from pydantic import Field
from ..import_utils import optional_import_block, require_optional_import
from ..llm_config import LLMConfigEntry, register_llm_config
from .client_utils import should_hide_tools, validate_parameter
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
with optional_import_block():
from together import Together
@register_llm_config
class TogetherLLMConfigEntry(LLMConfigEntry):
api_type: Literal["together"] = "together"
max_tokens: int = Field(default=512, ge=0)
stream: bool = False
temperature: Optional[float] = Field(default=None)
top_p: Optional[float] = Field(default=None)
top_k: Optional[int] = Field(default=None)
repetition_penalty: Optional[float] = Field(default=None)
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
min_p: Optional[float] = Field(default=None, ge=0, le=1)
safety_model: Optional[str] = None
hide_tools: Literal["if_all_run", "if_any_run", "never"] = "never"
price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
tool_choice: Optional[Union[str, dict[str, Union[str, dict[str, str]]]]] = (
None # dict is the tool to call: {"type": "function", "function": {"name": "my_function"}}
)
def create_client(self):
raise NotImplementedError("TogetherLLMConfigEntry.create_client is not implemented.")
class TogetherClient:
"""Client for Together.AI's API."""
def __init__(self, **kwargs):
"""Requires api_key or environment variable to be set
Args:
**kwargs: Additional keyword arguments to pass to the client.
"""
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key")
if not self.api_key:
self.api_key = os.getenv("TOGETHER_API_KEY")
if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Together.AI, it will be ignored.", UserWarning)
assert self.api_key, (
"Please include the api_key in your config list entry for Together.AI or set the TOGETHER_API_KEY env variable."
)
def message_retrieval(self, response) -> list:
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
def cost(self, response) -> float:
return response.cost
@staticmethod
def get_usage(response) -> dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def parse_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Loads the parameters for Together.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
together_params = {}
# Check that we have what we need to use Together.AI's API
together_params["model"] = params.get("model")
assert together_params["model"], (
"Please specify the 'model' in your config list entry to nominate the Together.AI model to use."
)
# Validate allowed Together.AI parameters
# https://github.com/togethercomputer/together-python/blob/94ffb30daf0ac3e078be986af7228f85f79bde99/src/together/resources/completions.py#L44
together_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, 512, (0, None), None)
together_params["stream"] = validate_parameter(params, "stream", bool, False, False, None, None)
together_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, None, None, None)
together_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
together_params["top_k"] = validate_parameter(params, "top_k", int, True, None, None, None)
together_params["repetition_penalty"] = validate_parameter(
params, "repetition_penalty", float, True, None, None, None
)
together_params["presence_penalty"] = validate_parameter(
params, "presence_penalty", (int, float), True, None, (-2, 2), None
)
together_params["frequency_penalty"] = validate_parameter(
params, "frequency_penalty", (int, float), True, None, (-2, 2), None
)
together_params["min_p"] = validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None)
together_params["safety_model"] = validate_parameter(
params, "safety_model", str, True, None, None, None
) # We won't enforce the available models as they are likely to change
# Check if they want to stream and use tools, which isn't currently supported (TODO)
if together_params["stream"] and "tools" in params:
warnings.warn(
"Streaming is not supported when using tools, streaming will be disabled.",
UserWarning,
)
together_params["stream"] = False
if "tool_choice" in params:
together_params["tool_choice"] = params["tool_choice"]
return together_params
@require_optional_import("together", "together")
def create(self, params: dict) -> ChatCompletion:
messages = params.get("messages", [])
# Convert AG2 messages to Together.AI messages
together_messages = oai_messages_to_together_messages(messages)
# Parse parameters to Together.AI API's parameters
together_params = self.parse_params(params)
# Add tools to the call if we have them and aren't hiding them
if "tools" in params:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
if not should_hide_tools(together_messages, params["tools"], hide_tools):
together_params["tools"] = params["tools"]
together_params["messages"] = together_messages
# We use chat model by default
client = Together(api_key=self.api_key)
# Token counts will be returned
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
response = client.chat.completions.create(**together_params)
if together_params["stream"]:
# Read in the chunks as they stream
ans = ""
for chunk in response:
ans = ans + (chunk.choices[0].delta.content or "")
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
total_tokens = chunk.usage.total_tokens
else:
ans: str = response.choices[0].message.content
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
total_tokens = response.usage.total_tokens
if response.choices[0].finish_reason == "tool_calls":
together_finish = "tool_calls"
tool_calls = []
for tool_call in response.choices[0].message.tool_calls:
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
type="function",
)
)
else:
together_finish = "stop"
tool_calls = None
# 3. convert output
message = ChatCompletionMessage(
role="assistant",
content=response.choices[0].message.content,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=together_finish, index=0, message=message)]
response_oai = ChatCompletion(
id=response.id,
model=together_params["model"],
created=int(time.time()),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
cost=calculate_together_cost(prompt_tokens, completion_tokens, together_params["model"]),
)
return response_oai
def oai_messages_to_together_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Together.AI format.
We correct for any specific role orders and types.
"""
together_messages = copy.deepcopy(messages)
# If we have a message with role='tool', which occurs when a function is executed, change it to 'user'
for msg in together_messages:
if "role" in msg and msg["role"] == "tool":
msg["role"] = "user"
return together_messages
# MODELS AND COSTS
chat_lang_code_model_sizes = {
"zero-one-ai/Yi-34B-Chat": 34,
"allenai/OLMo-7B-Instruct": 7,
"allenai/OLMo-7B-Twin-2T": 7,
"allenai/OLMo-7B": 7,
"Austism/chronos-hermes-13b": 13,
"deepseek-ai/deepseek-coder-33b-instruct": 33,
"deepseek-ai/deepseek-llm-67b-chat": 67,
"garage-bAInd/Platypus2-70B-instruct": 70,
"google/gemma-2b-it": 2,
"google/gemma-7b-it": 7,
"Gryphe/MythoMax-L2-13b": 13,
"lmsys/vicuna-13b-v1.5": 13,
"lmsys/vicuna-7b-v1.5": 7,
"codellama/CodeLlama-13b-Instruct-hf": 13,
"codellama/CodeLlama-34b-Instruct-hf": 34,
"codellama/CodeLlama-70b-Instruct-hf": 70,
"codellama/CodeLlama-7b-Instruct-hf": 7,
"meta-llama/Llama-2-70b-chat-hf": 70,
"meta-llama/Llama-2-13b-chat-hf": 13,
"meta-llama/Llama-2-7b-chat-hf": 7,
"meta-llama/Llama-3-8b-chat-hf": 8,
"meta-llama/Llama-3-70b-chat-hf": 70,
"mistralai/Mistral-7B-Instruct-v0.1": 7,
"mistralai/Mistral-7B-Instruct-v0.2": 7,
"mistralai/Mistral-7B-Instruct-v0.3": 7,
"NousResearch/Nous-Capybara-7B-V1p9": 7,
"NousResearch/Nous-Hermes-llama-2-7b": 7,
"NousResearch/Nous-Hermes-Llama2-13b": 13,
"NousResearch/Nous-Hermes-2-Yi-34B": 34,
"openchat/openchat-3.5-1210": 7,
"Open-Orca/Mistral-7B-OpenOrca": 7,
"Qwen/Qwen1.5-0.5B-Chat": 0.5,
"Qwen/Qwen1.5-1.8B-Chat": 1.8,
"Qwen/Qwen1.5-4B-Chat": 4,
"Qwen/Qwen1.5-7B-Chat": 7,
"Qwen/Qwen1.5-14B-Chat": 14,
"Qwen/Qwen1.5-32B-Chat": 32,
"Qwen/Qwen1.5-72B-Chat": 72,
"Qwen/Qwen1.5-110B-Chat": 110,
"Qwen/Qwen2-72B-Instruct": 72,
"snorkelai/Snorkel-Mistral-PairRM-DPO": 7,
"togethercomputer/alpaca-7b": 7,
"teknium/OpenHermes-2-Mistral-7B": 7,
"teknium/OpenHermes-2p5-Mistral-7B": 7,
"togethercomputer/Llama-2-7B-32K-Instruct": 7,
"togethercomputer/RedPajama-INCITE-Chat-3B-v1": 3,
"togethercomputer/RedPajama-INCITE-7B-Chat": 7,
"togethercomputer/StripedHyena-Nous-7B": 7,
"Undi95/ReMM-SLERP-L2-13B": 13,
"Undi95/Toppy-M-7B": 7,
"WizardLM/WizardLM-13B-V1.2": 13,
"upstage/SOLAR-10.7B-Instruct-v1.0": 11,
}
# Cost per million tokens based on up to X Billion parameters, e.g. up 4B is $0.1/million
chat_lang_code_model_costs = {4: 0.1, 8: 0.2, 21: 0.3, 41: 0.8, 80: 0.9, 110: 1.8}
mixture_model_sizes = {
"cognitivecomputations/dolphin-2.5-mixtral-8x7b": 56,
"databricks/dbrx-instruct": 132,
"mistralai/Mixtral-8x7B-Instruct-v0.1": 47,
"mistralai/Mixtral-8x22B-Instruct-v0.1": 141,
"NousResearch/Nous-Hermes-2-Mistral-7B-DPO": 7,
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 47,
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT": 47,
"Snowflake/snowflake-arctic-instruct": 480,
}
# Cost per million tokens based on up to X Billion parameters, e.g. up 56B is $0.6/million
mixture_costs = {56: 0.6, 176: 1.2, 480: 2.4}
def calculate_together_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
"""Cost calculation for inference"""
if model_name in chat_lang_code_model_sizes or model_name in mixture_model_sizes:
cost_per_mil = 0
# Chat, Language, Code models
if model_name in chat_lang_code_model_sizes:
size_in_b = chat_lang_code_model_sizes[model_name]
for top_size in chat_lang_code_model_costs:
if size_in_b <= top_size:
cost_per_mil = chat_lang_code_model_costs[top_size]
break
else:
# Mixture-of-experts
size_in_b = mixture_model_sizes[model_name]
for top_size in mixture_costs:
if size_in_b <= top_size:
cost_per_mil = mixture_costs[top_size]
break
if cost_per_mil == 0:
warnings.warn("Model size doesn't align with cost structure.", UserWarning)
return cost_per_mil * ((input_tokens + output_tokens) / 1e6)
else:
# Model is not in our list of models, can't determine the cost
warnings.warn(
"The model isn't catered for costing, to apply costs you can use the 'price' key on your config_list.",
UserWarning,
)
return 0