CoACT initialize (#292)
This commit is contained in:
628
mm_agents/coact/autogen/oai/bedrock.py
Normal file
628
mm_agents/coact/autogen/oai/bedrock.py
Normal file
@@ -0,0 +1,628 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
"""Create a compatible client for the Amazon Bedrock Converse API.
|
||||
|
||||
Example usage:
|
||||
Install the `boto3` package by running `pip install --upgrade boto3`.
|
||||
- https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
|
||||
|
||||
```python
|
||||
import autogen
|
||||
|
||||
config_list = [
|
||||
{
|
||||
"api_type": "bedrock",
|
||||
"model": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"aws_region": "us-west-2",
|
||||
"aws_access_key": "",
|
||||
"aws_secret_key": "",
|
||||
"price": [0.003, 0.015],
|
||||
}
|
||||
]
|
||||
|
||||
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
|
||||
```
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import Field, SecretStr, field_serializer
|
||||
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from ..llm_config import LLMConfigEntry, register_llm_config
|
||||
from .client_utils import validate_parameter
|
||||
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
||||
|
||||
with optional_import_block():
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
|
||||
@register_llm_config
|
||||
class BedrockLLMConfigEntry(LLMConfigEntry):
|
||||
api_type: Literal["bedrock"] = "bedrock"
|
||||
aws_region: str
|
||||
aws_access_key: Optional[SecretStr] = None
|
||||
aws_secret_key: Optional[SecretStr] = None
|
||||
aws_session_token: Optional[SecretStr] = None
|
||||
aws_profile_name: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None # noqa: N815
|
||||
maxTokens: Optional[int] = None # noqa: N815
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
k: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
cache_seed: Optional[int] = None
|
||||
supports_system_prompts: bool = True
|
||||
stream: bool = False
|
||||
price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
|
||||
timeout: Optional[int] = None
|
||||
|
||||
@field_serializer("aws_access_key", "aws_secret_key", "aws_session_token", when_used="unless-none")
|
||||
def serialize_aws_secrets(self, v: SecretStr) -> str:
|
||||
return v.get_secret_value()
|
||||
|
||||
def create_client(self):
|
||||
raise NotImplementedError("BedrockLLMConfigEntry.create_client must be implemented.")
|
||||
|
||||
|
||||
@require_optional_import("boto3", "bedrock")
|
||||
class BedrockClient:
|
||||
"""Client for Amazon's Bedrock Converse API."""
|
||||
|
||||
_retries = 5
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialises BedrockClient for Amazon's Bedrock Converse API"""
|
||||
self._aws_access_key = kwargs.get("aws_access_key")
|
||||
self._aws_secret_key = kwargs.get("aws_secret_key")
|
||||
self._aws_session_token = kwargs.get("aws_session_token")
|
||||
self._aws_region = kwargs.get("aws_region")
|
||||
self._aws_profile_name = kwargs.get("aws_profile_name")
|
||||
self._timeout = kwargs.get("timeout")
|
||||
|
||||
if not self._aws_access_key:
|
||||
self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
|
||||
|
||||
if not self._aws_secret_key:
|
||||
self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
|
||||
|
||||
if not self._aws_session_token:
|
||||
self._aws_session_token = os.getenv("AWS_SESSION_TOKEN")
|
||||
|
||||
if not self._aws_region:
|
||||
self._aws_region = os.getenv("AWS_REGION")
|
||||
|
||||
if self._aws_region is None:
|
||||
raise ValueError("Region is required to use the Amazon Bedrock API.")
|
||||
|
||||
if self._timeout is None:
|
||||
self._timeout = 60
|
||||
|
||||
# Initialize Bedrock client, session, and runtime
|
||||
bedrock_config = Config(
|
||||
region_name=self._aws_region,
|
||||
signature_version="v4",
|
||||
retries={"max_attempts": self._retries, "mode": "standard"},
|
||||
read_timeout=self._timeout,
|
||||
)
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=self._aws_access_key,
|
||||
aws_secret_access_key=self._aws_secret_key,
|
||||
aws_session_token=self._aws_session_token,
|
||||
profile_name=self._aws_profile_name,
|
||||
)
|
||||
|
||||
if "response_format" in kwargs and kwargs["response_format"] is not None:
|
||||
warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning)
|
||||
|
||||
# if haven't got any access_key or secret_key in environment variable or via arguments then
|
||||
if (
|
||||
self._aws_access_key is None
|
||||
or self._aws_access_key == ""
|
||||
or self._aws_secret_key is None
|
||||
or self._aws_secret_key == ""
|
||||
):
|
||||
# attempts to get client from attached role of managed service (lambda, ec2, ecs, etc.)
|
||||
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime", config=bedrock_config)
|
||||
else:
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=self._aws_access_key,
|
||||
aws_secret_access_key=self._aws_secret_key,
|
||||
aws_session_token=self._aws_session_token,
|
||||
profile_name=self._aws_profile_name,
|
||||
)
|
||||
self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)
|
||||
|
||||
def message_retrieval(self, response):
|
||||
"""Retrieve the messages from the response."""
|
||||
return [choice.message for choice in response.choices]
|
||||
|
||||
def parse_custom_params(self, params: dict[str, Any]):
|
||||
"""Parses custom parameters for logic in this client class"""
|
||||
# Should we separate system messages into its own request parameter, default is True
|
||||
# This is required because not all models support a system prompt (e.g. Mistral Instruct).
|
||||
self._supports_system_prompts = params.get("supports_system_prompts", True)
|
||||
|
||||
def parse_params(self, params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Loads the valid parameters required to invoke Bedrock Converse
|
||||
Returns a tuple of (base_params, additional_params)
|
||||
"""
|
||||
base_params = {}
|
||||
additional_params = {}
|
||||
|
||||
# Amazon Bedrock base model IDs are here:
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
|
||||
self._model_id = params.get("model")
|
||||
assert self._model_id, "Please provide the 'model` in the config_list to use Amazon Bedrock"
|
||||
|
||||
# Parameters vary based on the model used.
|
||||
# As we won't cater for all models and parameters, it's the developer's
|
||||
# responsibility to implement the parameters and they will only be
|
||||
# included if the developer has it in the config.
|
||||
#
|
||||
# Important:
|
||||
# No defaults will be used (as they can vary per model)
|
||||
# No ranges will be used (as they can vary)
|
||||
# We will cover all the main parameters but there may be others
|
||||
# that need to be added later
|
||||
#
|
||||
# Here are some pages that show the parameters available for different models
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-chat-completion.html
|
||||
|
||||
# Here are the possible "base" parameters and their suitable types
|
||||
base_parameters = [["temperature", (float, int)], ["topP", (float, int)], ["maxTokens", (int)]]
|
||||
|
||||
for param_name, suitable_types in base_parameters:
|
||||
if param_name in params:
|
||||
base_params[param_name] = validate_parameter(
|
||||
params, param_name, suitable_types, False, None, None, None
|
||||
)
|
||||
|
||||
# Here are the possible "model-specific" parameters and their suitable types, known as additional parameters
|
||||
additional_parameters = [
|
||||
["top_p", (float, int)],
|
||||
["top_k", (int)],
|
||||
["k", (int)],
|
||||
["seed", (int)],
|
||||
]
|
||||
|
||||
for param_name, suitable_types in additional_parameters:
|
||||
if param_name in params:
|
||||
additional_params[param_name] = validate_parameter(
|
||||
params, param_name, suitable_types, False, None, None, None
|
||||
)
|
||||
|
||||
# Streaming
|
||||
self._streaming = params.get("stream", False)
|
||||
|
||||
# For this release we will not support streaming as many models do not support streaming with tool use
|
||||
if self._streaming:
|
||||
warnings.warn(
|
||||
"Streaming is not currently supported, streaming will be disabled.",
|
||||
UserWarning,
|
||||
)
|
||||
self._streaming = False
|
||||
|
||||
return base_params, additional_params
|
||||
|
||||
def create(self, params) -> ChatCompletion:
|
||||
"""Run Amazon Bedrock inference and return AG2 response"""
|
||||
# Set custom client class settings
|
||||
self.parse_custom_params(params)
|
||||
|
||||
# Parse the inference parameters
|
||||
base_params, additional_params = self.parse_params(params)
|
||||
|
||||
has_tools = "tools" in params
|
||||
messages = oai_messages_to_bedrock_messages(params["messages"], has_tools, self._supports_system_prompts)
|
||||
|
||||
if self._supports_system_prompts:
|
||||
system_messages = extract_system_messages(params["messages"])
|
||||
|
||||
tool_config = format_tools(params["tools"] if has_tools else [])
|
||||
|
||||
request_args = {"messages": messages, "modelId": self._model_id}
|
||||
|
||||
# Base and additional args
|
||||
if len(base_params) > 0:
|
||||
request_args["inferenceConfig"] = base_params
|
||||
|
||||
if len(additional_params) > 0:
|
||||
request_args["additionalModelRequestFields"] = additional_params
|
||||
|
||||
if self._supports_system_prompts:
|
||||
request_args["system"] = system_messages
|
||||
|
||||
if len(tool_config["tools"]) > 0:
|
||||
request_args["toolConfig"] = tool_config
|
||||
|
||||
response = self.bedrock_runtime.converse(**request_args)
|
||||
if response is None:
|
||||
raise RuntimeError(f"Failed to get response from Bedrock after retrying {self._retries} times.")
|
||||
|
||||
finish_reason = convert_stop_reason_to_finish_reason(response["stopReason"])
|
||||
response_message = response["output"]["message"]
|
||||
|
||||
tool_calls = format_tool_calls(response_message["content"]) if finish_reason == "tool_calls" else None
|
||||
|
||||
text = ""
|
||||
for content in response_message["content"]:
|
||||
if "text" in content:
|
||||
text = content["text"]
|
||||
# NOTE: other types of output may be dealt with here
|
||||
|
||||
message = ChatCompletionMessage(role="assistant", content=text, tool_calls=tool_calls)
|
||||
|
||||
response_usage = response["usage"]
|
||||
usage = CompletionUsage(
|
||||
prompt_tokens=response_usage["inputTokens"],
|
||||
completion_tokens=response_usage["outputTokens"],
|
||||
total_tokens=response_usage["totalTokens"],
|
||||
)
|
||||
|
||||
return ChatCompletion(
|
||||
id=response["ResponseMetadata"]["RequestId"],
|
||||
choices=[Choice(finish_reason=finish_reason, index=0, message=message)],
|
||||
created=int(time.time()),
|
||||
model=self._model_id,
|
||||
object="chat.completion",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def cost(self, response: ChatCompletion) -> float:
|
||||
"""Calculate the cost of the response."""
|
||||
return calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens, response.model)
|
||||
|
||||
@staticmethod
|
||||
def get_usage(response) -> dict:
|
||||
"""Get the usage of tokens and their cost information."""
|
||||
return {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"cost": response.cost,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
def extract_system_messages(messages: list[dict[str, Any]]) -> list:
|
||||
"""Extract the system messages from the list of messages.
|
||||
|
||||
Args:
|
||||
messages (list[dict[str, Any]]): List of messages.
|
||||
|
||||
Returns:
|
||||
List[SystemMessage]: List of System messages.
|
||||
"""
|
||||
"""
|
||||
system_messages = [message.get("content")[0]["text"] for message in messages if message.get("role") == "system"]
|
||||
return system_messages # ''.join(system_messages)
|
||||
"""
|
||||
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
if isinstance(message["content"], str):
|
||||
return [{"text": message.get("content")}]
|
||||
else:
|
||||
return [{"text": message.get("content")[0]["text"]}]
|
||||
return []
|
||||
|
||||
|
||||
def oai_messages_to_bedrock_messages(
|
||||
messages: list[dict[str, Any]], has_tools: bool, supports_system_prompts: bool
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Bedrock format.
|
||||
We correct for any specific role orders and types, etc.
|
||||
AWS Bedrock requires messages to alternate between user and assistant roles. This function ensures that the messages
|
||||
are in the correct order and format for Bedrock by inserting "Please continue" messages as needed.
|
||||
This is the same method as the one in the Autogen Anthropic client
|
||||
"""
|
||||
# Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages.
|
||||
# Bedrock requires a tools parameter with the tools listed, if there are other messages with tool use or tool results.
|
||||
# This can occur when we don't need tool calling, such as for group chat speaker selection
|
||||
|
||||
# Convert messages to Bedrock compliant format
|
||||
|
||||
# Take out system messages if the model supports it, otherwise leave them in.
|
||||
if supports_system_prompts:
|
||||
messages = [x for x in messages if x["role"] != "system"]
|
||||
else:
|
||||
# Replace role="system" with role="user"
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
msg["role"] = "user"
|
||||
|
||||
processed_messages = []
|
||||
|
||||
# Used to interweave user messages to ensure user/assistant alternating
|
||||
user_continue_message = {"content": [{"text": "Please continue."}], "role": "user"}
|
||||
assistant_continue_message = {
|
||||
"content": [{"text": "Please continue."}],
|
||||
"role": "assistant",
|
||||
}
|
||||
|
||||
tool_use_messages = 0
|
||||
tool_result_messages = 0
|
||||
last_tool_use_index = -1
|
||||
last_tool_result_index = -1
|
||||
# user_role_index = 0 if supports_system_prompts else 1 # If system prompts are supported, messages start with user, otherwise they'll be the second message
|
||||
for message in messages:
|
||||
# New messages will be added here, manage role alternations
|
||||
expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
|
||||
|
||||
if "tool_calls" in message:
|
||||
# Map the tool call options to Bedrock's format
|
||||
tool_uses = []
|
||||
tool_names = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_uses.append({
|
||||
"toolUse": {
|
||||
"toolUseId": tool_call["id"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"input": json.loads(tool_call["function"]["arguments"]),
|
||||
}
|
||||
})
|
||||
if has_tools:
|
||||
tool_use_messages += 1
|
||||
tool_names.append(tool_call["function"]["name"])
|
||||
|
||||
if expected_role == "user":
|
||||
# Insert an extra user message as we will append an assistant message
|
||||
processed_messages.append(user_continue_message)
|
||||
|
||||
if has_tools:
|
||||
processed_messages.append({"role": "assistant", "content": tool_uses})
|
||||
last_tool_use_index = len(processed_messages) - 1
|
||||
else:
|
||||
# Not using tools, so put in a plain text message
|
||||
processed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"text": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]"}],
|
||||
})
|
||||
elif "tool_call_id" in message:
|
||||
if has_tools:
|
||||
# Map the tool usage call to tool_result for Bedrock
|
||||
tool_result = {
|
||||
"toolResult": {
|
||||
"toolUseId": message["tool_call_id"],
|
||||
"content": [{"text": message["content"]}],
|
||||
}
|
||||
}
|
||||
|
||||
# If the previous message also had a tool_result, add it to that
|
||||
# Otherwise append a new message
|
||||
if last_tool_result_index == len(processed_messages) - 1:
|
||||
processed_messages[-1]["content"].append(tool_result)
|
||||
else:
|
||||
if expected_role == "assistant":
|
||||
# Insert an extra assistant message as we will append a user message
|
||||
processed_messages.append(assistant_continue_message)
|
||||
|
||||
processed_messages.append({"role": "user", "content": [tool_result]})
|
||||
last_tool_result_index = len(processed_messages) - 1
|
||||
|
||||
tool_result_messages += 1
|
||||
else:
|
||||
# Not using tools, so put in a plain text message
|
||||
processed_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"text": f"Running the function returned: {message['content']}"}],
|
||||
})
|
||||
elif message["content"] == "":
|
||||
# Ignoring empty messages
|
||||
pass
|
||||
else:
|
||||
if expected_role != message["role"] and not (len(processed_messages) == 0 and message["role"] == "system"):
|
||||
# Inserting the alternating continue message (ignore if it's the first message and a system message)
|
||||
processed_messages.append(
|
||||
user_continue_message if expected_role == "user" else assistant_continue_message
|
||||
)
|
||||
|
||||
processed_messages.append({
|
||||
"role": message["role"],
|
||||
"content": parse_content_parts(message=message),
|
||||
})
|
||||
|
||||
# We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function)
|
||||
if has_tools and tool_use_messages != tool_result_messages:
|
||||
processed_messages[last_tool_use_index] = assistant_continue_message
|
||||
|
||||
# name is not a valid field on messages
|
||||
for message in processed_messages:
|
||||
if "name" in message:
|
||||
message.pop("name", None)
|
||||
|
||||
# Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response
|
||||
# So, if the last role is not user, add a 'user' continue message at the end
|
||||
if processed_messages[-1]["role"] != "user":
|
||||
processed_messages.append(user_continue_message)
|
||||
|
||||
return processed_messages
|
||||
|
||||
|
||||
def parse_content_parts(
|
||||
message: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
content: str | list[dict[str, Any]] = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return [
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
]
|
||||
content_parts = []
|
||||
for part in content:
|
||||
# part_content: Dict = part.get("content")
|
||||
if "text" in part: # part_content:
|
||||
content_parts.append({
|
||||
"text": part.get("text"),
|
||||
})
|
||||
elif "image_url" in part: # part_content:
|
||||
image_data, content_type = parse_image(part.get("image_url").get("url"))
|
||||
content_parts.append({
|
||||
"image": {
|
||||
"format": content_type[6:], # image/
|
||||
"source": {"bytes": image_data},
|
||||
},
|
||||
})
|
||||
else:
|
||||
# Ignore..
|
||||
continue
|
||||
return content_parts
|
||||
|
||||
|
||||
def parse_image(image_url: str) -> tuple[bytes, str]:
|
||||
"""Try to get the raw data from an image url.
|
||||
|
||||
Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html
|
||||
returns a tuple of (Image Data, Content Type)
|
||||
"""
|
||||
pattern = r"^data:(image/[a-z]*);base64,\s*"
|
||||
content_type = re.search(pattern, image_url)
|
||||
# if already base64 encoded.
|
||||
# Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp'
|
||||
if content_type:
|
||||
image_data = re.sub(pattern, "", image_url)
|
||||
return base64.b64decode(image_data), content_type.group(1)
|
||||
|
||||
# Send a request to the image URL
|
||||
response = requests.get(image_url)
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if not content_type.startswith("image"):
|
||||
content_type = "image/jpeg"
|
||||
# Get the image content
|
||||
image_content = response.content
|
||||
return image_content, content_type
|
||||
else:
|
||||
raise RuntimeError("Unable to access the image url")
|
||||
|
||||
|
||||
def format_tools(tools: list[dict[str, Any]]) -> dict[Literal["tools"], list[dict[str, Any]]]:
|
||||
converted_schema = {"tools": []}
|
||||
|
||||
for tool in tools:
|
||||
if tool["type"] == "function":
|
||||
function = tool["function"]
|
||||
converted_tool = {
|
||||
"toolSpec": {
|
||||
"name": function["name"],
|
||||
"description": function["description"],
|
||||
"inputSchema": {"json": {"type": "object", "properties": {}, "required": []}},
|
||||
}
|
||||
}
|
||||
|
||||
for prop_name, prop_details in function["parameters"]["properties"].items():
|
||||
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name] = {
|
||||
"type": prop_details["type"],
|
||||
"description": prop_details.get("description", ""),
|
||||
}
|
||||
if "enum" in prop_details:
|
||||
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["enum"] = prop_details[
|
||||
"enum"
|
||||
]
|
||||
if "default" in prop_details:
|
||||
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["default"] = (
|
||||
prop_details["default"]
|
||||
)
|
||||
|
||||
if "required" in function["parameters"]:
|
||||
converted_tool["toolSpec"]["inputSchema"]["json"]["required"] = function["parameters"]["required"]
|
||||
|
||||
converted_schema["tools"].append(converted_tool)
|
||||
|
||||
return converted_schema
|
||||
|
||||
|
||||
def format_tool_calls(content):
|
||||
"""Converts Converse API response tool calls to AG2 format"""
|
||||
tool_calls = []
|
||||
for tool_request in content:
|
||||
if "toolUse" in tool_request:
|
||||
tool = tool_request["toolUse"]
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=tool["toolUseId"],
|
||||
function={
|
||||
"name": tool["name"],
|
||||
"arguments": json.dumps(tool["input"]),
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
def convert_stop_reason_to_finish_reason(
|
||||
stop_reason: str,
|
||||
) -> Literal["stop", "length", "tool_calls", "content_filter"]:
|
||||
"""Converts Bedrock finish reasons to our finish reasons, according to OpenAI:
|
||||
|
||||
- stop: if the model hit a natural stop point or a provided stop sequence,
|
||||
- length: if the maximum number of tokens specified in the request was reached,
|
||||
- content_filter: if content was omitted due to a flag from our content filters,
|
||||
- tool_calls: if the model called a tool
|
||||
"""
|
||||
if stop_reason:
|
||||
finish_reason_mapping = {
|
||||
"tool_use": "tool_calls",
|
||||
"finished": "stop",
|
||||
"end_turn": "stop",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"complete": "stop",
|
||||
"content_filtered": "content_filter",
|
||||
}
|
||||
return finish_reason_mapping.get(stop_reason.lower(), stop_reason.lower())
|
||||
|
||||
warnings.warn(f"Unsupported stop reason: {stop_reason}", UserWarning)
|
||||
return None
|
||||
|
||||
|
||||
# NOTE: As this will be quite dynamic, it's expected that the developer will use the "price" parameter in their config
|
||||
# These may be removed.
|
||||
PRICES_PER_K_TOKENS = {
|
||||
"meta.llama3-8b-instruct-v1:0": (0.0003, 0.0006),
|
||||
"meta.llama3-70b-instruct-v1:0": (0.00265, 0.0035),
|
||||
"mistral.mistral-7b-instruct-v0:2": (0.00015, 0.0002),
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": (0.00045, 0.0007),
|
||||
"mistral.mistral-large-2402-v1:0": (0.004, 0.012),
|
||||
"mistral.mistral-small-2402-v1:0": (0.001, 0.003),
|
||||
}
|
||||
|
||||
|
||||
def calculate_cost(input_tokens: int, output_tokens: int, model_id: str) -> float:
|
||||
"""Calculate the cost of the completion using the Bedrock pricing."""
|
||||
if model_id in PRICES_PER_K_TOKENS:
|
||||
input_cost_per_k, output_cost_per_k = PRICES_PER_K_TOKENS[model_id]
|
||||
input_cost = (input_tokens / 1000) * input_cost_per_k
|
||||
output_cost = (output_tokens / 1000) * output_cost_per_k
|
||||
return input_cost + output_cost
|
||||
else:
|
||||
warnings.warn(
|
||||
f'Cannot get the costs for {model_id}. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.',
|
||||
UserWarning,
|
||||
)
|
||||
return 0
|
||||
Reference in New Issue
Block a user