CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
__all__: list[str] = []

View File

@@ -0,0 +1,5 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
__all__: list[str] = []

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from ...assistant_agent import ConversableAgent
class AgentCapability:
"""Base class for composable capabilities that can be added to an agent."""
def __init__(self):
pass
def add_to_agent(self, agent: ConversableAgent):
"""Adds a particular capability to the given agent. Must be implemented by the capability subclass.
An implementation will typically call agent.register_hook() one or more times. See teachability.py as an example.
"""
raise NotImplementedError

View File

@@ -0,0 +1,301 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import re
from typing import Any, Literal, Optional, Protocol, Union
from .... import Agent, ConversableAgent, code_utils
from ....cache import AbstractCache
from ....import_utils import optional_import_block, require_optional_import
from ....llm_config import LLMConfig
from .. import img_utils
from ..capabilities.agent_capability import AgentCapability
from ..text_analyzer_agent import TextAnalyzerAgent
with optional_import_block():
from PIL.Image import Image
from openai import OpenAI
SYSTEM_MESSAGE = "You've been given the special ability to generate images."
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT.
DO NOT include any advice. RESPOND like the following example:
EXAMPLE: Blue background, 3D shapes, ...
"""
class ImageGenerator(Protocol):
"""This class defines an interface for image generators.
Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
input and returns a PIL Image object.
NOTE: Current implementation does not allow you to edit a previously existing image.
"""
def generate_image(self, prompt: str) -> "Image":
"""Generates an image based on the provided prompt.
Args:
prompt: A string describing the desired image.
Returns:
A PIL Image object representing the generated image.
Raises:
ValueError: If the image generation fails.
"""
...
def cache_key(self, prompt: str) -> str:
"""Generates a unique cache key for the given prompt.
This key can be used to store and retrieve generated images based on the prompt.
Args:
prompt: A string describing the desired image.
Returns:
A unique string that can be used as a cache key.
"""
...
@require_optional_import("PIL", "unknown")
@require_optional_import("openai>=1.66.2", "openai")
class DalleImageGenerator:
"""Generates images using OpenAI's DALL-E models.
This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E
models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate.
Note: Current implementation does not allow you to edit a previously existing image.
"""
def __init__(
self,
llm_config: Union[LLMConfig, dict[str, Any]],
resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
quality: Literal["standard", "hd"] = "standard",
num_images: int = 1,
):
"""Args:
llm_config (LLMConfig or dict): llm config, must contain a valid dalle model and OpenAI API key in config_list.
resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
quality (str): The quality of the image you want to generate. Must be one of "standard", "hd".
num_images (int): The number of images to generate.
"""
config_list = llm_config["config_list"]
_validate_dalle_model(config_list[0]["model"])
_validate_resolution_format(resolution)
self._model = config_list[0]["model"]
self._resolution = resolution
self._quality = quality
self._num_images = num_images
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])
def generate_image(self, prompt: str) -> "Image":
response = self._dalle_client.images.generate(
model=self._model,
prompt=prompt,
size=self._resolution,
quality=self._quality,
n=self._num_images,
)
image_url = response.data[0].url
if image_url is None:
raise ValueError("Failed to generate image.")
return img_utils.get_pil_image(image_url)
def cache_key(self, prompt: str) -> str:
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
return ",".join([str(k) for k in keys])
@require_optional_import("PIL", "unknown")
class ImageGeneration(AgentCapability):
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and
extract relevant details.
2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image.
3. Optionally caches generated images for faster retrieval in future conversations.
NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every
message received by the agent.
Example:
```python
import autogen
from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration
# Assuming you have llm configs configured for the LLMs you want to use and Dalle.
# Create the agent
agent = autogen.ConversableAgent(
name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER"
)
# Create an ImageGenerator with desired settings
dalle_gen = generate_images.DalleImageGenerator(llm_config={...})
# Add the ImageGeneration capability to the agent
agent.add_capability(ImageGeneration(image_generator=dalle_gen))
```
"""
def __init__(
self,
image_generator: ImageGenerator,
cache: Optional[AbstractCache] = None,
text_analyzer_llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
verbosity: int = 0,
register_reply_position: int = 2,
):
"""Args:
image_generator (ImageGenerator): The image generator you would like to use to generate images.
cache (None or AbstractCache): The cache client to use to store and retrieve generated images. If None,
no caching will be used.
text_analyzer_llm_config (LLMConfig or Dict or None): The LLM config for the text analyzer. If None, the LLM config will
be retrieved from the agent you're adding the ability to.
text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze
incoming messages and extract the prompt for image generation. The default instructions focus on
summarizing the prompt. You can customize the instructions to achieve more granular control over prompt
extraction.
Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.'
verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text
analyzer llm calls will be silent if verbosity is less than 2.
register_reply_position (int): The position of the reply function in the agent's list of reply functions.
This capability registers a new reply function to handle messages with image generation requests.
Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
"""
self._image_generator = image_generator
self._cache = cache
self._text_analyzer_llm_config = text_analyzer_llm_config
self._text_analyzer_instructions = text_analyzer_instructions
self._verbosity = verbosity
self._register_reply_position = register_reply_position
self._agent: Optional[ConversableAgent] = None
self._text_analyzer: Optional[TextAnalyzerAgent] = None
def add_to_agent(self, agent: ConversableAgent):
"""Adds the Image Generation capability to the specified ConversableAgent.
This function performs the following modifications to the agent:
1. Registers a reply function: A new reply function is registered with the agent to handle messages that
potentially request image generation. This function analyzes the message and triggers image generation if
necessary.
2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements.
3. Updates System Message: The agent's system message is updated to include a message indicating the
capability to generate images has been added.
4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation
capability. This might be helpful in certain use cases, like group chats.
Args:
agent (ConversableAgent): The ConversableAgent to add the capability to.
"""
self._agent = agent
agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position)
self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config
self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config)
agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE)
agent.description += "\n" + DESCRIPTION_MESSAGE
def _image_gen_reply(
self,
recipient: ConversableAgent,
messages: Optional[list[dict[str, Any]]],
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
if messages is None:
return False, None
last_message = code_utils.content_str(messages[-1]["content"])
if not last_message:
return False, None
if self._should_generate_image(last_message):
prompt = self._extract_prompt(last_message)
image = self._cache_get(prompt)
if image is None:
image = self._image_generator.generate_image(prompt)
self._cache_set(prompt, image)
return True, self._generate_content_message(prompt, image)
else:
return False, None
def _should_generate_image(self, message: str) -> bool:
assert self._text_analyzer is not None
instructions = """
Does any part of the TEXT ask the agent to generate an image?
The TEXT must explicitly mention that the image must be generated.
Answer with just one word, yes or no.
"""
analysis = self._text_analyzer.analyze_text(message, instructions)
return "yes" in self._extract_analysis(analysis).lower()
def _extract_prompt(self, last_message) -> str:
assert self._text_analyzer is not None
analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
return self._extract_analysis(analysis)
def _cache_get(self, prompt: str) -> Optional["Image"]:
if self._cache:
key = self._image_generator.cache_key(prompt)
cached_value = self._cache.get(key)
if cached_value:
return img_utils.get_pil_image(cached_value)
def _cache_set(self, prompt: str, image: "Image"):
if self._cache:
key = self._image_generator.cache_key(prompt)
self._cache.set(key, img_utils.pil_to_data_uri(image))
def _extract_analysis(self, analysis: Optional[Union[str, dict[str, Any]]]) -> str:
if isinstance(analysis, dict):
return code_utils.content_str(analysis["content"])
else:
return code_utils.content_str(analysis)
def _generate_content_message(self, prompt: str, image: "Image") -> dict[str, Any]:
return {
"content": [
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
{"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
]
}
# Helpers
def _validate_resolution_format(resolution: str):
"""Checks if a string is in a valid resolution format (e.g., "1024x768")."""
pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits
matched_resolution = re.match(pattern, resolution)
if matched_resolution is None:
raise ValueError(f"Invalid resolution format: {resolution}")
def _validate_dalle_model(model: str):
if model not in ["dall-e-3", "dall-e-2"]:
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")

View File

@@ -0,0 +1,393 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import os
import pickle
from typing import Any, Optional, Union
from ....formatting_utils import colored
from ....import_utils import optional_import_block, require_optional_import
from ....llm_config import LLMConfig
from ...assistant_agent import ConversableAgent
from ..text_analyzer_agent import TextAnalyzerAgent
from .agent_capability import AgentCapability
with optional_import_block():
import chromadb
from chromadb.config import Settings
class Teachability(AgentCapability):
"""Teachability uses a vector database to give an agent the ability to remember user teachings,
where the user is any caller (human or not) sending messages to the teachable agent.
Teachability is designed to be composable with other agent capabilities.
To make any conversable agent teachable, instantiate both the agent and the Teachability class,
then pass the agent to teachability.add_to_agent(agent).
Note that teachable agents in a group chat must be given unique path_to_db_dir values.
When adding Teachability to an agent, the following are modified:
- The agent's system message is appended with a note about the agent's new ability.
- A hook is added to the agent's `process_last_received_message` hookable method,
and the hook potentially modifies the last of the received messages to include earlier teachings related to the message.
Added teachings do not propagate into the stored message history.
If new user teachings are detected, they are added to new memos in the vector database.
"""
def __init__(
self,
verbosity: Optional[int] = 0,
reset_db: Optional[bool] = False,
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
recall_threshold: Optional[float] = 1.5,
max_num_retrievals: Optional[int] = 10,
llm_config: Optional[Union[LLMConfig, dict[str, Any], bool]] = None,
):
"""Args:
verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
reset_db (Optional, bool): True to clear the DB before starting. Default False.
path_to_db_dir (Optional, str): path to the directory where this particular agent's DB is stored. Default "./tmp/teachable_agent_db"
recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
llm_config (LLMConfig or dict or False): llm inference configuration passed to TextAnalyzerAgent.
If None, TextAnalyzerAgent uses llm_config from the teachable agent.
"""
self.verbosity = verbosity
self.path_to_db_dir = path_to_db_dir
self.recall_threshold = recall_threshold
self.max_num_retrievals = max_num_retrievals
self.llm_config = llm_config
self.analyzer = None
self.teachable_agent = None
# Create the memo store.
self.memo_store = MemoStore(self.verbosity, reset_db, self.path_to_db_dir)
def add_to_agent(self, agent: ConversableAgent):
"""Adds teachability to the given agent."""
self.teachable_agent = agent
# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
# Was an llm_config passed to the constructor?
if self.llm_config is None:
# No. Use the agent's llm_config.
self.llm_config = agent.llm_config
assert self.llm_config, "Teachability requires a valid llm_config."
# Create the analyzer agent.
self.analyzer = TextAnalyzerAgent(llm_config=self.llm_config)
# Append extra info to the system message.
agent.update_system_message(
agent.system_message
+ "\nYou've been given the special ability to remember user teachings from prior conversations."
)
def prepopulate_db(self):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()
def process_last_received_message(self, text: Union[dict[str, Any], str]):
"""Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
"""
# Try to retrieve relevant memos from the DB.
expanded_text = text
if self.memo_store.last_memo_id > 0:
expanded_text = self._consider_memo_retrieval(text)
# Try to store any user teachings in new memos to be used in the future.
self._consider_memo_storage(text)
# Return the (possibly) expanded message text.
return expanded_text
def _consider_memo_storage(self, comment: Union[dict[str, Any], str]):
"""Decides whether to store something from one user comment in the DB."""
memo_added = False
# Check for a problem-solution pair.
response = self._analyze(
comment,
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
)
if "yes" in response.lower():
# Can we extract advice?
advice = self._analyze(
comment,
"Briefly copy any advice from the TEXT that may be useful for a similar but different task in the future. But if no advice is present, just respond with 'none'.",
)
if "none" not in advice.lower():
# Yes. Extract the task.
task = self._analyze(
comment,
"Briefly copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice.",
)
# Generalize the task.
general_task = self._analyze(
task,
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
)
# Add the task-advice (problem-solution) pair to the vector DB.
if self.verbosity >= 1:
print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow"))
self.memo_store.add_input_output_pair(general_task, advice)
memo_added = True
# Check for information to be learned.
response = self._analyze(
comment,
"Does the TEXT contain information that could be committed to memory? Answer with just one word, yes or no.",
)
if "yes" in response.lower():
# Yes. What question would this information answer?
question = self._analyze(
comment,
"Imagine that the user forgot this information in the TEXT. How would they ask you for this information? Include no other text in your response.",
)
# Extract the information.
answer = self._analyze(
comment, "Copy the information from the TEXT that should be committed to memory. Add no explanation."
)
# Add the question-answer pair to the vector DB.
if self.verbosity >= 1:
print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow"))
self.memo_store.add_input_output_pair(question, answer)
memo_added = True
# Were any memos added?
if memo_added:
# Yes. Save them to disk.
self.memo_store._save_memos()
def _consider_memo_retrieval(self, comment: Union[dict[str, Any], str]):
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""
# First, use the comment directly as the lookup key.
if self.verbosity >= 1:
print(colored("\nLOOK FOR RELEVANT MEMOS, AS QUESTION-ANSWER PAIRS", "light_yellow"))
memo_list = self._retrieve_relevant_memos(comment)
# Next, if the comment involves a task, then extract and generalize the task before using it as the lookup key.
response = self._analyze(
comment,
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
)
if "yes" in response.lower():
if self.verbosity >= 1:
print(colored("\nLOOK FOR RELEVANT MEMOS, AS TASK-ADVICE PAIRS", "light_yellow"))
# Extract the task.
task = self._analyze(
comment, "Copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice."
)
# Generalize the task.
general_task = self._analyze(
task,
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
)
# Append any relevant memos.
memo_list.extend(self._retrieve_relevant_memos(general_task))
# De-duplicate the memo list.
memo_list = list(set(memo_list))
# Append the memos to the text of the last message.
return comment + self._concatenate_memo_texts(memo_list)
def _retrieve_relevant_memos(self, input_text: str) -> list:
"""Returns semantically related memos from the DB."""
memo_list = self.memo_store.get_related_memos(
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
)
if self.verbosity >= 1: # noqa: SIM102
# Was anything retrieved?
if len(memo_list) == 0:
# No. Look at the closest memo.
print(colored("\nTHE CLOSEST MEMO IS BEYOND THE THRESHOLD:", "light_yellow"))
self.memo_store.get_nearest_memo(input_text)
print() # Print a blank line. The memo details were printed by get_nearest_memo().
# Create a list of just the memo output_text strings.
memo_list = [memo[1] for memo in memo_list]
return memo_list
def _concatenate_memo_texts(self, memo_list: list) -> str:
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
memo_texts = ""
if len(memo_list) > 0:
info = "\n# Memories that might help\n"
for memo in memo_list:
info = info + "- " + memo + "\n"
if self.verbosity >= 1:
print(colored("\nMEMOS APPENDED TO LAST MESSAGE...\n" + info + "\n", "light_yellow"))
memo_texts = memo_texts + "\n" + info
return memo_texts
def _analyze(self, text_to_analyze: Union[dict[str, Any], str], analysis_instructions: Union[dict[str, Any], str]):
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
self.analyzer.reset() # Clear the analyzer's list of messages.
self.teachable_agent.send(
recipient=self.analyzer, message=text_to_analyze, request_reply=False, silent=(self.verbosity < 2)
) # Put the message in the analyzer's list.
self.teachable_agent.send(
recipient=self.analyzer, message=analysis_instructions, request_reply=True, silent=(self.verbosity < 2)
) # Request the reply.
return self.teachable_agent.last_message(self.analyzer)["content"]
@require_optional_import("chromadb", "teachable")
class MemoStore:
"""Provides memory storage and retrieval for a teachable agent, using a vector database.
Each DB entry (called a memo) is a pair of strings: an input text and an output text.
The input text might be a question, or a task to perform.
The output text might be an answer to the question, or advice on how to perform the task.
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
"""
def __init__(
self,
verbosity: Optional[int] = 0,
reset: Optional[bool] = False,
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
):
"""Args:
- verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
- reset (Optional, bool): True to clear the DB before starting. Default False.
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
"""
self.verbosity = verbosity
self.path_to_db_dir = path_to_db_dir
# Load or create the vector DB on disk.
settings = Settings(
anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
)
self.db_client = chromadb.Client(settings)
self.vec_db = self.db_client.create_collection("memos", get_or_create=True) # The collection is the DB.
# Load or create the associated memo dict on disk.
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
self.uid_text_dict = {}
self.last_memo_id = 0
if (not reset) and os.path.exists(self.path_to_dict):
print(colored("\nLOADING MEMORY FROM DISK", "light_green"))
print(colored(f" Location = {self.path_to_dict}", "light_green"))
with open(self.path_to_dict, "rb") as f:
self.uid_text_dict = pickle.load(f)
self.last_memo_id = len(self.uid_text_dict)
if self.verbosity >= 3:
self.list_memos()
# Clear the DB if requested.
if reset:
self.reset_db()
def list_memos(self):
"""Prints the contents of MemoStore."""
print(colored("LIST OF MEMOS", "light_green"))
for uid, text in self.uid_text_dict.items():
input_text, output_text = text
print(
colored(
f" ID: {uid}\n INPUT TEXT: {input_text}\n OUTPUT TEXT: {output_text}",
"light_green",
)
)
def _save_memos(self):
"""Saves self.uid_text_dict to disk."""
with open(self.path_to_dict, "wb") as file:
pickle.dump(self.uid_text_dict, file)
def reset_db(self):
"""Forces immediate deletion of the DB's contents, in memory and on disk."""
print(colored("\nCLEARING MEMORY", "light_green"))
self.db_client.delete_collection("memos")
self.vec_db = self.db_client.create_collection("memos")
self.uid_text_dict = {}
self._save_memos()
def add_input_output_pair(self, input_text: str, output_text: str):
"""Adds an input-output pair to the vector DB."""
self.last_memo_id += 1
self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
self.uid_text_dict[str(self.last_memo_id)] = input_text, output_text
if self.verbosity >= 1:
print(
colored(
f"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {self.last_memo_id}\n INPUT\n {input_text}\n OUTPUT\n {output_text}\n",
"light_yellow",
)
)
if self.verbosity >= 3:
self.list_memos()
def get_nearest_memo(self, query_text: str):
"""Retrieves the nearest memo to the given query text."""
results = self.vec_db.query(query_texts=[query_text], n_results=1)
uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
input_text_2, output_text = self.uid_text_dict[uid]
assert input_text == input_text_2
if self.verbosity >= 1:
print(
colored(
f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}",
"light_yellow",
)
)
return input_text, output_text, distance
def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]):
"""Retrieves memos that are related to the given query text within the specified distance threshold."""
if n_results > len(self.uid_text_dict):
n_results = len(self.uid_text_dict)
results = self.vec_db.query(query_texts=[query_text], n_results=n_results)
memos = []
num_results = len(results["ids"][0])
for i in range(num_results):
uid, input_text, distance = results["ids"][0][i], results["documents"][0][i], results["distances"][0][i]
if distance < threshold:
input_text_2, output_text = self.uid_text_dict[uid]
assert input_text == input_text_2
if self.verbosity >= 1:
print(
colored(
f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}",
"light_yellow",
)
)
memos.append((input_text, output_text, distance))
return memos
def prepopulate(self):
"""Adds a few arbitrary examples to the vector DB, just to make retrieval less trivial."""
if self.verbosity >= 1:
print(colored("\nPREPOPULATING MEMORY", "light_green"))
examples = []
examples.append({"text": "When I say papers I mean research papers, which are typically pdfs.", "label": "yes"})
examples.append({"text": "Please verify that each paper you listed actually uses langchain.", "label": "no"})
examples.append({"text": "Tell gpt the output should still be latex code.", "label": "no"})
examples.append({"text": "Hint: convert pdfs to text and then answer questions based on them.", "label": "yes"})
examples.append({
"text": "To create a good PPT, include enough content to make it interesting.",
"label": "yes",
})
examples.append({
"text": "No, for this case the columns should be aspects and the rows should be frameworks.",
"label": "no",
})
examples.append({"text": "When writing code, remember to include any libraries that are used.", "label": "yes"})
examples.append({"text": "Please summarize the papers by Eric Horvitz on bounded rationality.", "label": "no"})
examples.append({"text": "Compare the h-index of Daniel Weld and Oren Etzioni.", "label": "no"})
examples.append({
"text": "Double check to be sure that the columns in a table correspond to what was asked for.",
"label": "yes",
})
for example in examples:
self.add_input_output_pair(example["text"], example["label"])
self._save_memos()

