207 lines
8.0 KiB
Python
207 lines
8.0 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
# SPDX-License-Identifier: MIT
|
|
import re
|
|
from typing import Any, Optional, Union
|
|
|
|
from ..doc_utils import export_module
|
|
from .agent import Agent
|
|
|
|
|
|
def consolidate_chat_info(
|
|
chat_info: Union[dict[str, Any], list[dict[str, Any]]], uniform_sender: Optional[Agent] = None
|
|
) -> None:
|
|
if isinstance(chat_info, dict):
|
|
chat_info = [chat_info]
|
|
for c in chat_info:
|
|
if uniform_sender is None:
|
|
assert "sender" in c, "sender must be provided."
|
|
sender = c["sender"]
|
|
else:
|
|
sender = uniform_sender
|
|
assert "recipient" in c, "recipient must be provided."
|
|
summary_method = c.get("summary_method")
|
|
assert (
|
|
summary_method is None or callable(summary_method) or summary_method in ("last_msg", "reflection_with_llm")
|
|
), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None."
|
|
if summary_method == "reflection_with_llm":
|
|
assert sender.client is not None or c["recipient"].client is not None, (
|
|
"llm client must be set in either the recipient or sender when summary_method is reflection_with_llm."
|
|
)
|
|
|
|
|
|
@export_module("autogen")
|
|
def gather_usage_summary(agents: list[Agent]) -> dict[str, dict[str, Any]]:
|
|
r"""Gather usage summary from all agents.
|
|
|
|
Args:
|
|
agents: (list): List of agents.
|
|
|
|
Returns:
|
|
dictionary: A dictionary containing two keys:
|
|
- "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
|
|
- "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
|
|
|
|
Example:
|
|
```python
|
|
{
|
|
"usage_including_cached_inference": {
|
|
"total_cost": 0.0006090000000000001,
|
|
"gpt-35-turbo": {
|
|
"cost": 0.0006090000000000001,
|
|
"prompt_tokens": 242,
|
|
"completion_tokens": 123,
|
|
"total_tokens": 365,
|
|
},
|
|
},
|
|
"usage_excluding_cached_inference": {
|
|
"total_cost": 0.0006090000000000001,
|
|
"gpt-35-turbo": {
|
|
"cost": 0.0006090000000000001,
|
|
"prompt_tokens": 242,
|
|
"completion_tokens": 123,
|
|
"total_tokens": 365,
|
|
},
|
|
},
|
|
}
|
|
```
|
|
|
|
Note:
|
|
If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`.
|
|
"""
|
|
|
|
def aggregate_summary(usage_summary: dict[str, Any], agent_summary: dict[str, Any]) -> None:
|
|
if agent_summary is None:
|
|
return
|
|
usage_summary["total_cost"] += agent_summary.get("total_cost", 0)
|
|
for model, data in agent_summary.items():
|
|
if model != "total_cost":
|
|
if model not in usage_summary:
|
|
usage_summary[model] = data.copy()
|
|
else:
|
|
usage_summary[model]["cost"] += data.get("cost", 0)
|
|
usage_summary[model]["prompt_tokens"] += data.get("prompt_tokens", 0)
|
|
usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0)
|
|
usage_summary[model]["total_tokens"] += data.get("total_tokens", 0)
|
|
|
|
usage_including_cached_inference = {"total_cost": 0}
|
|
usage_excluding_cached_inference = {"total_cost": 0}
|
|
|
|
for agent in agents:
|
|
if getattr(agent, "client", None):
|
|
aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary) # type: ignore[attr-defined]
|
|
aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary) # type: ignore[attr-defined]
|
|
|
|
return {
|
|
"usage_including_cached_inference": usage_including_cached_inference,
|
|
"usage_excluding_cached_inference": usage_excluding_cached_inference,
|
|
}
|
|
|
|
|
|
def parse_tags_from_content(tag: str, content: Union[str, list[dict[str, Any]]]) -> list[dict[str, Any]]:
|
|
"""Parses HTML style tags from message contents.
|
|
|
|
The parsing is done by looking for patterns in the text that match the format of HTML tags. The tag to be parsed is
|
|
specified as an argument to the function. The function looks for this tag in the text and extracts its content. The
|
|
content of a tag is everything that is inside the tag, between the opening and closing angle brackets. The content
|
|
can be a single string or a set of attribute-value pairs.
|
|
|
|
Examples:
|
|
`<img http://example.com/image.png> -> [{"tag": "img", "attr": {"src": "http://example.com/image.png"}, "match": re.Match}]`
|
|
```<audio text="Hello I'm a robot" prompt="whisper"> ->
|
|
[{"tag": "audio", "attr": {"text": "Hello I'm a robot", "prompt": "whisper"}, "match": re.Match}]```
|
|
|
|
Args:
|
|
tag (str): The HTML style tag to be parsed.
|
|
content (Union[str, list[dict[str, Any]]]): The message content to parse. Can be a string or a list of content
|
|
items.
|
|
|
|
Returns:
|
|
list[dict[str, str]]: A list of dictionaries, where each dictionary represents a parsed tag. Each dictionary
|
|
contains three key-value pairs: 'type' which is the tag, 'attr' which is a dictionary of the parsed attributes,
|
|
and 'match' which is a regular expression match object.
|
|
|
|
Raises:
|
|
ValueError: If the content is not a string or a list.
|
|
"""
|
|
results = []
|
|
if isinstance(content, str):
|
|
results.extend(_parse_tags_from_text(tag, content))
|
|
# Handles case for multimodal messages.
|
|
elif isinstance(content, list):
|
|
for item in content:
|
|
if item.get("type") == "text":
|
|
results.extend(_parse_tags_from_text(tag, item["text"]))
|
|
else:
|
|
raise ValueError(f"content must be str or list, but got {type(content)}")
|
|
|
|
return results
|
|
|
|
|
|
def _parse_tags_from_text(tag: str, text: str) -> list[dict[str, Any]]:
|
|
pattern = re.compile(f"<{tag} (.*?)>")
|
|
|
|
results = []
|
|
for match in re.finditer(pattern, text):
|
|
tag_attr = match.group(1).strip()
|
|
attr = _parse_attributes_from_tags(tag_attr)
|
|
|
|
results.append({"tag": tag, "attr": attr, "match": match})
|
|
return results
|
|
|
|
|
|
def _parse_attributes_from_tags(tag_content: str) -> dict[str, str]:
|
|
pattern = r"([^ ]+)"
|
|
attrs = re.findall(pattern, tag_content)
|
|
reconstructed_attrs = _reconstruct_attributes(attrs)
|
|
|
|
def _append_src_value(content: dict[str, str], value: Any) -> None:
|
|
if "src" in content:
|
|
content["src"] += f" {value}"
|
|
else:
|
|
content["src"] = value
|
|
|
|
content: dict[str, str] = {}
|
|
for attr in reconstructed_attrs:
|
|
if "=" not in attr:
|
|
_append_src_value(content, attr)
|
|
continue
|
|
|
|
key, value = attr.split("=", 1)
|
|
if value.startswith("'") or value.startswith('"'):
|
|
content[key] = value[1:-1] # remove quotes
|
|
else:
|
|
_append_src_value(content, attr)
|
|
|
|
return content
|
|
|
|
|
|
def _reconstruct_attributes(attrs: list[str]) -> list[str]:
|
|
"""Reconstructs attributes from a list of strings where some attributes may be split across multiple elements."""
|
|
|
|
def is_attr(attr: str) -> bool:
|
|
if "=" in attr:
|
|
_, value = attr.split("=", 1)
|
|
if value.startswith("'") or value.startswith('"'):
|
|
return True
|
|
return False
|
|
|
|
reconstructed = []
|
|
found_attr = False
|
|
for attr in attrs:
|
|
if is_attr(attr):
|
|
reconstructed.append(attr)
|
|
found_attr = True
|
|
else:
|
|
if found_attr:
|
|
reconstructed[-1] += f" {attr}"
|
|
found_attr = True
|
|
elif reconstructed:
|
|
reconstructed[-1] += f" {attr}"
|
|
else:
|
|
reconstructed.append(attr)
|
|
return reconstructed
|