629 lines
25 KiB
Python
629 lines
25 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
# SPDX-License-Identifier: MIT
|
|
"""Create a compatible client for the Amazon Bedrock Converse API.
|
|
|
|
Example usage:
|
|
Install the `boto3` package by running `pip install --upgrade boto3`.
|
|
- https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
|
|
|
|
```python
|
|
import autogen
|
|
|
|
config_list = [
|
|
{
|
|
"api_type": "bedrock",
|
|
"model": "meta.llama3-1-8b-instruct-v1:0",
|
|
"aws_region": "us-west-2",
|
|
"aws_access_key": "",
|
|
"aws_secret_key": "",
|
|
"price": [0.003, 0.015],
|
|
}
|
|
]
|
|
|
|
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
|
|
```
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
import os
|
|
import re
|
|
import time
|
|
import warnings
|
|
from typing import Any, Literal, Optional
|
|
|
|
import requests
|
|
from pydantic import Field, SecretStr, field_serializer
|
|
|
|
from ..import_utils import optional_import_block, require_optional_import
|
|
from ..llm_config import LLMConfigEntry, register_llm_config
|
|
from .client_utils import validate_parameter
|
|
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
|
|
|
with optional_import_block():
|
|
import boto3
|
|
from botocore.config import Config
|
|
|
|
|
|
@register_llm_config
|
|
class BedrockLLMConfigEntry(LLMConfigEntry):
|
|
api_type: Literal["bedrock"] = "bedrock"
|
|
aws_region: str
|
|
aws_access_key: Optional[SecretStr] = None
|
|
aws_secret_key: Optional[SecretStr] = None
|
|
aws_session_token: Optional[SecretStr] = None
|
|
aws_profile_name: Optional[str] = None
|
|
temperature: Optional[float] = None
|
|
topP: Optional[float] = None # noqa: N815
|
|
maxTokens: Optional[int] = None # noqa: N815
|
|
top_p: Optional[float] = None
|
|
top_k: Optional[int] = None
|
|
k: Optional[int] = None
|
|
seed: Optional[int] = None
|
|
cache_seed: Optional[int] = None
|
|
supports_system_prompts: bool = True
|
|
stream: bool = False
|
|
price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
|
|
timeout: Optional[int] = None
|
|
|
|
@field_serializer("aws_access_key", "aws_secret_key", "aws_session_token", when_used="unless-none")
|
|
def serialize_aws_secrets(self, v: SecretStr) -> str:
|
|
return v.get_secret_value()
|
|
|
|
def create_client(self):
|
|
raise NotImplementedError("BedrockLLMConfigEntry.create_client must be implemented.")
|
|
|
|
|
|
@require_optional_import("boto3", "bedrock")
|
|
class BedrockClient:
|
|
"""Client for Amazon's Bedrock Converse API."""
|
|
|
|
_retries = 5
|
|
|
|
def __init__(self, **kwargs: Any):
|
|
"""Initialises BedrockClient for Amazon's Bedrock Converse API"""
|
|
self._aws_access_key = kwargs.get("aws_access_key")
|
|
self._aws_secret_key = kwargs.get("aws_secret_key")
|
|
self._aws_session_token = kwargs.get("aws_session_token")
|
|
self._aws_region = kwargs.get("aws_region")
|
|
self._aws_profile_name = kwargs.get("aws_profile_name")
|
|
self._timeout = kwargs.get("timeout")
|
|
|
|
if not self._aws_access_key:
|
|
self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
|
|
|
|
if not self._aws_secret_key:
|
|
self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
|
|
|
|
if not self._aws_session_token:
|
|
self._aws_session_token = os.getenv("AWS_SESSION_TOKEN")
|
|
|
|
if not self._aws_region:
|
|
self._aws_region = os.getenv("AWS_REGION")
|
|
|
|
if self._aws_region is None:
|
|
raise ValueError("Region is required to use the Amazon Bedrock API.")
|
|
|
|
if self._timeout is None:
|
|
self._timeout = 60
|
|
|
|
# Initialize Bedrock client, session, and runtime
|
|
bedrock_config = Config(
|
|
region_name=self._aws_region,
|
|
signature_version="v4",
|
|
retries={"max_attempts": self._retries, "mode": "standard"},
|
|
read_timeout=self._timeout,
|
|
)
|
|
|
|
session = boto3.Session(
|
|
aws_access_key_id=self._aws_access_key,
|
|
aws_secret_access_key=self._aws_secret_key,
|
|
aws_session_token=self._aws_session_token,
|
|
profile_name=self._aws_profile_name,
|
|
)
|
|
|
|
if "response_format" in kwargs and kwargs["response_format"] is not None:
|
|
warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning)
|
|
|
|
# if haven't got any access_key or secret_key in environment variable or via arguments then
|
|
if (
|
|
self._aws_access_key is None
|
|
or self._aws_access_key == ""
|
|
or self._aws_secret_key is None
|
|
or self._aws_secret_key == ""
|
|
):
|
|
# attempts to get client from attached role of managed service (lambda, ec2, ecs, etc.)
|
|
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime", config=bedrock_config)
|
|
else:
|
|
session = boto3.Session(
|
|
aws_access_key_id=self._aws_access_key,
|
|
aws_secret_access_key=self._aws_secret_key,
|
|
aws_session_token=self._aws_session_token,
|
|
profile_name=self._aws_profile_name,
|
|
)
|
|
self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)
|
|
|
|
def message_retrieval(self, response):
|
|
"""Retrieve the messages from the response."""
|
|
return [choice.message for choice in response.choices]
|
|
|
|
def parse_custom_params(self, params: dict[str, Any]):
|
|
"""Parses custom parameters for logic in this client class"""
|
|
# Should we separate system messages into its own request parameter, default is True
|
|
# This is required because not all models support a system prompt (e.g. Mistral Instruct).
|
|
self._supports_system_prompts = params.get("supports_system_prompts", True)
|
|
|
|
def parse_params(self, params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
"""Loads the valid parameters required to invoke Bedrock Converse
|
|
Returns a tuple of (base_params, additional_params)
|
|
"""
|
|
base_params = {}
|
|
additional_params = {}
|
|
|
|
# Amazon Bedrock base model IDs are here:
|
|
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
|
|
self._model_id = params.get("model")
|
|
assert self._model_id, "Please provide the 'model` in the config_list to use Amazon Bedrock"
|
|
|
|
# Parameters vary based on the model used.
|
|
# As we won't cater for all models and parameters, it's the developer's
|
|
# responsibility to implement the parameters and they will only be
|
|
# included if the developer has it in the config.
|
|
#
|
|
# Important:
|
|
# No defaults will be used (as they can vary per model)
|
|
# No ranges will be used (as they can vary)
|
|
# We will cover all the main parameters but there may be others
|
|
# that need to be added later
|
|
#
|
|
# Here are some pages that show the parameters available for different models
|
|
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
|
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
|
|
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
|
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
|
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-chat-completion.html
|
|
|
|
# Here are the possible "base" parameters and their suitable types
|
|
base_parameters = [["temperature", (float, int)], ["topP", (float, int)], ["maxTokens", (int)]]
|
|
|
|
for param_name, suitable_types in base_parameters:
|
|
if param_name in params:
|
|
base_params[param_name] = validate_parameter(
|
|
params, param_name, suitable_types, False, None, None, None
|
|
)
|
|
|
|
# Here are the possible "model-specific" parameters and their suitable types, known as additional parameters
|
|
additional_parameters = [
|
|
["top_p", (float, int)],
|
|
["top_k", (int)],
|
|
["k", (int)],
|
|
["seed", (int)],
|
|
]
|
|
|
|
for param_name, suitable_types in additional_parameters:
|
|
if param_name in params:
|
|
additional_params[param_name] = validate_parameter(
|
|
params, param_name, suitable_types, False, None, None, None
|
|
)
|
|
|
|
# Streaming
|
|
self._streaming = params.get("stream", False)
|
|
|
|
# For this release we will not support streaming as many models do not support streaming with tool use
|
|
if self._streaming:
|
|
warnings.warn(
|
|
"Streaming is not currently supported, streaming will be disabled.",
|
|
UserWarning,
|
|
)
|
|
self._streaming = False
|
|
|
|
return base_params, additional_params
|
|
|
|
def create(self, params) -> ChatCompletion:
|
|
"""Run Amazon Bedrock inference and return AG2 response"""
|
|
# Set custom client class settings
|
|
self.parse_custom_params(params)
|
|
|
|
# Parse the inference parameters
|
|
base_params, additional_params = self.parse_params(params)
|
|
|
|
has_tools = "tools" in params
|
|
messages = oai_messages_to_bedrock_messages(params["messages"], has_tools, self._supports_system_prompts)
|
|
|
|
if self._supports_system_prompts:
|
|
system_messages = extract_system_messages(params["messages"])
|
|
|
|
tool_config = format_tools(params["tools"] if has_tools else [])
|
|
|
|
request_args = {"messages": messages, "modelId": self._model_id}
|
|
|
|
# Base and additional args
|
|
if len(base_params) > 0:
|
|
request_args["inferenceConfig"] = base_params
|
|
|
|
if len(additional_params) > 0:
|
|
request_args["additionalModelRequestFields"] = additional_params
|
|
|
|
if self._supports_system_prompts:
|
|
request_args["system"] = system_messages
|
|
|
|
if len(tool_config["tools"]) > 0:
|
|
request_args["toolConfig"] = tool_config
|
|
|
|
response = self.bedrock_runtime.converse(**request_args)
|
|
if response is None:
|
|
raise RuntimeError(f"Failed to get response from Bedrock after retrying {self._retries} times.")
|
|
|
|
finish_reason = convert_stop_reason_to_finish_reason(response["stopReason"])
|
|
response_message = response["output"]["message"]
|
|
|
|
tool_calls = format_tool_calls(response_message["content"]) if finish_reason == "tool_calls" else None
|
|
|
|
text = ""
|
|
for content in response_message["content"]:
|
|
if "text" in content:
|
|
text = content["text"]
|
|
# NOTE: other types of output may be dealt with here
|
|
|
|
message = ChatCompletionMessage(role="assistant", content=text, tool_calls=tool_calls)
|
|
|
|
response_usage = response["usage"]
|
|
usage = CompletionUsage(
|
|
prompt_tokens=response_usage["inputTokens"],
|
|
completion_tokens=response_usage["outputTokens"],
|
|
total_tokens=response_usage["totalTokens"],
|
|
)
|
|
|
|
return ChatCompletion(
|
|
id=response["ResponseMetadata"]["RequestId"],
|
|
choices=[Choice(finish_reason=finish_reason, index=0, message=message)],
|
|
created=int(time.time()),
|
|
model=self._model_id,
|
|
object="chat.completion",
|
|
usage=usage,
|
|
)
|
|
|
|
def cost(self, response: ChatCompletion) -> float:
|
|
"""Calculate the cost of the response."""
|
|
return calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens, response.model)
|
|
|
|
@staticmethod
|
|
def get_usage(response) -> dict:
|
|
"""Get the usage of tokens and their cost information."""
|
|
return {
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
|
"completion_tokens": response.usage.completion_tokens,
|
|
"total_tokens": response.usage.total_tokens,
|
|
"cost": response.cost,
|
|
"model": response.model,
|
|
}
|
|
|
|
|
|
def extract_system_messages(messages: list[dict[str, Any]]) -> list:
|
|
"""Extract the system messages from the list of messages.
|
|
|
|
Args:
|
|
messages (list[dict[str, Any]]): List of messages.
|
|
|
|
Returns:
|
|
List[SystemMessage]: List of System messages.
|
|
"""
|
|
"""
|
|
system_messages = [message.get("content")[0]["text"] for message in messages if message.get("role") == "system"]
|
|
return system_messages # ''.join(system_messages)
|
|
"""
|
|
|
|
for message in messages:
|
|
if message.get("role") == "system":
|
|
if isinstance(message["content"], str):
|
|
return [{"text": message.get("content")}]
|
|
else:
|
|
return [{"text": message.get("content")[0]["text"]}]
|
|
return []
|
|
|
|
|
|
def oai_messages_to_bedrock_messages(
|
|
messages: list[dict[str, Any]], has_tools: bool, supports_system_prompts: bool
|
|
) -> list[dict[str, Any]]:
|
|
"""Convert messages from OAI format to Bedrock format.
|
|
We correct for any specific role orders and types, etc.
|
|
AWS Bedrock requires messages to alternate between user and assistant roles. This function ensures that the messages
|
|
are in the correct order and format for Bedrock by inserting "Please continue" messages as needed.
|
|
This is the same method as the one in the Autogen Anthropic client
|
|
"""
|
|
# Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages.
|
|
# Bedrock requires a tools parameter with the tools listed, if there are other messages with tool use or tool results.
|
|
# This can occur when we don't need tool calling, such as for group chat speaker selection
|
|
|
|
# Convert messages to Bedrock compliant format
|
|
|
|
# Take out system messages if the model supports it, otherwise leave them in.
|
|
if supports_system_prompts:
|
|
messages = [x for x in messages if x["role"] != "system"]
|
|
else:
|
|
# Replace role="system" with role="user"
|
|
for msg in messages:
|
|
if msg["role"] == "system":
|
|
msg["role"] = "user"
|
|
|
|
processed_messages = []
|
|
|
|
# Used to interweave user messages to ensure user/assistant alternating
|
|
user_continue_message = {"content": [{"text": "Please continue."}], "role": "user"}
|
|
assistant_continue_message = {
|
|
"content": [{"text": "Please continue."}],
|
|
"role": "assistant",
|
|
}
|
|
|
|
tool_use_messages = 0
|
|
tool_result_messages = 0
|
|
last_tool_use_index = -1
|
|
last_tool_result_index = -1
|
|
# user_role_index = 0 if supports_system_prompts else 1 # If system prompts are supported, messages start with user, otherwise they'll be the second message
|
|
for message in messages:
|
|
# New messages will be added here, manage role alternations
|
|
expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
|
|
|
|
if "tool_calls" in message:
|
|
# Map the tool call options to Bedrock's format
|
|
tool_uses = []
|
|
tool_names = []
|
|
for tool_call in message["tool_calls"]:
|
|
tool_uses.append({
|
|
"toolUse": {
|
|
"toolUseId": tool_call["id"],
|
|
"name": tool_call["function"]["name"],
|
|
"input": json.loads(tool_call["function"]["arguments"]),
|
|
}
|
|
})
|
|
if has_tools:
|
|
tool_use_messages += 1
|
|
tool_names.append(tool_call["function"]["name"])
|
|
|
|
if expected_role == "user":
|
|
# Insert an extra user message as we will append an assistant message
|
|
processed_messages.append(user_continue_message)
|
|
|
|
if has_tools:
|
|
processed_messages.append({"role": "assistant", "content": tool_uses})
|
|
last_tool_use_index = len(processed_messages) - 1
|
|
else:
|
|
# Not using tools, so put in a plain text message
|
|
processed_messages.append({
|
|
"role": "assistant",
|
|
"content": [{"text": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]"}],
|
|
})
|
|
elif "tool_call_id" in message:
|
|
if has_tools:
|
|
# Map the tool usage call to tool_result for Bedrock
|
|
tool_result = {
|
|
"toolResult": {
|
|
"toolUseId": message["tool_call_id"],
|
|
"content": [{"text": message["content"]}],
|
|
}
|
|
}
|
|
|
|
# If the previous message also had a tool_result, add it to that
|
|
# Otherwise append a new message
|
|
if last_tool_result_index == len(processed_messages) - 1:
|
|
processed_messages[-1]["content"].append(tool_result)
|
|
else:
|
|
if expected_role == "assistant":
|
|
# Insert an extra assistant message as we will append a user message
|
|
processed_messages.append(assistant_continue_message)
|
|
|
|
processed_messages.append({"role": "user", "content": [tool_result]})
|
|
last_tool_result_index = len(processed_messages) - 1
|
|
|
|
tool_result_messages += 1
|
|
else:
|
|
# Not using tools, so put in a plain text message
|
|
processed_messages.append({
|
|
"role": "user",
|
|
"content": [{"text": f"Running the function returned: {message['content']}"}],
|
|
})
|
|
elif message["content"] == "":
|
|
# Ignoring empty messages
|
|
pass
|
|
else:
|
|
if expected_role != message["role"] and not (len(processed_messages) == 0 and message["role"] == "system"):
|
|
# Inserting the alternating continue message (ignore if it's the first message and a system message)
|
|
processed_messages.append(
|
|
user_continue_message if expected_role == "user" else assistant_continue_message
|
|
)
|
|
|
|
processed_messages.append({
|
|
"role": message["role"],
|
|
"content": parse_content_parts(message=message),
|
|
})
|
|
|
|
# We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function)
|
|
if has_tools and tool_use_messages != tool_result_messages:
|
|
processed_messages[last_tool_use_index] = assistant_continue_message
|
|
|
|
# name is not a valid field on messages
|
|
for message in processed_messages:
|
|
if "name" in message:
|
|
message.pop("name", None)
|
|
|
|
# Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response
|
|
# So, if the last role is not user, add a 'user' continue message at the end
|
|
if processed_messages[-1]["role"] != "user":
|
|
processed_messages.append(user_continue_message)
|
|
|
|
return processed_messages
|
|
|
|
|
|
def parse_content_parts(
|
|
message: dict[str, Any],
|
|
) -> list[dict[str, Any]]:
|
|
content: str | list[dict[str, Any]] = message.get("content")
|
|
if isinstance(content, str):
|
|
return [
|
|
{
|
|
"text": content,
|
|
}
|
|
]
|
|
content_parts = []
|
|
for part in content:
|
|
# part_content: Dict = part.get("content")
|
|
if "text" in part: # part_content:
|
|
content_parts.append({
|
|
"text": part.get("text"),
|
|
})
|
|
elif "image_url" in part: # part_content:
|
|
image_data, content_type = parse_image(part.get("image_url").get("url"))
|
|
content_parts.append({
|
|
"image": {
|
|
"format": content_type[6:], # image/
|
|
"source": {"bytes": image_data},
|
|
},
|
|
})
|
|
else:
|
|
# Ignore..
|
|
continue
|
|
return content_parts
|
|
|
|
|
|
def parse_image(image_url: str) -> tuple[bytes, str]:
|
|
"""Try to get the raw data from an image url.
|
|
|
|
Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html
|
|
returns a tuple of (Image Data, Content Type)
|
|
"""
|
|
pattern = r"^data:(image/[a-z]*);base64,\s*"
|
|
content_type = re.search(pattern, image_url)
|
|
# if already base64 encoded.
|
|
# Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp'
|
|
if content_type:
|
|
image_data = re.sub(pattern, "", image_url)
|
|
return base64.b64decode(image_data), content_type.group(1)
|
|
|
|
# Send a request to the image URL
|
|
response = requests.get(image_url)
|
|
# Check if the request was successful
|
|
if response.status_code == 200:
|
|
content_type = response.headers.get("Content-Type")
|
|
if not content_type.startswith("image"):
|
|
content_type = "image/jpeg"
|
|
# Get the image content
|
|
image_content = response.content
|
|
return image_content, content_type
|
|
else:
|
|
raise RuntimeError("Unable to access the image url")
|
|
|
|
|
|
def format_tools(tools: list[dict[str, Any]]) -> dict[Literal["tools"], list[dict[str, Any]]]:
|
|
converted_schema = {"tools": []}
|
|
|
|
for tool in tools:
|
|
if tool["type"] == "function":
|
|
function = tool["function"]
|
|
converted_tool = {
|
|
"toolSpec": {
|
|
"name": function["name"],
|
|
"description": function["description"],
|
|
"inputSchema": {"json": {"type": "object", "properties": {}, "required": []}},
|
|
}
|
|
}
|
|
|
|
for prop_name, prop_details in function["parameters"]["properties"].items():
|
|
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name] = {
|
|
"type": prop_details["type"],
|
|
"description": prop_details.get("description", ""),
|
|
}
|
|
if "enum" in prop_details:
|
|
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["enum"] = prop_details[
|
|
"enum"
|
|
]
|
|
if "default" in prop_details:
|
|
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["default"] = (
|
|
prop_details["default"]
|
|
)
|
|
|
|
if "required" in function["parameters"]:
|
|
converted_tool["toolSpec"]["inputSchema"]["json"]["required"] = function["parameters"]["required"]
|
|
|
|
converted_schema["tools"].append(converted_tool)
|
|
|
|
return converted_schema
|
|
|
|
|
|
def format_tool_calls(content):
|
|
"""Converts Converse API response tool calls to AG2 format"""
|
|
tool_calls = []
|
|
for tool_request in content:
|
|
if "toolUse" in tool_request:
|
|
tool = tool_request["toolUse"]
|
|
|
|
tool_calls.append(
|
|
ChatCompletionMessageToolCall(
|
|
id=tool["toolUseId"],
|
|
function={
|
|
"name": tool["name"],
|
|
"arguments": json.dumps(tool["input"]),
|
|
},
|
|
type="function",
|
|
)
|
|
)
|
|
return tool_calls
|
|
|
|
|
|
def convert_stop_reason_to_finish_reason(
|
|
stop_reason: str,
|
|
) -> Literal["stop", "length", "tool_calls", "content_filter"]:
|
|
"""Converts Bedrock finish reasons to our finish reasons, according to OpenAI:
|
|
|
|
- stop: if the model hit a natural stop point or a provided stop sequence,
|
|
- length: if the maximum number of tokens specified in the request was reached,
|
|
- content_filter: if content was omitted due to a flag from our content filters,
|
|
- tool_calls: if the model called a tool
|
|
"""
|
|
if stop_reason:
|
|
finish_reason_mapping = {
|
|
"tool_use": "tool_calls",
|
|
"finished": "stop",
|
|
"end_turn": "stop",
|
|
"max_tokens": "length",
|
|
"stop_sequence": "stop",
|
|
"complete": "stop",
|
|
"content_filtered": "content_filter",
|
|
}
|
|
return finish_reason_mapping.get(stop_reason.lower(), stop_reason.lower())
|
|
|
|
warnings.warn(f"Unsupported stop reason: {stop_reason}", UserWarning)
|
|
return None
|
|
|
|
|
|
# NOTE: As this will be quite dynamic, it's expected that the developer will use the "price" parameter in their config
|
|
# These may be removed.
|
|
PRICES_PER_K_TOKENS = {
|
|
"meta.llama3-8b-instruct-v1:0": (0.0003, 0.0006),
|
|
"meta.llama3-70b-instruct-v1:0": (0.00265, 0.0035),
|
|
"mistral.mistral-7b-instruct-v0:2": (0.00015, 0.0002),
|
|
"mistral.mixtral-8x7b-instruct-v0:1": (0.00045, 0.0007),
|
|
"mistral.mistral-large-2402-v1:0": (0.004, 0.012),
|
|
"mistral.mistral-small-2402-v1:0": (0.001, 0.003),
|
|
}
|
|
|
|
|
|
def calculate_cost(input_tokens: int, output_tokens: int, model_id: str) -> float:
|
|
"""Calculate the cost of the completion using the Bedrock pricing."""
|
|
if model_id in PRICES_PER_K_TOKENS:
|
|
input_cost_per_k, output_cost_per_k = PRICES_PER_K_TOKENS[model_id]
|
|
input_cost = (input_tokens / 1000) * input_cost_per_k
|
|
output_cost = (output_tokens / 1000) * output_cost_per_k
|
|
return input_cost + output_cost
|
|
else:
|
|
warnings.warn(
|
|
f'Cannot get the costs for {model_id}. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.',
|
|
UserWarning,
|
|
)
|
|
return 0
|