View File

@@ -0,0 +1,66 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import Any, Protocol
from ....import_utils import optional_import_block, require_optional_import
with optional_import_block() as result:
import llmlingua
from llmlingua import PromptCompressor
class TextCompressor(Protocol):
"""Defines a protocol for text compression to optimize agent interactions."""
def compress_text(self, text: str, **compression_params) -> dict[str, Any]:
"""This method takes a string as input and returns a dictionary containing the compressed text and other
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
"""
...
@require_optional_import("llmlingua", "long-context")
class LLMLingua:
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
and the specific configurations used for the PromptCompressor.
"""
def __init__(
self,
prompt_compressor_kwargs: dict = dict(
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2=True,
device_map="cpu",
),
structured_compression: bool = False,
) -> None:
"""Args:
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2 set to True, and device_map set to "cpu".
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
is used. Defaults to False.
dictionary.
Raises:
ImportError: If the llmlingua library is not installed.
"""
self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)
assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
self._compression_method = (
self._prompt_compressor.structured_compress_prompt
if structured_compression
else self._prompt_compressor.compress_prompt
)
def compress_text(self, text: str, **compression_params) -> dict[str, Any]:
return self._compression_method([text], **compression_params)

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from ....agentchat import ConversableAgent
from ....tools import Tool
class ToolsCapability:
"""Adding a list of tools as composable capabilities to a single agent.
This class can be inherited from to allow code to run at the point of creating or adding the capability.
Note: both caller and executor of the tools are the same agent.
"""
def __init__(self, tool_list: list[Tool]):
self.tools = [tool for tool in tool_list]
def add_to_agent(self, agent: ConversableAgent):
"""Add tools to the given agent."""
for tool in self.tools:
tool.register_tool(agent=agent)

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import copy
from typing import TYPE_CHECKING, Any
from ....formatting_utils import colored
from .transforms import MessageTransform
if TYPE_CHECKING:
from ...conversable_agent import ConversableAgent
class TransformMessages:
"""Agent capability for transforming messages before reply generation.
This capability allows you to apply a series of message transformations to
a ConversableAgent's incoming messages before they are processed for response
generation. This is useful for tasks such as:
- Limiting the number of messages considered for context.
- Truncating messages to meet token limits.
- Filtering sensitive information.
- Customizing message formatting.
To use `TransformMessages`:
1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`).
2. Instantiate `TransformMessages` with a list of these transformations.
3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`.
NOTE: Order of message transformations is important. You could get different results based on
the order of transformations.
Example:
```python
from agentchat import ConversableAgent
from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter
max_messages = MessageHistoryLimiter(max_messages=2)
truncate_messages = MessageTokenLimiter(max_tokens=500)
transform_messages = TransformMessages(transforms=[max_messages, truncate_messages])
agent = ConversableAgent(...)
transform_messages.add_to_agent(agent)
```
"""
def __init__(self, *, transforms: list[MessageTransform] = [], verbose: bool = True):
"""Args:
transforms: A list of message transformations to apply.
verbose: Whether to print logs of each transformation or not.
"""
self._transforms = transforms
self._verbose = verbose
def add_to_agent(self, agent: "ConversableAgent"):
"""Adds the message transformations capability to the specified ConversableAgent.
This function performs the following modifications to the agent:
1. Registers a hook that automatically transforms all messages before they are processed for
response generation.
"""
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
def _transform_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
post_transform_messages = copy.deepcopy(messages)
system_message = None
if messages[0]["role"] == "system":
system_message = copy.deepcopy(messages[0])
post_transform_messages.pop(0)
for transform in self._transforms:
# deepcopy in case pre_transform_messages will later be used for logs printing
pre_transform_messages = (
copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages
)
post_transform_messages = transform.apply_transform(pre_transform_messages)
if self._verbose:
logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages)
if had_effect:
print(colored(logs_str, "yellow"))
if system_message:
post_transform_messages.insert(0, system_message)
return post_transform_messages

View File

@@ -0,0 +1,579 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import copy
import sys
from typing import Any, Optional, Protocol, Union
import tiktoken
from termcolor import colored
from .... import token_count_utils
from ....cache import AbstractCache, Cache
from ....types import MessageContentType
from . import transforms_util
from .text_compressors import LLMLingua, TextCompressor
class MessageTransform(Protocol):
"""Defines a contract for message transformation.
Classes implementing this protocol should provide an `apply_transform` method
that takes a list of messages and returns the transformed list.
"""
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Applies a transformation to a list of messages.
Args:
messages: A list of dictionaries representing messages.
Returns:
A new list of dictionaries containing the transformed messages.
"""
...
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
"""Creates the string including the logs of the transformation
Alongside the string, it returns a boolean indicating whether the transformation had an effect or not.
Args:
pre_transform_messages: A list of dictionaries representing messages before the transformation.
post_transform_messages: A list of dictionaries representig messages after the transformation.
Returns:
A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not.
"""
...
class MessageHistoryLimiter:
"""Limits the number of messages considered by an agent for response generation.
This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages).
It trims the conversation history by removing older messages, retaining only the most recent messages.
"""
def __init__(
self,
max_messages: Optional[int] = None,
keep_first_message: bool = False,
exclude_names: Optional[list[str]] = None,
):
"""Args:
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
keep_first_message bool: Whether to keep the original first message in the conversation history.
Defaults to False.
exclude_names Optional[list[str]]: List of message sender names to exclude from the message history.
Messages from these senders will be filtered out before applying the message limit. Defaults to None.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages
self._keep_first_message = keep_first_message
self._exclude_names = exclude_names
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Truncates the conversation history to the specified maximum number of messages.
This method returns a new list containing the most recent messages up to the specified
maximum number of messages (max_messages). If max_messages is None, it returns the
original list of messages unmodified.
Args:
messages (List[Dict]): The list of messages representing the conversation history.
Returns:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""
exclude_names = getattr(self, "_exclude_names", None)
filtered = [msg for msg in messages if msg.get("name") not in exclude_names] if exclude_names else messages
if self._max_messages is None or len(filtered) <= self._max_messages:
return filtered
truncated_messages = []
remaining_count = self._max_messages
# Start with the first message if we need to keep it
if self._keep_first_message and filtered:
truncated_messages = [filtered[0]]
remaining_count -= 1
# Loop through messages in reverse
for i in range(len(filtered) - 1, 0, -1):
if remaining_count > 1:
truncated_messages.insert(1 if self._keep_first_message else 0, filtered[i])
if remaining_count == 1: # noqa: SIM102
# If there's only 1 slot left and it's a 'tools' message, ignore it.
if filtered[i].get("role") != "tool":
truncated_messages.insert(1, filtered[i])
remaining_count -= 1
if remaining_count == 0:
break
return truncated_messages
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
pre_transform_messages_len = len(pre_transform_messages)
post_transform_messages_len = len(post_transform_messages)
if post_transform_messages_len < pre_transform_messages_len:
logs_str = (
f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. "
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}."
)
return logs_str, True
return "No messages were removed.", False
def _validate_max_messages(self, max_messages: Optional[int]):
if max_messages is not None and max_messages < 1:
raise ValueError("max_messages must be None or greater than 1")
class MessageTokenLimiter:
"""Truncates messages to meet token limits for efficient processing and response generation.
This transformation applies two levels of truncation to the conversation history:
1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message.
2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens.
NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token
counts for the same text.
NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input
(e.g images).
The truncation process follows these steps in order:
1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages
is less than this threshold, then the messages are returned as is. In other case, the following process is applied.
2. Messages are processed in reverse order (newest to oldest).
3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
and other types of content, only the text content is truncated.
4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
exceeds this limit, the current message being processed get truncated to meet the total token count and any
remaining messages get discarded.
5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
original message order.
"""
def __init__(
self,
max_tokens_per_message: Optional[int] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
filter_dict: Optional[dict[str, Any]] = None,
exclude_filter: bool = True,
):
"""Args:
max_tokens_per_message (None or int): Maximum number of tokens to keep in each message.
Must be greater than or equal to 0 if not None.
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
Must be greater than or equal to 0 if not None.
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from token truncation. If False, messages that match the filter will be truncated.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Applies token truncation to the conversation history.
Args:
messages (List[Dict]): The list of messages representing the conversation history.
Returns:
List[Dict]: A new list containing the truncated messages up to the specified token limits.
"""
assert self._max_tokens_per_message is not None
assert self._max_tokens is not None
assert self._min_tokens is not None
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages
temp_messages = copy.deepcopy(messages)
processed_messages = []
processed_messages_tokens = 0
for msg in reversed(temp_messages):
# Some messages may not have content.
if not transforms_util.is_content_right_type(msg.get("content")):
processed_messages.insert(0, msg)
continue
if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
processed_messages.insert(0, msg)
processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
continue
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
# If adding this message would exceed the token limit, truncate the last message to meet the total token
# limit and discard all remaining messages
if expected_tokens_remained < 0:
msg["content"] = self._truncate_str_to_tokens(
msg["content"], self._max_tokens - processed_messages_tokens
)
processed_messages.insert(0, msg)
break
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
msg_tokens = transforms_util.count_text_tokens(msg["content"])
# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
processed_messages.insert(0, msg)
return processed_messages
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
pre_transform_messages_tokens = sum(
transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
)
post_transform_messages_tokens = sum(
transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
)
if post_transform_messages_tokens < pre_transform_messages_tokens:
logs_str = (
f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. "
f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}"
)
return logs_str, True
return "No tokens were truncated.", False
def _truncate_str_to_tokens(self, contents: Union[str, list], n_tokens: int) -> Union[str, list]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
elif isinstance(contents, list):
return self._truncate_multimodal_text(contents, n_tokens)
else:
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")
def _truncate_multimodal_text(self, contents: list[dict[str, Any]], n_tokens: int) -> list[dict[str, Any]]:
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
tmp_contents = []
for content in contents:
if content["type"] == "text":
truncated_text = self._truncate_tokens(content["text"], n_tokens)
tmp_contents.append({"type": "text", "text": truncated_text})
else:
tmp_contents.append(content)
return tmp_contents
def _truncate_tokens(self, text: str, n_tokens: int) -> str:
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer
encoded_tokens = encoding.encode(text)
truncated_tokens = encoded_tokens[:n_tokens]
truncated_text = encoding.decode(truncated_tokens) # Decode back to text
return truncated_text
def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]:
if max_tokens is not None and max_tokens < 0:
raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0")
try:
allowed_tokens = token_count_utils.get_max_token_limit(self._model)
except Exception:
print(colored(f"Model {self._model} not found in token_count_utils.", "yellow"))
allowed_tokens = None
if max_tokens is not None and allowed_tokens is not None and max_tokens > allowed_tokens:
print(
colored(
f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.",
"yellow",
)
)
return allowed_tokens
return max_tokens if max_tokens is not None else sys.maxsize
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
if min_tokens is None:
return 0
if min_tokens < 0:
raise ValueError("min_tokens must be None or greater than or equal to 0.")
if max_tokens is not None and min_tokens > max_tokens:
raise ValueError("min_tokens must not be more than max_tokens.")
return min_tokens
class TextMessageCompressor:
"""A transform for compressing text messages in a conversation history.
It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
processing and response generation by downstream models.
"""
def __init__(
self,
text_compressor: Optional[TextCompressor] = None,
min_tokens: Optional[int] = None,
compression_params: dict = dict(),
cache: Optional[AbstractCache] = None,
filter_dict: Optional[dict[str, Any]] = None,
exclude_filter: bool = True,
):
"""Args:
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
protocol. If None, it defaults to LLMLingua.
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
than or equal to 0 if not None. If None, no threshold-based compression is applied.
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from compression. If False, messages that match the filter will be compressed.
"""
if text_compressor is None:
text_compressor = LLMLingua()
self._validate_min_tokens(min_tokens)
self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
if cache is None:
self._cache = Cache.disk()
else:
self._cache = cache
# Optimizing savings calculations to optimize log generation
self._recent_tokens_savings = 0
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Applies compression to messages in a conversation history based on the specified configuration.
The function processes each message according to the `compression_args` and `min_tokens` settings, applying
the specified compression configuration and returning a new list of messages with reduced token counts
where possible.
Args:
messages (List[Dict]): A list of message dictionaries to be compressed.
Returns:
List[Dict]: A list of dictionaries with the message content compressed according to the configured
method and scope.
"""
# Make sure there is at least one message
if not messages:
return messages
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages
total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not transforms_util.is_content_right_type(message.get("content")):
continue
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
continue
if transforms_util.is_content_text_empty(message["content"]):
continue
cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
cached_content = transforms_util.cache_content_get(self._cache, cache_key)
if cached_content is not None:
message["content"], savings = cached_content
else:
message["content"], savings = self._compress(message["content"])
transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)
assert isinstance(savings, int)
total_savings += savings
self._recent_tokens_savings = total_savings
return processed_messages
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
if self._recent_tokens_savings > 0:
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
else:
return "No tokens saved with text compression.", False
def _compress(self, content: MessageContentType) -> tuple[MessageContentType, int]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return content, 0
def _compress_multimodal(self, content: MessageContentType) -> tuple[MessageContentType, int]:
tokens_saved = 0
for item in content:
if isinstance(item, dict) and "text" in item:
item["text"], savings = self._compress_text(item["text"])
tokens_saved += savings
elif isinstance(item, str):
item, savings = self._compress_text(item)
tokens_saved += savings
return content, tokens_saved
def _compress_text(self, text: str) -> tuple[str, int]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
savings = 0
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
return compressed_text["compressed_prompt"], savings
def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")
class TextMessageContentName:
"""A transform for including the agent's name in the content of a message.
How to create and apply the transform:
# Imports
from autogen.agentchat.contrib.capabilities import transform_messages, transforms
# Create Transform
name_transform = transforms.TextMessageContentName(position="start", format_string="'{name}' said:\n")
# Create the TransformMessages
context_handling = transform_messages.TransformMessages(
transforms=[
name_transform
]
)
# Add it to an agent so when they run inference it will apply to the messages
context_handling.add_to_agent(my_agent)
"""
def __init__(
self,
position: str = "start",
format_string: str = "{name}:\n",
deduplicate: bool = True,
filter_dict: Optional[dict[str, Any]] = None,
exclude_filter: bool = True,
):
"""Args:
position (str): The position to add the name to the content. The possible options are 'start' or 'end'. Defaults to 'start'.
format_string (str): The f-string to format the message name with. Use '{name}' as a placeholder for the agent's name. Defaults to '{name}:\n' and must contain '{name}'.
deduplicate (bool): Whether to deduplicate the formatted string so it doesn't appear twice (sometimes the LLM will add it to new messages itself). Defaults to True.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from compression. If False, messages that match the filter will be compressed.
"""
assert isinstance(position, str) and position in ["start", "end"]
assert isinstance(format_string, str) and "{name}" in format_string
assert isinstance(deduplicate, bool) and deduplicate is not None
self._position = position
self._format_string = format_string
self._deduplicate = deduplicate
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
# Track the number of messages changed for logging
self._messages_changed = 0
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Applies the name change to the message based on the position and format string.
Args:
messages (List[Dict]): A list of message dictionaries.
Returns:
List[Dict]: A list of dictionaries with the message content updated with names.
"""
# Make sure there is at least one message
if not messages:
return messages
messages_changed = 0
processed_messages = copy.deepcopy(messages)
for message in processed_messages:
# Some messages may not have content.
if not transforms_util.is_content_right_type(
message.get("content")
) or not transforms_util.is_content_right_type(message.get("name")):
continue
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
continue
if transforms_util.is_content_text_empty(message["content"]) or transforms_util.is_content_text_empty(
message["name"]
):
continue
# Get and format the name in the content
content = message["content"]
formatted_name = self._format_string.format(name=message["name"])
if self._position == "start":
if not self._deduplicate or not content.startswith(formatted_name):
message["content"] = f"{formatted_name}{content}"
messages_changed += 1
else:
if not self._deduplicate or not content.endswith(formatted_name):
message["content"] = f"{content}{formatted_name}"
messages_changed += 1
self._messages_changed = messages_changed
return processed_messages
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
if self._messages_changed > 0:
return f"{self._messages_changed} message(s) changed to incorporate name.", True
else:
return "No messages changed to incorporate name.", False

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from collections.abc import Hashable
from typing import Any, Optional
from .... import token_count_utils
from ....cache.abstract_cache_base import AbstractCache
from ....oai.openai_utils import filter_config
from ....types import MessageContentType
def cache_key(content: MessageContentType, *args: Hashable) -> str:
"""Calculates the cache key for the given message content and any other hashable args.
Args:
content (MessageContentType): The message content to calculate the cache key for.
*args: Any additional hashable args to include in the cache key.
"""
str_keys = [str(key) for key in (content, *args)]
return "".join(str_keys)
def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[tuple[MessageContentType, ...]]:
"""Retrieves cached content from the cache.
Args:
cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
key (str): The key to retrieve the content from.
"""
if cache:
cached_value = cache.get(key)
if cached_value:
return cached_value
def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
"""Sets content into the cache.
Args:
cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
key (str): The key to set the content into.
content (MessageContentType): The message content to set into the cache.
*extra_values: Additional values to be passed to the cache.
"""
if cache:
cache_value = (content, *extra_values)
cache.set(key, cache_value)
def min_tokens_reached(messages: list[dict[str, Any]], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
Args:
messages (List[Dict]): A list of messages to check.
min_tokens (None or int): The minimum number of tokens to check for.
"""
if not min_tokens:
return True
messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens
def count_text_tokens(content: MessageContentType) -> int:
"""Calculates the number of text tokens in the given message content.
Args:
content (MessageContentType): The message content to calculate the number of text tokens for.
"""
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
token_count += token_count_utils.count_token(item)
else:
token_count += count_text_tokens(item.get("text", ""))
return token_count
def is_content_right_type(content: Any) -> bool:
"""A helper function to check if the passed in content is of the right type."""
return isinstance(content, (str, list))
def is_content_text_empty(content: MessageContentType) -> bool:
"""Checks if the content of the message does not contain any text.
Args:
content (MessageContentType): The message content to check.
"""
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
texts = []
for item in content:
if isinstance(item, str):
texts.append(item)
elif isinstance(item, dict):
texts.append(item.get("text", ""))
return not any(texts)
else:
return True
def should_transform_message(message: dict[str, Any], filter_dict: Optional[dict[str, Any]], exclude: bool) -> bool:
"""Validates whether the transform should be applied according to the filter dictionary.
Args:
message (Dict[str, Any]): The message to validate.
filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
exclude (bool): Whether to exclude messages that match the filter dictionary.
"""
if not filter_dict:
return True
return len(filter_config([message], filter_dict, exclude)) > 0

