CoACT initialize (#292)
This commit is contained in:
206
mm_agents/coact/autogen/agentchat/utils.py
Normal file
206
mm_agents/coact/autogen/agentchat/utils.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from .agent import Agent
|
||||
|
||||
|
||||
def consolidate_chat_info(
|
||||
chat_info: Union[dict[str, Any], list[dict[str, Any]]], uniform_sender: Optional[Agent] = None
|
||||
) -> None:
|
||||
if isinstance(chat_info, dict):
|
||||
chat_info = [chat_info]
|
||||
for c in chat_info:
|
||||
if uniform_sender is None:
|
||||
assert "sender" in c, "sender must be provided."
|
||||
sender = c["sender"]
|
||||
else:
|
||||
sender = uniform_sender
|
||||
assert "recipient" in c, "recipient must be provided."
|
||||
summary_method = c.get("summary_method")
|
||||
assert (
|
||||
summary_method is None or callable(summary_method) or summary_method in ("last_msg", "reflection_with_llm")
|
||||
), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None."
|
||||
if summary_method == "reflection_with_llm":
|
||||
assert sender.client is not None or c["recipient"].client is not None, (
|
||||
"llm client must be set in either the recipient or sender when summary_method is reflection_with_llm."
|
||||
)
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def gather_usage_summary(agents: list[Agent]) -> dict[str, dict[str, Any]]:
|
||||
r"""Gather usage summary from all agents.
|
||||
|
||||
Args:
|
||||
agents: (list): List of agents.
|
||||
|
||||
Returns:
|
||||
dictionary: A dictionary containing two keys:
|
||||
- "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
|
||||
- "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
|
||||
|
||||
Example:
|
||||
```python
|
||||
{
|
||||
"usage_including_cached_inference": {
|
||||
"total_cost": 0.0006090000000000001,
|
||||
"gpt-35-turbo": {
|
||||
"cost": 0.0006090000000000001,
|
||||
"prompt_tokens": 242,
|
||||
"completion_tokens": 123,
|
||||
"total_tokens": 365,
|
||||
},
|
||||
},
|
||||
"usage_excluding_cached_inference": {
|
||||
"total_cost": 0.0006090000000000001,
|
||||
"gpt-35-turbo": {
|
||||
"cost": 0.0006090000000000001,
|
||||
"prompt_tokens": 242,
|
||||
"completion_tokens": 123,
|
||||
"total_tokens": 365,
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
Note:
|
||||
If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`.
|
||||
"""
|
||||
|
||||
def aggregate_summary(usage_summary: dict[str, Any], agent_summary: dict[str, Any]) -> None:
|
||||
if agent_summary is None:
|
||||
return
|
||||
usage_summary["total_cost"] += agent_summary.get("total_cost", 0)
|
||||
for model, data in agent_summary.items():
|
||||
if model != "total_cost":
|
||||
if model not in usage_summary:
|
||||
usage_summary[model] = data.copy()
|
||||
else:
|
||||
usage_summary[model]["cost"] += data.get("cost", 0)
|
||||
usage_summary[model]["prompt_tokens"] += data.get("prompt_tokens", 0)
|
||||
usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0)
|
||||
usage_summary[model]["total_tokens"] += data.get("total_tokens", 0)
|
||||
|
||||
usage_including_cached_inference = {"total_cost": 0}
|
||||
usage_excluding_cached_inference = {"total_cost": 0}
|
||||
|
||||
for agent in agents:
|
||||
if getattr(agent, "client", None):
|
||||
aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary) # type: ignore[attr-defined]
|
||||
aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary) # type: ignore[attr-defined]
|
||||
|
||||
return {
|
||||
"usage_including_cached_inference": usage_including_cached_inference,
|
||||
"usage_excluding_cached_inference": usage_excluding_cached_inference,
|
||||
}
|
||||
|
||||
|
||||
def parse_tags_from_content(tag: str, content: Union[str, list[dict[str, Any]]]) -> list[dict[str, Any]]:
|
||||
"""Parses HTML style tags from message contents.
|
||||
|
||||
The parsing is done by looking for patterns in the text that match the format of HTML tags. The tag to be parsed is
|
||||
specified as an argument to the function. The function looks for this tag in the text and extracts its content. The
|
||||
content of a tag is everything that is inside the tag, between the opening and closing angle brackets. The content
|
||||
can be a single string or a set of attribute-value pairs.
|
||||
|
||||
Examples:
|
||||
`<img http://example.com/image.png> -> [{"tag": "img", "attr": {"src": "http://example.com/image.png"}, "match": re.Match}]`
|
||||
```<audio text="Hello I'm a robot" prompt="whisper"> ->
|
||||
[{"tag": "audio", "attr": {"text": "Hello I'm a robot", "prompt": "whisper"}, "match": re.Match}]```
|
||||
|
||||
Args:
|
||||
tag (str): The HTML style tag to be parsed.
|
||||
content (Union[str, list[dict[str, Any]]]): The message content to parse. Can be a string or a list of content
|
||||
items.
|
||||
|
||||
Returns:
|
||||
list[dict[str, str]]: A list of dictionaries, where each dictionary represents a parsed tag. Each dictionary
|
||||
contains three key-value pairs: 'type' which is the tag, 'attr' which is a dictionary of the parsed attributes,
|
||||
and 'match' which is a regular expression match object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the content is not a string or a list.
|
||||
"""
|
||||
results = []
|
||||
if isinstance(content, str):
|
||||
results.extend(_parse_tags_from_text(tag, content))
|
||||
# Handles case for multimodal messages.
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
results.extend(_parse_tags_from_text(tag, item["text"]))
|
||||
else:
|
||||
raise ValueError(f"content must be str or list, but got {type(content)}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _parse_tags_from_text(tag: str, text: str) -> list[dict[str, Any]]:
|
||||
pattern = re.compile(f"<{tag} (.*?)>")
|
||||
|
||||
results = []
|
||||
for match in re.finditer(pattern, text):
|
||||
tag_attr = match.group(1).strip()
|
||||
attr = _parse_attributes_from_tags(tag_attr)
|
||||
|
||||
results.append({"tag": tag, "attr": attr, "match": match})
|
||||
return results
|
||||
|
||||
|
||||
def _parse_attributes_from_tags(tag_content: str) -> dict[str, str]:
|
||||
pattern = r"([^ ]+)"
|
||||
attrs = re.findall(pattern, tag_content)
|
||||
reconstructed_attrs = _reconstruct_attributes(attrs)
|
||||
|
||||
def _append_src_value(content: dict[str, str], value: Any) -> None:
|
||||
if "src" in content:
|
||||
content["src"] += f" {value}"
|
||||
else:
|
||||
content["src"] = value
|
||||
|
||||
content: dict[str, str] = {}
|
||||
for attr in reconstructed_attrs:
|
||||
if "=" not in attr:
|
||||
_append_src_value(content, attr)
|
||||
continue
|
||||
|
||||
key, value = attr.split("=", 1)
|
||||
if value.startswith("'") or value.startswith('"'):
|
||||
content[key] = value[1:-1] # remove quotes
|
||||
else:
|
||||
_append_src_value(content, attr)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def _reconstruct_attributes(attrs: list[str]) -> list[str]:
|
||||
"""Reconstructs attributes from a list of strings where some attributes may be split across multiple elements."""
|
||||
|
||||
def is_attr(attr: str) -> bool:
|
||||
if "=" in attr:
|
||||
_, value = attr.split("=", 1)
|
||||
if value.startswith("'") or value.startswith('"'):
|
||||
return True
|
||||
return False
|
||||
|
||||
reconstructed = []
|
||||
found_attr = False
|
||||
for attr in attrs:
|
||||
if is_attr(attr):
|
||||
reconstructed.append(attr)
|
||||
found_attr = True
|
||||
else:
|
||||
if found_attr:
|
||||
reconstructed[-1] += f" {attr}"
|
||||
found_attr = True
|
||||
elif reconstructed:
|
||||
reconstructed[-1] += f" {attr}"
|
||||
else:
|
||||
reconstructed.append(attr)
|
||||
return reconstructed
|
||||
Reference in New Issue
Block a user