172 lines
5.5 KiB
Python
172 lines
5.5 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from typing import Any, Callable, Literal, Optional, Union
|
|
from uuid import UUID
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from ..events import deprecated_by
|
|
from ..events.client_events import StreamEvent, UsageSummaryEvent
|
|
from .base_message import BaseMessage, wrap_message
|
|
|
|
__all__ = ["UsageSummaryMessage"]
|
|
|
|
|
|
class ModelUsageSummary(BaseModel):
|
|
"""Model usage summary."""
|
|
|
|
model: str
|
|
"""Model name."""
|
|
completion_tokens: int
|
|
"""Number of tokens used for completion."""
|
|
cost: float
|
|
"""Cost of the completion."""
|
|
prompt_tokens: int
|
|
"""Number of tokens used for prompt."""
|
|
total_tokens: int
|
|
"""Total number of tokens used."""
|
|
|
|
|
|
class ActualUsageSummary(BaseModel):
|
|
"""Actual usage summary."""
|
|
|
|
usages: Optional[list[ModelUsageSummary]] = None
|
|
"""List of model usage summaries."""
|
|
total_cost: Optional[float] = None
|
|
"""Total cost."""
|
|
|
|
|
|
class TotalUsageSummary(BaseModel):
|
|
"""Total usage summary."""
|
|
|
|
usages: Optional[list[ModelUsageSummary]] = None
|
|
"""List of model usage summaries."""
|
|
total_cost: Optional[float] = None
|
|
"""Total cost."""
|
|
|
|
|
|
Mode = Literal["both", "total", "actual"]
|
|
|
|
|
|
def _change_usage_summary_format(
|
|
actual_usage_summary: Optional[dict[str, Any]] = None, total_usage_summary: Optional[dict[str, Any]] = None
|
|
) -> dict[str, dict[str, Any]]:
|
|
summary: dict[str, Any] = {}
|
|
|
|
for usage_type, usage_summary in {"actual": actual_usage_summary, "total": total_usage_summary}.items():
|
|
if usage_summary is None:
|
|
summary[usage_type] = {"usages": None, "total_cost": None}
|
|
continue
|
|
|
|
usage_summary_altered_format: dict[str, list[dict[str, Any]]] = {"usages": []}
|
|
for k, v in usage_summary.items():
|
|
if isinstance(k, str) and isinstance(v, dict):
|
|
current_usage = {key: value for key, value in v.items()}
|
|
current_usage["model"] = k
|
|
usage_summary_altered_format["usages"].append(current_usage)
|
|
else:
|
|
usage_summary_altered_format[k] = v
|
|
summary[usage_type] = usage_summary_altered_format
|
|
|
|
return summary
|
|
|
|
|
|
@deprecated_by(UsageSummaryEvent)
|
|
@wrap_message
|
|
class UsageSummaryMessage(BaseMessage):
|
|
"""Usage summary message."""
|
|
|
|
actual: ActualUsageSummary
|
|
"""Actual usage summary."""
|
|
total: TotalUsageSummary
|
|
"""Total usage summary."""
|
|
mode: Mode
|
|
"""Mode to display the usage summary."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
uuid: Optional[UUID] = None,
|
|
actual_usage_summary: Optional[dict[str, Any]] = None,
|
|
total_usage_summary: Optional[dict[str, Any]] = None,
|
|
mode: Mode = "both",
|
|
):
|
|
# print(f"{actual_usage_summary=}")
|
|
# print(f"{total_usage_summary=}")
|
|
|
|
summary_dict = _change_usage_summary_format(actual_usage_summary, total_usage_summary)
|
|
|
|
super().__init__(uuid=uuid, **summary_dict, mode=mode)
|
|
|
|
def _print_usage(
|
|
self,
|
|
usage_summary: Union[ActualUsageSummary, TotalUsageSummary],
|
|
usage_type: str = "total",
|
|
f: Optional[Callable[..., Any]] = None,
|
|
) -> None:
|
|
f = f or print
|
|
word_from_type = "including" if usage_type == "total" else "excluding"
|
|
if usage_summary.usages is None or len(usage_summary.usages) == 0:
|
|
f("No actual cost incurred (all completions are using cache).", flush=True)
|
|
return
|
|
|
|
f(f"Usage summary {word_from_type} cached usage: ", flush=True)
|
|
f(f"Total cost: {round(usage_summary.total_cost, 5)}", flush=True) # type: ignore [arg-type]
|
|
|
|
for usage in usage_summary.usages:
|
|
f(
|
|
f"* Model '{usage.model}': cost: {round(usage.cost, 5)}, prompt_tokens: {usage.prompt_tokens}, completion_tokens: {usage.completion_tokens}, total_tokens: {usage.total_tokens}",
|
|
flush=True,
|
|
)
|
|
|
|
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
|
|
f = f or print
|
|
|
|
if self.total.usages is None:
|
|
f('No usage summary. Please call "create" first.', flush=True)
|
|
return
|
|
|
|
f("-" * 100, flush=True)
|
|
if self.mode == "both":
|
|
self._print_usage(self.actual, "actual", f)
|
|
f()
|
|
if self.total.model_dump_json() != self.actual.model_dump_json():
|
|
self._print_usage(self.total, "total", f)
|
|
else:
|
|
f(
|
|
"All completions are non-cached: the total cost with cached completions is the same as actual cost.",
|
|
flush=True,
|
|
)
|
|
elif self.mode == "total":
|
|
self._print_usage(self.total, "total", f)
|
|
elif self.mode == "actual":
|
|
self._print_usage(self.actual, "actual", f)
|
|
else:
|
|
raise ValueError(f'Invalid mode: {self.mode}, choose from "actual", "total", ["actual", "total"]')
|
|
f("-" * 100, flush=True)
|
|
|
|
|
|
@deprecated_by(StreamEvent)
|
|
@wrap_message
|
|
class StreamMessage(BaseMessage):
|
|
"""Stream message."""
|
|
|
|
content: str
|
|
"""Content of the message."""
|
|
|
|
def __init__(self, *, uuid: Optional[UUID] = None, content: str) -> None:
|
|
super().__init__(uuid=uuid, content=content)
|
|
|
|
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
|
|
f = f or print
|
|
|
|
# Set the terminal text color to green
|
|
f("\033[32m", end="")
|
|
|
|
f(self.content, end="", flush=True)
|
|
|
|
# Reset the terminal text color
|
|
f("\033[0m\n")
|