View File

@@ -0,0 +1,212 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import copy
from typing import Any, Callable, Optional, Union
from ....code_utils import content_str
from ....oai.client import OpenAIWrapper
from ...assistant_agent import ConversableAgent
from ..img_utils import (
convert_base64_to_data_uri,
get_image_data,
get_pil_image,
gpt4v_formatter,
)
from .agent_capability import AgentCapability
DEFAULT_DESCRIPTION_PROMPT = (
"Write a detailed caption for this image. "
"Pay special attention to any details that might be useful or relevant "
"to the ongoing conversation."
)
class VisionCapability(AgentCapability):
"""We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability,
such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe
the image (captioning) before sending the information to the agent's actual client.
The vision capability will hook to the ConversableAgent's `process_last_received_message`.
Some technical details:
When the agent (who has the vision capability) received an message, it will:
1. _process_received_message:
a. _append_oai_message
2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag.
a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.)
b. hook process_all_messages_before_reply
3. send:
a. hook process_message_before_send
b. _append_oai_message
"""
def __init__(
self,
lmm_config: dict[str, Any],
description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT,
custom_caption_func: Callable = None,
) -> None:
"""Initializes a new instance, setting up the configuration for interacting with
a Language Multimodal (LMM) client and specifying optional parameters for image
description and captioning.
Args:
lmm_config (Dict): Configuration for the LMM client, which is used to call
the LMM service for describing the image. This must be a dictionary containing
the necessary configuration parameters. If `lmm_config` is False or an empty dictionary,
it is considered invalid, and initialization will assert.
description_prompt (Optional[str], optional): The prompt to use for generating
descriptions of the image. This parameter allows customization of the
prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided.
custom_caption_func (Callable, optional): A callable that, if provided, will be used
to generate captions for images. This allows for custom captioning logic outside
of the standard LMM service interaction.
The callable should take three parameters as input:
1. an image URL (or local location)
2. image_data (a PIL image)
3. lmm_client (to call remote LMM)
and then return a description (as string).
If not provided, captioning will rely on the LMM client configured via `lmm_config`.
If provided, we will not run the default self._get_image_caption method.
Raises:
AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided,
an AssertionError is raised to indicate that the Vision Capability requires
one of these to be valid for operation.
"""
self._lmm_config = lmm_config
self._description_prompt = description_prompt
self._parent_agent = None
if lmm_config:
self._lmm_client = OpenAIWrapper(**lmm_config)
else:
self._lmm_client = None
self._custom_caption_func = custom_caption_func
assert self._lmm_config or custom_caption_func, (
"Vision Capability requires a valid lmm_config or custom_caption_func."
)
def add_to_agent(self, agent: ConversableAgent) -> None:
self._parent_agent = agent
# Append extra info to the system message.
agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.")
# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
def process_last_received_message(self, content: Union[str, list[dict[str, Any]]]) -> str:
"""Processes the last received message content by normalizing and augmenting it
with descriptions of any included images. The function supports input content
as either a string or a list of dictionaries, where each dictionary represents
a content item (e.g., text, image). If the content contains image URLs, it
fetches the image data, generates a caption for each image, and inserts the
caption into the augmented content.
The function aims to transform the content into a format compatible with GPT-4V
multimodal inputs, specifically by formatting strings into PIL-compatible
images if needed and appending text descriptions for images. This allows for
a more accessible presentation of the content, especially in contexts where
images cannot be displayed directly.
Args:
content (Union[str, List[dict[str, Any]]]): The last received message content, which
can be a plain text string or a list of dictionaries representing
different types of content items (e.g., text, image_url).
Returns:
str: The augmented message content
Raises:
AssertionError: If an item in the content list is not a dictionary.
Examples:
Assuming `self._get_image_caption(img_data)` returns
"A beautiful sunset over the mountains" for the image.
- Input as String:
content = "Check out this cool photo!"
Output: "Check out this cool photo!"
(Content is a string without an image, remains unchanged.)
- Input as String, with image location:
content = "What's weather in this cool photo: `<img http://example.com/photo.jpg>`"
Output: "What's weather in this cool photo: `<img http://example.com/photo.jpg>` in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)
- Input as List with Text Only:
content = `[{"type": "text", "text": "Here's an interesting fact."}]`
Output: "Here's an interesting fact."
(No images in the content, it remains unchanged.)
- Input as List with Image URL:
```python
content = [
{"type": "text", "text": "What's weather in this cool photo:"},
{"type": "image_url", "image_url": "http://example.com/photo.jpg"},
]
```
Output: "What's weather in this cool photo: `<img http://example.com/photo.jpg>` in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)
"""
copy.deepcopy(content)
# normalize the content into the gpt-4v format for multimodal
# we want to keep the URL format to keep it concise.
if isinstance(content, str):
content = gpt4v_formatter(content, img_format="url")
aug_content: str = ""
for item in content:
assert isinstance(item, dict)
if item["type"] == "text":
aug_content += item["text"]
elif item["type"] == "image_url":
img_url = item["image_url"]
img_caption = ""
if self._custom_caption_func:
img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client)
elif self._lmm_client:
img_data = get_image_data(img_url)
img_caption = self._get_image_caption(img_data)
else:
img_caption = ""
aug_content += f"<img {img_url}> in case you can not see, the caption of this image is: {img_caption}\n"
else:
print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.")
return aug_content
def _get_image_caption(self, img_data: str) -> str:
"""Args:
img_data (str): base64 encoded image data.
Returns:
str: caption for the given image.
"""
response = self._lmm_client.create(
context=None,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": self._description_prompt},
{
"type": "image_url",
"image_url": convert_base64_to_data_uri(img_data),
},
],
}
],
)
description = response.choices[0].message.content
return content_str(description)

View File

@@ -0,0 +1,411 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import base64
import copy
import os
import re
from io import BytesIO
from math import ceil
from typing import Any, Union
import requests
from ...import_utils import optional_import_block, require_optional_import
from .. import utils
with optional_import_block():
from PIL import Image
# Parameters for token counting for images for different models
MODEL_PARAMS = {
"gpt-4-vision": {
"max_edge": 2048,
"min_edge": 768,
"tile_size": 512,
"base_token_count": 85,
"token_multiplier": 170,
},
"gpt-4o-mini": {
"max_edge": 2048,
"min_edge": 768,
"tile_size": 512,
"base_token_count": 2833,
"token_multiplier": 5667,
},
"gpt-4o": {"max_edge": 2048, "min_edge": 768, "tile_size": 512, "base_token_count": 85, "token_multiplier": 170},
}
@require_optional_import("PIL", "unknown")
def get_pil_image(image_file: Union[str, "Image.Image"]) -> "Image.Image":
"""Loads an image from a file and returns a PIL Image object.
Parameters:
image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.
Returns:
Image.Image: The PIL Image object.
"""
if isinstance(image_file, Image.Image):
# Already a PIL Image object
return image_file
# Remove quotes if existed
if image_file.startswith('"') and image_file.endswith('"'):
image_file = image_file[1:-1]
if image_file.startswith("'") and image_file.endswith("'"):
image_file = image_file[1:-1]
if image_file.startswith("http://") or image_file.startswith("https://"):
# A URL file
response = requests.get(image_file)
content = BytesIO(response.content)
image = Image.open(content)
# Match base64-encoded image URIs for supported formats: jpg, jpeg, png, gif, bmp, webp
elif re.match(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", image_file):
# A URI. Remove the prefix and decode the base64 string.
base64_data = re.sub(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", "", image_file)
image = _to_pil(base64_data)
elif os.path.exists(image_file):
# A local file
image = Image.open(image_file)
else:
# base64 encoded string
image = _to_pil(image_file)
return image.convert("RGB")
@require_optional_import("PIL", "unknown")
def get_image_data(image_file: Union[str, "Image.Image"], use_b64=True) -> bytes:
"""Loads an image and returns its data either as raw bytes or in base64-encoded format.
This function first loads an image from the specified file, URL, or base64 string using
the `get_pil_image` function. It then saves this image in memory in PNG format and
retrieves its binary content. Depending on the `use_b64` flag, this binary content is
either returned directly or as a base64-encoded string.
Parameters:
image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
string of the image.
use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
If False, it returns the raw byte data of the image. Defaults to True.
Returns:
bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string
if `use_b64` is True.
"""
image = get_pil_image(image_file)
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
if use_b64:
return base64.b64encode(content).decode("utf-8")
else:
return content
@require_optional_import("PIL", "unknown")
def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, list[str]]:
"""Formats the input prompt by replacing image tags and returns the new prompt along with image locations.
Parameters:
- prompt (str): The input string that may contain image tags like `<img ...>`.
- order_image_tokens (bool, optional): Whether to order the image tokens with numbers.
It will be useful for GPT-4V. Defaults to False.
Returns:
- Tuple[str, List[str]]: A tuple containing the formatted string and a list of images (loaded in b64 format).
"""
# Initialize variables
new_prompt = prompt
image_locations = []
images = []
image_count = 0
# Regular expression pattern for matching <img ...> tags
img_tag_pattern = re.compile(r"<img ([^>]+)>")
# Find all image tags
for match in img_tag_pattern.finditer(prompt):
image_location = match.group(1)
try:
img_data = get_image_data(image_location)
except Exception as e:
# Remove the token
print(f"Warning! Unable to load image from {image_location}, because of {e}")
new_prompt = new_prompt.replace(match.group(0), "", 1)
continue
image_locations.append(image_location)
images.append(img_data)
# Increment the image count and replace the tag in the prompt
new_token = f"<image {image_count}>" if order_image_tokens else "<image>"
new_prompt = new_prompt.replace(match.group(0), new_token, 1)
image_count += 1
return new_prompt, images
@require_optional_import("PIL", "unknown")
def pil_to_data_uri(image: "Image.Image") -> str:
"""Converts a PIL Image object to a data URI.
Parameters:
image (Image.Image): The PIL Image object.
Returns:
str: The data URI string.
"""
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
def convert_base64_to_data_uri(base64_image):
def _get_mime_type_from_data_uri(base64_image):
# Decode the base64 string
image_data = base64.b64decode(base64_image)
# Check the first few bytes for known signatures
if image_data.startswith(b"\xff\xd8\xff"):
return "image/jpeg"
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
return "image/png"
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
return "image/gif"
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
return "image/webp"
return "image/jpeg" # use jpeg for unknown formats, best guess.
mime_type = _get_mime_type_from_data_uri(base64_image)
data_uri = f"data:{mime_type};base64,{base64_image}"
return data_uri
@require_optional_import("PIL", "unknown")
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> list[Union[str, dict[str, Any]]]:
"""Formats the input prompt by replacing image tags and returns a list of text and images.
Args:
prompt (str): The input string that may contain image tags like `<img ...>`.
img_format (str): what image format should be used. One of "uri", "url", "pil".
Returns:
List[Union[str, dict[str, Any]]]: A list of alternating text and image dictionary items.
"""
assert img_format in ["uri", "url", "pil"]
output = []
last_index = 0
image_count = 0
# Find all image tags
for parsed_tag in utils.parse_tags_from_content("img", prompt):
image_location = parsed_tag["attr"]["src"]
try:
if img_format == "pil":
img_data = get_pil_image(image_location)
elif img_format == "uri":
img_data = get_image_data(image_location)
img_data = convert_base64_to_data_uri(img_data)
elif img_format == "url":
img_data = image_location
else:
raise ValueError(f"Unknown image format {img_format}")
except Exception as e:
# Warning and skip this token
print(f"Warning! Unable to load image from {image_location}, because {e}")
continue
# Add text before this image tag to output list
output.append({"type": "text", "text": prompt[last_index : parsed_tag["match"].start()]})
# Add image data to output list
output.append({"type": "image_url", "image_url": {"url": img_data}})
last_index = parsed_tag["match"].end()
image_count += 1
# Add remaining text to output list
if last_index < len(prompt):
output.append({"type": "text", "text": prompt[last_index:]})
return output
def extract_img_paths(paragraph: str) -> list:
"""Extract image paths (URLs or local paths) from a text paragraph.
Parameters:
paragraph (str): The input text paragraph.
Returns:
list: A list of extracted image paths.
"""
# Regular expression to match image URLs and file paths.
# This regex detects URLs and file paths with common image extensions, including support for the webp format.
img_path_pattern = re.compile(
r"\b(?:http[s]?://\S+\.(?:jpg|jpeg|png|gif|bmp|webp)|\S+\.(?:jpg|jpeg|png|gif|bmp|webp))\b", re.IGNORECASE
)
# Find all matches in the paragraph
img_paths = re.findall(img_path_pattern, paragraph)
return img_paths
@require_optional_import("PIL", "unknown")
def _to_pil(data: str) -> "Image.Image":
"""Converts a base64 encoded image data string to a PIL Image object.
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
and finally creates and returns a PIL Image object from the BytesIO object.
Parameters:
data (str): The encoded image data string.
Returns:
Image.Image: The PIL Image object created from the input data.
"""
return Image.open(BytesIO(base64.b64decode(data)))
@require_optional_import("PIL", "unknown")
def message_formatter_pil_to_b64(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts the PIL image URLs in the messages to base64 encoded data URIs.
This function iterates over a list of message dictionaries. For each message,
if it contains a 'content' key with a list of items, it looks for items
with an 'image_url' key. The function then converts the PIL image URL
(pointed to by 'image_url') to a base64 encoded data URI.
Parameters:
messages (List[Dict]): A list of message dictionaries. Each dictionary
may contain a 'content' key with a list of items,
some of which might be image URLs.
Returns:
List[Dict]: A new list of message dictionaries with PIL image URLs in the
'image_url' key converted to base64 encoded data URIs.
Example Input:
example 1:
```python
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here?"},
{'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]
```
Example Output:
example 1:
```python
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here?"},
{'type': 'image_url', 'image_url': {'url': a B64 Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]
```
"""
new_messages = []
for message in messages:
# deepcopy to avoid modifying the original message.
message = copy.deepcopy(message)
if isinstance(message, dict) and "content" in message:
# First, if the content is a string, parse it into a list of parts.
# This is for tool output that contains images.
if isinstance(message["content"], str):
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
# Second, if the content is a list, process any image parts.
if isinstance(message["content"], list):
for item in message["content"]:
if (
isinstance(item, dict)
and "image_url" in item
and isinstance(item["image_url"]["url"], Image.Image)
):
item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"])
new_messages.append(message)
return new_messages
@require_optional_import("PIL", "unknown")
def num_tokens_from_gpt_image(
image_data: Union[str, "Image.Image"], model: str = "gpt-4-vision", low_quality: bool = False
) -> int:
"""Calculate the number of tokens required to process an image based on its dimensions
after scaling for different GPT models. Supports "gpt-4-vision", "gpt-4o", and "gpt-4o-mini".
This function scales the image so that its longest edge is at most 2048 pixels and its shortest
edge is at most 768 pixels (for "gpt-4-vision"). It then calculates the number of 512x512 tiles
needed to cover the scaled image and computes the total tokens based on the number of these tiles.
Reference: https://openai.com/api/pricing/
Args:
image_data : Union[str, Image.Image]: The image data which can either be a base64 encoded string, a URL, a file path, or a PIL Image object.
model: str: The model being used for image processing. Can be "gpt-4-vision", "gpt-4o", or "gpt-4o-mini".
low_quality: bool: Whether to use low-quality processing. Defaults to False.
Returns:
int: The total number of tokens required for processing the image.
Examples:
--------
>>> from PIL import Image
>>> img = Image.new("RGB", (2500, 2500), color="red")
>>> num_tokens_from_gpt_image(img, model="gpt-4-vision")
765
"""
image = get_pil_image(image_data) # PIL Image
width, height = image.size
# Determine model parameters
if "gpt-4-vision" in model or "gpt-4-turbo" in model or "gpt-4v" in model or "gpt-4-v" in model:
params = MODEL_PARAMS["gpt-4-vision"]
elif "gpt-4o-mini" in model:
params = MODEL_PARAMS["gpt-4o-mini"]
elif "gpt-4o" in model:
params = MODEL_PARAMS["gpt-4o"]
else:
raise ValueError(
f"Model {model} is not supported. Choose 'gpt-4-vision', 'gpt-4-turbo', 'gpt-4v', 'gpt-4-v', 'gpt-4o', or 'gpt-4o-mini'."
)
if low_quality:
return params["base_token_count"]
# 1. Constrain the longest edge
if max(width, height) > params["max_edge"]:
scale_factor = params["max_edge"] / max(width, height)
width, height = int(width * scale_factor), int(height * scale_factor)
# 2. Further constrain the shortest edge
if min(width, height) > params["min_edge"]:
scale_factor = params["min_edge"] / min(width, height)
width, height = int(width * scale_factor), int(height * scale_factor)
# 3. Count how many tiles are needed to cover the image
tiles_width = ceil(width / params["tile_size"])
tiles_height = ceil(height / params["tile_size"])
total_tokens = params["base_token_count"] + params["token_multiplier"] * (tiles_width * tiles_height)
return total_tokens

View File

@@ -0,0 +1,153 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import copy
from typing import Any, Optional, Union
from ... import OpenAIWrapper
from ...code_utils import content_str
from .. import Agent, ConversableAgent
from ..contrib.img_utils import (
gpt4v_formatter,
message_formatter_pil_to_b64,
)
DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant."""
DEFAULT_MODEL = "gpt-4-vision-preview"
class MultimodalConversableAgent(ConversableAgent):
DEFAULT_CONFIG = {
"model": DEFAULT_MODEL,
}
def __init__(
self,
name: str,
system_message: Optional[Union[str, list]] = DEFAULT_LMM_SYS_MSG,
is_termination_msg: str = None,
*args,
**kwargs: Any,
):
"""Args:
name (str): agent name.
system_message (str): system message for the OpenAIWrapper inference.
Please override this attribute if you want to reprogram the agent.
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](/docs/api-reference/autogen/ConversableAgent#conversableagent).
"""
super().__init__(
name,
system_message,
is_termination_msg=is_termination_msg,
*args,
**kwargs,
)
# call the setter to handle special format.
self.update_system_message(system_message)
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
# Override the `generate_oai_reply`
self.replace_reply_func(ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply)
self.replace_reply_func(
ConversableAgent.a_generate_oai_reply,
MultimodalConversableAgent.a_generate_oai_reply,
)
def update_system_message(self, system_message: Union[dict[str, Any], list[str], str]):
"""Update the system message.
Args:
system_message (str): system message for the OpenAIWrapper inference.
"""
self._oai_system_message[0]["content"] = self._message_to_dict(system_message)["content"]
self._oai_system_message[0]["role"] = "system"
@staticmethod
def _message_to_dict(message: Union[dict[str, Any], list[str], str]) -> dict:
"""Convert a message to a dictionary. This implementation
handles the GPT-4V formatting for easier prompts.
The message can be a string, a dictionary, or a list of dictionaries:
- If it's a string, it will be cast into a list and placed in the 'content' field.
- If it's a list, it will be directly placed in the 'content' field.
- If it's a dictionary, it is already in message dict format. The 'content' field of this dictionary
will be processed using the gpt4v_formatter.
"""
if isinstance(message, str):
return {"content": gpt4v_formatter(message, img_format="pil")}
if isinstance(message, list):
return {"content": message}
if isinstance(message, dict):
assert "content" in message, "The message dict must have a `content` field"
if isinstance(message["content"], str):
message = copy.deepcopy(message)
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
try:
content_str(message["content"])
except (TypeError, ValueError) as e:
print("The `content` field should be compatible with the content_str function!")
raise e
return message
raise ValueError(f"Unsupported message type: {type(message)}")
def generate_oai_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[OpenAIWrapper] = None,
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
"""Generate a reply using autogen.oai."""
client = self.client if config is None else config
if client is None:
return False, None
if messages is None:
messages = self._oai_messages[sender]
messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)
new_messages = []
for message in messages_with_b64_img:
if 'tool_responses' in message:
for tool_response in message['tool_responses']:
tmp_image = None
tmp_list = []
for ctx in message['content']:
if ctx['type'] == 'image_url':
tmp_image = ctx
tmp_list.append({
'role': 'tool',
'tool_call_id': tool_response['tool_call_id'],
'content': [message['content'][0]]
})
if tmp_image:
tmp_list.append({
'role': 'user',
'content': [
{'type': 'text', 'text': 'I take a screenshot for the current state for you.'},
tmp_image
]
})
new_messages.extend(tmp_list)
else:
new_messages.append(message)
messages_with_b64_img = new_messages.copy()
# TODO: #1143 handle token limit exceeded error
response = client.create(
context=messages[-1].pop("context", None), messages=messages_with_b64_img, agent=self.name
)
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
extracted_response = client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str):
extracted_response = extracted_response.model_dump()
return True, extracted_response