CoACT initialize (#292)
This commit is contained in:
5
mm_agents/coact/autogen/agentchat/contrib/__init__.py
Normal file
5
mm_agents/coact/autogen/agentchat/contrib/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
__all__: list[str] = []
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
__all__: list[str] = []
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from ...assistant_agent import ConversableAgent
|
||||
|
||||
|
||||
class AgentCapability:
|
||||
"""Base class for composable capabilities that can be added to an agent."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def add_to_agent(self, agent: ConversableAgent):
|
||||
"""Adds a particular capability to the given agent. Must be implemented by the capability subclass.
|
||||
An implementation will typically call agent.register_hook() one or more times. See teachability.py as an example.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,301 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import re
|
||||
from typing import Any, Literal, Optional, Protocol, Union
|
||||
|
||||
from .... import Agent, ConversableAgent, code_utils
|
||||
from ....cache import AbstractCache
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ....llm_config import LLMConfig
|
||||
from .. import img_utils
|
||||
from ..capabilities.agent_capability import AgentCapability
|
||||
from ..text_analyzer_agent import TextAnalyzerAgent
|
||||
|
||||
with optional_import_block():
|
||||
from PIL.Image import Image
|
||||
from openai import OpenAI
|
||||
|
||||
SYSTEM_MESSAGE = "You've been given the special ability to generate images."
|
||||
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
|
||||
|
||||
PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT.
|
||||
DO NOT include any advice. RESPOND like the following example:
|
||||
EXAMPLE: Blue background, 3D shapes, ...
|
||||
"""
|
||||
|
||||
|
||||
class ImageGenerator(Protocol):
|
||||
"""This class defines an interface for image generators.
|
||||
|
||||
Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
|
||||
input and returns a PIL Image object.
|
||||
|
||||
NOTE: Current implementation does not allow you to edit a previously existing image.
|
||||
"""
|
||||
|
||||
def generate_image(self, prompt: str) -> "Image":
|
||||
"""Generates an image based on the provided prompt.
|
||||
|
||||
Args:
|
||||
prompt: A string describing the desired image.
|
||||
|
||||
Returns:
|
||||
A PIL Image object representing the generated image.
|
||||
|
||||
Raises:
|
||||
ValueError: If the image generation fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def cache_key(self, prompt: str) -> str:
|
||||
"""Generates a unique cache key for the given prompt.
|
||||
|
||||
This key can be used to store and retrieve generated images based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: A string describing the desired image.
|
||||
|
||||
Returns:
|
||||
A unique string that can be used as a cache key.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
@require_optional_import("openai>=1.66.2", "openai")
|
||||
class DalleImageGenerator:
|
||||
"""Generates images using OpenAI's DALL-E models.
|
||||
|
||||
This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E
|
||||
models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate.
|
||||
|
||||
Note: Current implementation does not allow you to edit a previously existing image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
|
||||
quality: Literal["standard", "hd"] = "standard",
|
||||
num_images: int = 1,
|
||||
):
|
||||
"""Args:
|
||||
llm_config (LLMConfig or dict): llm config, must contain a valid dalle model and OpenAI API key in config_list.
|
||||
resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
|
||||
quality (str): The quality of the image you want to generate. Must be one of "standard", "hd".
|
||||
num_images (int): The number of images to generate.
|
||||
"""
|
||||
config_list = llm_config["config_list"]
|
||||
_validate_dalle_model(config_list[0]["model"])
|
||||
_validate_resolution_format(resolution)
|
||||
|
||||
self._model = config_list[0]["model"]
|
||||
self._resolution = resolution
|
||||
self._quality = quality
|
||||
self._num_images = num_images
|
||||
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])
|
||||
|
||||
def generate_image(self, prompt: str) -> "Image":
|
||||
response = self._dalle_client.images.generate(
|
||||
model=self._model,
|
||||
prompt=prompt,
|
||||
size=self._resolution,
|
||||
quality=self._quality,
|
||||
n=self._num_images,
|
||||
)
|
||||
|
||||
image_url = response.data[0].url
|
||||
if image_url is None:
|
||||
raise ValueError("Failed to generate image.")
|
||||
|
||||
return img_utils.get_pil_image(image_url)
|
||||
|
||||
def cache_key(self, prompt: str) -> str:
|
||||
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
|
||||
return ",".join([str(k) for k in keys])
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
class ImageGeneration(AgentCapability):
|
||||
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
|
||||
|
||||
1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and
|
||||
extract relevant details.
|
||||
2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image.
|
||||
3. Optionally caches generated images for faster retrieval in future conversations.
|
||||
|
||||
NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every
|
||||
message received by the agent.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import autogen
|
||||
from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration
|
||||
|
||||
# Assuming you have llm configs configured for the LLMs you want to use and Dalle.
|
||||
# Create the agent
|
||||
agent = autogen.ConversableAgent(
|
||||
name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER"
|
||||
)
|
||||
|
||||
# Create an ImageGenerator with desired settings
|
||||
dalle_gen = generate_images.DalleImageGenerator(llm_config={...})
|
||||
|
||||
# Add the ImageGeneration capability to the agent
|
||||
agent.add_capability(ImageGeneration(image_generator=dalle_gen))
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_generator: ImageGenerator,
|
||||
cache: Optional[AbstractCache] = None,
|
||||
text_analyzer_llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
|
||||
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
|
||||
verbosity: int = 0,
|
||||
register_reply_position: int = 2,
|
||||
):
|
||||
"""Args:
|
||||
image_generator (ImageGenerator): The image generator you would like to use to generate images.
|
||||
cache (None or AbstractCache): The cache client to use to store and retrieve generated images. If None,
|
||||
no caching will be used.
|
||||
text_analyzer_llm_config (LLMConfig or Dict or None): The LLM config for the text analyzer. If None, the LLM config will
|
||||
be retrieved from the agent you're adding the ability to.
|
||||
text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze
|
||||
incoming messages and extract the prompt for image generation. The default instructions focus on
|
||||
summarizing the prompt. You can customize the instructions to achieve more granular control over prompt
|
||||
extraction.
|
||||
Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.'
|
||||
verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text
|
||||
analyzer llm calls will be silent if verbosity is less than 2.
|
||||
register_reply_position (int): The position of the reply function in the agent's list of reply functions.
|
||||
This capability registers a new reply function to handle messages with image generation requests.
|
||||
Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
|
||||
"""
|
||||
self._image_generator = image_generator
|
||||
self._cache = cache
|
||||
self._text_analyzer_llm_config = text_analyzer_llm_config
|
||||
self._text_analyzer_instructions = text_analyzer_instructions
|
||||
self._verbosity = verbosity
|
||||
self._register_reply_position = register_reply_position
|
||||
|
||||
self._agent: Optional[ConversableAgent] = None
|
||||
self._text_analyzer: Optional[TextAnalyzerAgent] = None
|
||||
|
||||
def add_to_agent(self, agent: ConversableAgent):
|
||||
"""Adds the Image Generation capability to the specified ConversableAgent.
|
||||
|
||||
This function performs the following modifications to the agent:
|
||||
|
||||
1. Registers a reply function: A new reply function is registered with the agent to handle messages that
|
||||
potentially request image generation. This function analyzes the message and triggers image generation if
|
||||
necessary.
|
||||
2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements.
|
||||
3. Updates System Message: The agent's system message is updated to include a message indicating the
|
||||
capability to generate images has been added.
|
||||
4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation
|
||||
capability. This might be helpful in certain use cases, like group chats.
|
||||
|
||||
Args:
|
||||
agent (ConversableAgent): The ConversableAgent to add the capability to.
|
||||
"""
|
||||
self._agent = agent
|
||||
|
||||
agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position)
|
||||
|
||||
self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config
|
||||
self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config)
|
||||
|
||||
agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE)
|
||||
agent.description += "\n" + DESCRIPTION_MESSAGE
|
||||
|
||||
def _image_gen_reply(
|
||||
self,
|
||||
recipient: ConversableAgent,
|
||||
messages: Optional[list[dict[str, Any]]],
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
|
||||
if messages is None:
|
||||
return False, None
|
||||
|
||||
last_message = code_utils.content_str(messages[-1]["content"])
|
||||
|
||||
if not last_message:
|
||||
return False, None
|
||||
|
||||
if self._should_generate_image(last_message):
|
||||
prompt = self._extract_prompt(last_message)
|
||||
|
||||
image = self._cache_get(prompt)
|
||||
if image is None:
|
||||
image = self._image_generator.generate_image(prompt)
|
||||
self._cache_set(prompt, image)
|
||||
|
||||
return True, self._generate_content_message(prompt, image)
|
||||
|
||||
else:
|
||||
return False, None
|
||||
|
||||
def _should_generate_image(self, message: str) -> bool:
|
||||
assert self._text_analyzer is not None
|
||||
|
||||
instructions = """
|
||||
Does any part of the TEXT ask the agent to generate an image?
|
||||
The TEXT must explicitly mention that the image must be generated.
|
||||
Answer with just one word, yes or no.
|
||||
"""
|
||||
analysis = self._text_analyzer.analyze_text(message, instructions)
|
||||
|
||||
return "yes" in self._extract_analysis(analysis).lower()
|
||||
|
||||
def _extract_prompt(self, last_message) -> str:
|
||||
assert self._text_analyzer is not None
|
||||
|
||||
analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
|
||||
return self._extract_analysis(analysis)
|
||||
|
||||
def _cache_get(self, prompt: str) -> Optional["Image"]:
|
||||
if self._cache:
|
||||
key = self._image_generator.cache_key(prompt)
|
||||
cached_value = self._cache.get(key)
|
||||
|
||||
if cached_value:
|
||||
return img_utils.get_pil_image(cached_value)
|
||||
|
||||
def _cache_set(self, prompt: str, image: "Image"):
|
||||
if self._cache:
|
||||
key = self._image_generator.cache_key(prompt)
|
||||
self._cache.set(key, img_utils.pil_to_data_uri(image))
|
||||
|
||||
def _extract_analysis(self, analysis: Optional[Union[str, dict[str, Any]]]) -> str:
|
||||
if isinstance(analysis, dict):
|
||||
return code_utils.content_str(analysis["content"])
|
||||
else:
|
||||
return code_utils.content_str(analysis)
|
||||
|
||||
def _generate_content_message(self, prompt: str, image: "Image") -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
|
||||
{"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# Helpers
|
||||
def _validate_resolution_format(resolution: str):
|
||||
"""Checks if a string is in a valid resolution format (e.g., "1024x768")."""
|
||||
pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits
|
||||
matched_resolution = re.match(pattern, resolution)
|
||||
if matched_resolution is None:
|
||||
raise ValueError(f"Invalid resolution format: {resolution}")
|
||||
|
||||
|
||||
def _validate_dalle_model(model: str):
|
||||
if model not in ["dall-e-3", "dall-e-2"]:
|
||||
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")
|
||||
@@ -0,0 +1,393 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from ....formatting_utils import colored
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ....llm_config import LLMConfig
|
||||
from ...assistant_agent import ConversableAgent
|
||||
from ..text_analyzer_agent import TextAnalyzerAgent
|
||||
from .agent_capability import AgentCapability
|
||||
|
||||
with optional_import_block():
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
|
||||
class Teachability(AgentCapability):
|
||||
"""Teachability uses a vector database to give an agent the ability to remember user teachings,
|
||||
where the user is any caller (human or not) sending messages to the teachable agent.
|
||||
Teachability is designed to be composable with other agent capabilities.
|
||||
To make any conversable agent teachable, instantiate both the agent and the Teachability class,
|
||||
then pass the agent to teachability.add_to_agent(agent).
|
||||
Note that teachable agents in a group chat must be given unique path_to_db_dir values.
|
||||
|
||||
When adding Teachability to an agent, the following are modified:
|
||||
- The agent's system message is appended with a note about the agent's new ability.
|
||||
- A hook is added to the agent's `process_last_received_message` hookable method,
|
||||
and the hook potentially modifies the last of the received messages to include earlier teachings related to the message.
|
||||
Added teachings do not propagate into the stored message history.
|
||||
If new user teachings are detected, they are added to new memos in the vector database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbosity: Optional[int] = 0,
|
||||
reset_db: Optional[bool] = False,
|
||||
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
|
||||
recall_threshold: Optional[float] = 1.5,
|
||||
max_num_retrievals: Optional[int] = 10,
|
||||
llm_config: Optional[Union[LLMConfig, dict[str, Any], bool]] = None,
|
||||
):
|
||||
"""Args:
|
||||
verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
|
||||
reset_db (Optional, bool): True to clear the DB before starting. Default False.
|
||||
path_to_db_dir (Optional, str): path to the directory where this particular agent's DB is stored. Default "./tmp/teachable_agent_db"
|
||||
recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
|
||||
max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
|
||||
llm_config (LLMConfig or dict or False): llm inference configuration passed to TextAnalyzerAgent.
|
||||
If None, TextAnalyzerAgent uses llm_config from the teachable agent.
|
||||
"""
|
||||
self.verbosity = verbosity
|
||||
self.path_to_db_dir = path_to_db_dir
|
||||
self.recall_threshold = recall_threshold
|
||||
self.max_num_retrievals = max_num_retrievals
|
||||
self.llm_config = llm_config
|
||||
|
||||
self.analyzer = None
|
||||
self.teachable_agent = None
|
||||
|
||||
# Create the memo store.
|
||||
self.memo_store = MemoStore(self.verbosity, reset_db, self.path_to_db_dir)
|
||||
|
||||
def add_to_agent(self, agent: ConversableAgent):
|
||||
"""Adds teachability to the given agent."""
|
||||
self.teachable_agent = agent
|
||||
|
||||
# Register a hook for processing the last message.
|
||||
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
|
||||
|
||||
# Was an llm_config passed to the constructor?
|
||||
if self.llm_config is None:
|
||||
# No. Use the agent's llm_config.
|
||||
self.llm_config = agent.llm_config
|
||||
assert self.llm_config, "Teachability requires a valid llm_config."
|
||||
|
||||
# Create the analyzer agent.
|
||||
self.analyzer = TextAnalyzerAgent(llm_config=self.llm_config)
|
||||
|
||||
# Append extra info to the system message.
|
||||
agent.update_system_message(
|
||||
agent.system_message
|
||||
+ "\nYou've been given the special ability to remember user teachings from prior conversations."
|
||||
)
|
||||
|
||||
def prepopulate_db(self):
|
||||
"""Adds a few arbitrary memos to the DB."""
|
||||
self.memo_store.prepopulate()
|
||||
|
||||
def process_last_received_message(self, text: Union[dict[str, Any], str]):
|
||||
"""Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
|
||||
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
|
||||
"""
|
||||
# Try to retrieve relevant memos from the DB.
|
||||
expanded_text = text
|
||||
if self.memo_store.last_memo_id > 0:
|
||||
expanded_text = self._consider_memo_retrieval(text)
|
||||
|
||||
# Try to store any user teachings in new memos to be used in the future.
|
||||
self._consider_memo_storage(text)
|
||||
|
||||
# Return the (possibly) expanded message text.
|
||||
return expanded_text
|
||||
|
||||
def _consider_memo_storage(self, comment: Union[dict[str, Any], str]):
|
||||
"""Decides whether to store something from one user comment in the DB."""
|
||||
memo_added = False
|
||||
|
||||
# Check for a problem-solution pair.
|
||||
response = self._analyze(
|
||||
comment,
|
||||
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
|
||||
)
|
||||
if "yes" in response.lower():
|
||||
# Can we extract advice?
|
||||
advice = self._analyze(
|
||||
comment,
|
||||
"Briefly copy any advice from the TEXT that may be useful for a similar but different task in the future. But if no advice is present, just respond with 'none'.",
|
||||
)
|
||||
if "none" not in advice.lower():
|
||||
# Yes. Extract the task.
|
||||
task = self._analyze(
|
||||
comment,
|
||||
"Briefly copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice.",
|
||||
)
|
||||
# Generalize the task.
|
||||
general_task = self._analyze(
|
||||
task,
|
||||
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
|
||||
)
|
||||
# Add the task-advice (problem-solution) pair to the vector DB.
|
||||
if self.verbosity >= 1:
|
||||
print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow"))
|
||||
self.memo_store.add_input_output_pair(general_task, advice)
|
||||
memo_added = True
|
||||
|
||||
# Check for information to be learned.
|
||||
response = self._analyze(
|
||||
comment,
|
||||
"Does the TEXT contain information that could be committed to memory? Answer with just one word, yes or no.",
|
||||
)
|
||||
if "yes" in response.lower():
|
||||
# Yes. What question would this information answer?
|
||||
question = self._analyze(
|
||||
comment,
|
||||
"Imagine that the user forgot this information in the TEXT. How would they ask you for this information? Include no other text in your response.",
|
||||
)
|
||||
# Extract the information.
|
||||
answer = self._analyze(
|
||||
comment, "Copy the information from the TEXT that should be committed to memory. Add no explanation."
|
||||
)
|
||||
# Add the question-answer pair to the vector DB.
|
||||
if self.verbosity >= 1:
|
||||
print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow"))
|
||||
self.memo_store.add_input_output_pair(question, answer)
|
||||
memo_added = True
|
||||
|
||||
# Were any memos added?
|
||||
if memo_added:
|
||||
# Yes. Save them to disk.
|
||||
self.memo_store._save_memos()
|
||||
|
||||
def _consider_memo_retrieval(self, comment: Union[dict[str, Any], str]):
|
||||
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""
|
||||
# First, use the comment directly as the lookup key.
|
||||
if self.verbosity >= 1:
|
||||
print(colored("\nLOOK FOR RELEVANT MEMOS, AS QUESTION-ANSWER PAIRS", "light_yellow"))
|
||||
memo_list = self._retrieve_relevant_memos(comment)
|
||||
|
||||
# Next, if the comment involves a task, then extract and generalize the task before using it as the lookup key.
|
||||
response = self._analyze(
|
||||
comment,
|
||||
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
|
||||
)
|
||||
if "yes" in response.lower():
|
||||
if self.verbosity >= 1:
|
||||
print(colored("\nLOOK FOR RELEVANT MEMOS, AS TASK-ADVICE PAIRS", "light_yellow"))
|
||||
# Extract the task.
|
||||
task = self._analyze(
|
||||
comment, "Copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice."
|
||||
)
|
||||
# Generalize the task.
|
||||
general_task = self._analyze(
|
||||
task,
|
||||
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
|
||||
)
|
||||
# Append any relevant memos.
|
||||
memo_list.extend(self._retrieve_relevant_memos(general_task))
|
||||
|
||||
# De-duplicate the memo list.
|
||||
memo_list = list(set(memo_list))
|
||||
|
||||
# Append the memos to the text of the last message.
|
||||
return comment + self._concatenate_memo_texts(memo_list)
|
||||
|
||||
def _retrieve_relevant_memos(self, input_text: str) -> list:
|
||||
"""Returns semantically related memos from the DB."""
|
||||
memo_list = self.memo_store.get_related_memos(
|
||||
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
|
||||
)
|
||||
|
||||
if self.verbosity >= 1: # noqa: SIM102
|
||||
# Was anything retrieved?
|
||||
if len(memo_list) == 0:
|
||||
# No. Look at the closest memo.
|
||||
print(colored("\nTHE CLOSEST MEMO IS BEYOND THE THRESHOLD:", "light_yellow"))
|
||||
self.memo_store.get_nearest_memo(input_text)
|
||||
print() # Print a blank line. The memo details were printed by get_nearest_memo().
|
||||
|
||||
# Create a list of just the memo output_text strings.
|
||||
memo_list = [memo[1] for memo in memo_list]
|
||||
return memo_list
|
||||
|
||||
def _concatenate_memo_texts(self, memo_list: list) -> str:
|
||||
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
|
||||
memo_texts = ""
|
||||
if len(memo_list) > 0:
|
||||
info = "\n# Memories that might help\n"
|
||||
for memo in memo_list:
|
||||
info = info + "- " + memo + "\n"
|
||||
if self.verbosity >= 1:
|
||||
print(colored("\nMEMOS APPENDED TO LAST MESSAGE...\n" + info + "\n", "light_yellow"))
|
||||
memo_texts = memo_texts + "\n" + info
|
||||
return memo_texts
|
||||
|
||||
def _analyze(self, text_to_analyze: Union[dict[str, Any], str], analysis_instructions: Union[dict[str, Any], str]):
|
||||
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
|
||||
self.analyzer.reset() # Clear the analyzer's list of messages.
|
||||
self.teachable_agent.send(
|
||||
recipient=self.analyzer, message=text_to_analyze, request_reply=False, silent=(self.verbosity < 2)
|
||||
) # Put the message in the analyzer's list.
|
||||
self.teachable_agent.send(
|
||||
recipient=self.analyzer, message=analysis_instructions, request_reply=True, silent=(self.verbosity < 2)
|
||||
) # Request the reply.
|
||||
return self.teachable_agent.last_message(self.analyzer)["content"]
|
||||
|
||||
|
||||
@require_optional_import("chromadb", "teachable")
|
||||
class MemoStore:
|
||||
"""Provides memory storage and retrieval for a teachable agent, using a vector database.
|
||||
Each DB entry (called a memo) is a pair of strings: an input text and an output text.
|
||||
The input text might be a question, or a task to perform.
|
||||
The output text might be an answer to the question, or advice on how to perform the task.
|
||||
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbosity: Optional[int] = 0,
|
||||
reset: Optional[bool] = False,
|
||||
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
|
||||
):
|
||||
"""Args:
|
||||
- verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
|
||||
- reset (Optional, bool): True to clear the DB before starting. Default False.
|
||||
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
|
||||
"""
|
||||
self.verbosity = verbosity
|
||||
self.path_to_db_dir = path_to_db_dir
|
||||
|
||||
# Load or create the vector DB on disk.
|
||||
settings = Settings(
|
||||
anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
|
||||
)
|
||||
self.db_client = chromadb.Client(settings)
|
||||
self.vec_db = self.db_client.create_collection("memos", get_or_create=True) # The collection is the DB.
|
||||
|
||||
# Load or create the associated memo dict on disk.
|
||||
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
|
||||
self.uid_text_dict = {}
|
||||
self.last_memo_id = 0
|
||||
if (not reset) and os.path.exists(self.path_to_dict):
|
||||
print(colored("\nLOADING MEMORY FROM DISK", "light_green"))
|
||||
print(colored(f" Location = {self.path_to_dict}", "light_green"))
|
||||
with open(self.path_to_dict, "rb") as f:
|
||||
self.uid_text_dict = pickle.load(f)
|
||||
self.last_memo_id = len(self.uid_text_dict)
|
||||
if self.verbosity >= 3:
|
||||
self.list_memos()
|
||||
|
||||
# Clear the DB if requested.
|
||||
if reset:
|
||||
self.reset_db()
|
||||
|
||||
def list_memos(self):
|
||||
"""Prints the contents of MemoStore."""
|
||||
print(colored("LIST OF MEMOS", "light_green"))
|
||||
for uid, text in self.uid_text_dict.items():
|
||||
input_text, output_text = text
|
||||
print(
|
||||
colored(
|
||||
f" ID: {uid}\n INPUT TEXT: {input_text}\n OUTPUT TEXT: {output_text}",
|
||||
"light_green",
|
||||
)
|
||||
)
|
||||
|
||||
def _save_memos(self):
|
||||
"""Saves self.uid_text_dict to disk."""
|
||||
with open(self.path_to_dict, "wb") as file:
|
||||
pickle.dump(self.uid_text_dict, file)
|
||||
|
||||
def reset_db(self):
|
||||
"""Forces immediate deletion of the DB's contents, in memory and on disk."""
|
||||
print(colored("\nCLEARING MEMORY", "light_green"))
|
||||
self.db_client.delete_collection("memos")
|
||||
self.vec_db = self.db_client.create_collection("memos")
|
||||
self.uid_text_dict = {}
|
||||
self._save_memos()
|
||||
|
||||
def add_input_output_pair(self, input_text: str, output_text: str):
|
||||
"""Adds an input-output pair to the vector DB."""
|
||||
self.last_memo_id += 1
|
||||
self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
|
||||
self.uid_text_dict[str(self.last_memo_id)] = input_text, output_text
|
||||
if self.verbosity >= 1:
|
||||
print(
|
||||
colored(
|
||||
f"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {self.last_memo_id}\n INPUT\n {input_text}\n OUTPUT\n {output_text}\n",
|
||||
"light_yellow",
|
||||
)
|
||||
)
|
||||
if self.verbosity >= 3:
|
||||
self.list_memos()
|
||||
|
||||
def get_nearest_memo(self, query_text: str):
|
||||
"""Retrieves the nearest memo to the given query text."""
|
||||
results = self.vec_db.query(query_texts=[query_text], n_results=1)
|
||||
uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
|
||||
input_text_2, output_text = self.uid_text_dict[uid]
|
||||
assert input_text == input_text_2
|
||||
if self.verbosity >= 1:
|
||||
print(
|
||||
colored(
|
||||
f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}",
|
||||
"light_yellow",
|
||||
)
|
||||
)
|
||||
return input_text, output_text, distance
|
||||
|
||||
def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]):
|
||||
"""Retrieves memos that are related to the given query text within the specified distance threshold."""
|
||||
if n_results > len(self.uid_text_dict):
|
||||
n_results = len(self.uid_text_dict)
|
||||
results = self.vec_db.query(query_texts=[query_text], n_results=n_results)
|
||||
memos = []
|
||||
num_results = len(results["ids"][0])
|
||||
for i in range(num_results):
|
||||
uid, input_text, distance = results["ids"][0][i], results["documents"][0][i], results["distances"][0][i]
|
||||
if distance < threshold:
|
||||
input_text_2, output_text = self.uid_text_dict[uid]
|
||||
assert input_text == input_text_2
|
||||
if self.verbosity >= 1:
|
||||
print(
|
||||
colored(
|
||||
f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}",
|
||||
"light_yellow",
|
||||
)
|
||||
)
|
||||
memos.append((input_text, output_text, distance))
|
||||
return memos
|
||||
|
||||
def prepopulate(self):
|
||||
"""Adds a few arbitrary examples to the vector DB, just to make retrieval less trivial."""
|
||||
if self.verbosity >= 1:
|
||||
print(colored("\nPREPOPULATING MEMORY", "light_green"))
|
||||
examples = []
|
||||
examples.append({"text": "When I say papers I mean research papers, which are typically pdfs.", "label": "yes"})
|
||||
examples.append({"text": "Please verify that each paper you listed actually uses langchain.", "label": "no"})
|
||||
examples.append({"text": "Tell gpt the output should still be latex code.", "label": "no"})
|
||||
examples.append({"text": "Hint: convert pdfs to text and then answer questions based on them.", "label": "yes"})
|
||||
examples.append({
|
||||
"text": "To create a good PPT, include enough content to make it interesting.",
|
||||
"label": "yes",
|
||||
})
|
||||
examples.append({
|
||||
"text": "No, for this case the columns should be aspects and the rows should be frameworks.",
|
||||
"label": "no",
|
||||
})
|
||||
examples.append({"text": "When writing code, remember to include any libraries that are used.", "label": "yes"})
|
||||
examples.append({"text": "Please summarize the papers by Eric Horvitz on bounded rationality.", "label": "no"})
|
||||
examples.append({"text": "Compare the h-index of Daniel Weld and Oren Etzioni.", "label": "no"})
|
||||
examples.append({
|
||||
"text": "Double check to be sure that the columns in a table correspond to what was asked for.",
|
||||
"label": "yes",
|
||||
})
|
||||
for example in examples:
|
||||
self.add_input_output_pair(example["text"], example["label"])
|
||||
self._save_memos()
|
||||
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from typing import Any, Protocol
|
||||
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
|
||||
with optional_import_block() as result:
|
||||
import llmlingua
|
||||
from llmlingua import PromptCompressor
|
||||
|
||||
|
||||
class TextCompressor(Protocol):
|
||||
"""Defines a protocol for text compression to optimize agent interactions."""
|
||||
|
||||
def compress_text(self, text: str, **compression_params) -> dict[str, Any]:
|
||||
"""This method takes a string as input and returns a dictionary containing the compressed text and other
|
||||
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
|
||||
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@require_optional_import("llmlingua", "long-context")
|
||||
class LLMLingua:
|
||||
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
|
||||
|
||||
NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
|
||||
and the specific configurations used for the PromptCompressor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_compressor_kwargs: dict = dict(
|
||||
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
|
||||
use_llmlingua2=True,
|
||||
device_map="cpu",
|
||||
),
|
||||
structured_compression: bool = False,
|
||||
) -> None:
|
||||
"""Args:
|
||||
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
|
||||
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
|
||||
use_llmlingua2 set to True, and device_map set to "cpu".
|
||||
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
|
||||
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
|
||||
is used. Defaults to False.
|
||||
dictionary.
|
||||
|
||||
Raises:
|
||||
ImportError: If the llmlingua library is not installed.
|
||||
"""
|
||||
self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)
|
||||
|
||||
assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
|
||||
self._compression_method = (
|
||||
self._prompt_compressor.structured_compress_prompt
|
||||
if structured_compression
|
||||
else self._prompt_compressor.compress_prompt
|
||||
)
|
||||
|
||||
def compress_text(self, text: str, **compression_params) -> dict[str, Any]:
|
||||
return self._compression_method([text], **compression_params)
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from ....agentchat import ConversableAgent
|
||||
from ....tools import Tool
|
||||
|
||||
|
||||
class ToolsCapability:
|
||||
"""Adding a list of tools as composable capabilities to a single agent.
|
||||
This class can be inherited from to allow code to run at the point of creating or adding the capability.
|
||||
|
||||
Note: both caller and executor of the tools are the same agent.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_list: list[Tool]):
|
||||
self.tools = [tool for tool in tool_list]
|
||||
|
||||
def add_to_agent(self, agent: ConversableAgent):
|
||||
"""Add tools to the given agent."""
|
||||
for tool in self.tools:
|
||||
tool.register_tool(agent=agent)
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ....formatting_utils import colored
|
||||
from .transforms import MessageTransform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...conversable_agent import ConversableAgent
|
||||
|
||||
|
||||
class TransformMessages:
|
||||
"""Agent capability for transforming messages before reply generation.
|
||||
|
||||
This capability allows you to apply a series of message transformations to
|
||||
a ConversableAgent's incoming messages before they are processed for response
|
||||
generation. This is useful for tasks such as:
|
||||
|
||||
- Limiting the number of messages considered for context.
|
||||
- Truncating messages to meet token limits.
|
||||
- Filtering sensitive information.
|
||||
- Customizing message formatting.
|
||||
|
||||
To use `TransformMessages`:
|
||||
|
||||
1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`).
|
||||
2. Instantiate `TransformMessages` with a list of these transformations.
|
||||
3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`.
|
||||
|
||||
NOTE: Order of message transformations is important. You could get different results based on
|
||||
the order of transformations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from agentchat import ConversableAgent
|
||||
from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter
|
||||
|
||||
max_messages = MessageHistoryLimiter(max_messages=2)
|
||||
truncate_messages = MessageTokenLimiter(max_tokens=500)
|
||||
transform_messages = TransformMessages(transforms=[max_messages, truncate_messages])
|
||||
|
||||
agent = ConversableAgent(...)
|
||||
transform_messages.add_to_agent(agent)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *, transforms: list[MessageTransform] = [], verbose: bool = True):
|
||||
"""Args:
|
||||
transforms: A list of message transformations to apply.
|
||||
verbose: Whether to print logs of each transformation or not.
|
||||
"""
|
||||
self._transforms = transforms
|
||||
self._verbose = verbose
|
||||
|
||||
def add_to_agent(self, agent: "ConversableAgent"):
|
||||
"""Adds the message transformations capability to the specified ConversableAgent.
|
||||
|
||||
This function performs the following modifications to the agent:
|
||||
|
||||
1. Registers a hook that automatically transforms all messages before they are processed for
|
||||
response generation.
|
||||
"""
|
||||
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
|
||||
|
||||
def _transform_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
post_transform_messages = copy.deepcopy(messages)
|
||||
system_message = None
|
||||
|
||||
if messages[0]["role"] == "system":
|
||||
system_message = copy.deepcopy(messages[0])
|
||||
post_transform_messages.pop(0)
|
||||
|
||||
for transform in self._transforms:
|
||||
# deepcopy in case pre_transform_messages will later be used for logs printing
|
||||
pre_transform_messages = (
|
||||
copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages
|
||||
)
|
||||
post_transform_messages = transform.apply_transform(pre_transform_messages)
|
||||
|
||||
if self._verbose:
|
||||
logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages)
|
||||
if had_effect:
|
||||
print(colored(logs_str, "yellow"))
|
||||
|
||||
if system_message:
|
||||
post_transform_messages.insert(0, system_message)
|
||||
|
||||
return post_transform_messages
|
||||
@@ -0,0 +1,579 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import copy
|
||||
import sys
|
||||
from typing import Any, Optional, Protocol, Union
|
||||
|
||||
import tiktoken
|
||||
from termcolor import colored
|
||||
|
||||
from .... import token_count_utils
|
||||
from ....cache import AbstractCache, Cache
|
||||
from ....types import MessageContentType
|
||||
from . import transforms_util
|
||||
from .text_compressors import LLMLingua, TextCompressor
|
||||
|
||||
|
||||
class MessageTransform(Protocol):
|
||||
"""Defines a contract for message transformation.
|
||||
|
||||
Classes implementing this protocol should provide an `apply_transform` method
|
||||
that takes a list of messages and returns the transformed list.
|
||||
"""
|
||||
|
||||
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Applies a transformation to a list of messages.
|
||||
|
||||
Args:
|
||||
messages: A list of dictionaries representing messages.
|
||||
|
||||
Returns:
|
||||
A new list of dictionaries containing the transformed messages.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_logs(
|
||||
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||
) -> tuple[str, bool]:
|
||||
"""Creates the string including the logs of the transformation
|
||||
|
||||
Alongside the string, it returns a boolean indicating whether the transformation had an effect or not.
|
||||
|
||||
Args:
|
||||
pre_transform_messages: A list of dictionaries representing messages before the transformation.
|
||||
post_transform_messages: A list of dictionaries representig messages after the transformation.
|
||||
|
||||
Returns:
|
||||
A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class MessageHistoryLimiter:
|
||||
"""Limits the number of messages considered by an agent for response generation.
|
||||
|
||||
This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages).
|
||||
It trims the conversation history by removing older messages, retaining only the most recent messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_messages: Optional[int] = None,
|
||||
keep_first_message: bool = False,
|
||||
exclude_names: Optional[list[str]] = None,
|
||||
):
|
||||
"""Args:
|
||||
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
|
||||
keep_first_message bool: Whether to keep the original first message in the conversation history.
|
||||
Defaults to False.
|
||||
exclude_names Optional[list[str]]: List of message sender names to exclude from the message history.
|
||||
Messages from these senders will be filtered out before applying the message limit. Defaults to None.
|
||||
"""
|
||||
self._validate_max_messages(max_messages)
|
||||
self._max_messages = max_messages
|
||||
self._keep_first_message = keep_first_message
|
||||
self._exclude_names = exclude_names
|
||||
|
||||
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Truncates the conversation history to the specified maximum number of messages.
|
||||
|
||||
This method returns a new list containing the most recent messages up to the specified
|
||||
maximum number of messages (max_messages). If max_messages is None, it returns the
|
||||
original list of messages unmodified.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): The list of messages representing the conversation history.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A new list containing the most recent messages up to the specified maximum.
|
||||
"""
|
||||
|
||||
exclude_names = getattr(self, "_exclude_names", None)
|
||||
|
||||
filtered = [msg for msg in messages if msg.get("name") not in exclude_names] if exclude_names else messages
|
||||
|
||||
if self._max_messages is None or len(filtered) <= self._max_messages:
|
||||
return filtered
|
||||
|
||||
truncated_messages = []
|
||||
remaining_count = self._max_messages
|
||||
|
||||
# Start with the first message if we need to keep it
|
||||
if self._keep_first_message and filtered:
|
||||
truncated_messages = [filtered[0]]
|
||||
remaining_count -= 1
|
||||
|
||||
# Loop through messages in reverse
|
||||
for i in range(len(filtered) - 1, 0, -1):
|
||||
if remaining_count > 1:
|
||||
truncated_messages.insert(1 if self._keep_first_message else 0, filtered[i])
|
||||
if remaining_count == 1: # noqa: SIM102
|
||||
# If there's only 1 slot left and it's a 'tools' message, ignore it.
|
||||
if filtered[i].get("role") != "tool":
|
||||
truncated_messages.insert(1, filtered[i])
|
||||
|
||||
remaining_count -= 1
|
||||
if remaining_count == 0:
|
||||
break
|
||||
|
||||
return truncated_messages
|
||||
|
||||
def get_logs(
|
||||
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||
) -> tuple[str, bool]:
|
||||
pre_transform_messages_len = len(pre_transform_messages)
|
||||
post_transform_messages_len = len(post_transform_messages)
|
||||
|
||||
if post_transform_messages_len < pre_transform_messages_len:
|
||||
logs_str = (
|
||||
f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. "
|
||||
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}."
|
||||
)
|
||||
return logs_str, True
|
||||
return "No messages were removed.", False
|
||||
|
||||
def _validate_max_messages(self, max_messages: Optional[int]):
|
||||
if max_messages is not None and max_messages < 1:
|
||||
raise ValueError("max_messages must be None or greater than 1")
|
||||
|
||||
|
||||
class MessageTokenLimiter:
|
||||
"""Truncates messages to meet token limits for efficient processing and response generation.
|
||||
|
||||
This transformation applies two levels of truncation to the conversation history:
|
||||
|
||||
1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message.
|
||||
2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens.
|
||||
|
||||
NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token
|
||||
counts for the same text.
|
||||
|
||||
NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input
|
||||
(e.g images).
|
||||
|
||||
The truncation process follows these steps in order:
|
||||
|
||||
1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages
|
||||
is less than this threshold, then the messages are returned as is. In other case, the following process is applied.
|
||||
2. Messages are processed in reverse order (newest to oldest).
|
||||
3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
|
||||
and other types of content, only the text content is truncated.
|
||||
4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
|
||||
exceeds this limit, the current message being processed get truncated to meet the total token count and any
|
||||
remaining messages get discarded.
|
||||
5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
|
||||
original message order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_per_message: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
min_tokens: Optional[int] = None,
|
||||
model: str = "gpt-3.5-turbo-0613",
|
||||
filter_dict: Optional[dict[str, Any]] = None,
|
||||
exclude_filter: bool = True,
|
||||
):
|
||||
"""Args:
|
||||
max_tokens_per_message (None or int): Maximum number of tokens to keep in each message.
|
||||
Must be greater than or equal to 0 if not None.
|
||||
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
|
||||
Must be greater than or equal to 0 if not None.
|
||||
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
|
||||
Must be greater than or equal to 0 if not None.
|
||||
model (str): The target OpenAI model for tokenization alignment.
|
||||
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
|
||||
If None, no filters will be applied.
|
||||
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
|
||||
excluded from token truncation. If False, messages that match the filter will be truncated.
|
||||
"""
|
||||
self._model = model
|
||||
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
|
||||
self._max_tokens = self._validate_max_tokens(max_tokens)
|
||||
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
|
||||
self._filter_dict = filter_dict
|
||||
self._exclude_filter = exclude_filter
|
||||
|
||||
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Applies token truncation to the conversation history.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): The list of messages representing the conversation history.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A new list containing the truncated messages up to the specified token limits.
|
||||
"""
|
||||
assert self._max_tokens_per_message is not None
|
||||
assert self._max_tokens is not None
|
||||
assert self._min_tokens is not None
|
||||
|
||||
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
|
||||
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
|
||||
return messages
|
||||
|
||||
temp_messages = copy.deepcopy(messages)
|
||||
processed_messages = []
|
||||
processed_messages_tokens = 0
|
||||
|
||||
for msg in reversed(temp_messages):
|
||||
# Some messages may not have content.
|
||||
if not transforms_util.is_content_right_type(msg.get("content")):
|
||||
processed_messages.insert(0, msg)
|
||||
continue
|
||||
|
||||
if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
|
||||
processed_messages.insert(0, msg)
|
||||
processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
|
||||
continue
|
||||
|
||||
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
|
||||
|
||||
# If adding this message would exceed the token limit, truncate the last message to meet the total token
|
||||
# limit and discard all remaining messages
|
||||
if expected_tokens_remained < 0:
|
||||
msg["content"] = self._truncate_str_to_tokens(
|
||||
msg["content"], self._max_tokens - processed_messages_tokens
|
||||
)
|
||||
processed_messages.insert(0, msg)
|
||||
break
|
||||
|
||||
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
|
||||
msg_tokens = transforms_util.count_text_tokens(msg["content"])
|
||||
|
||||
# prepend the message to the list to preserve order
|
||||
processed_messages_tokens += msg_tokens
|
||||
processed_messages.insert(0, msg)
|
||||
|
||||
return processed_messages
|
||||
|
||||
def get_logs(
|
||||
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||
) -> tuple[str, bool]:
|
||||
pre_transform_messages_tokens = sum(
|
||||
transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
|
||||
)
|
||||
post_transform_messages_tokens = sum(
|
||||
transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
|
||||
)
|
||||
|
||||
if post_transform_messages_tokens < pre_transform_messages_tokens:
|
||||
logs_str = (
|
||||
f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. "
|
||||
f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}"
|
||||
)
|
||||
return logs_str, True
|
||||
return "No tokens were truncated.", False
|
||||
|
||||
def _truncate_str_to_tokens(self, contents: Union[str, list], n_tokens: int) -> Union[str, list]:
|
||||
if isinstance(contents, str):
|
||||
return self._truncate_tokens(contents, n_tokens)
|
||||
elif isinstance(contents, list):
|
||||
return self._truncate_multimodal_text(contents, n_tokens)
|
||||
else:
|
||||
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")
|
||||
|
||||
def _truncate_multimodal_text(self, contents: list[dict[str, Any]], n_tokens: int) -> list[dict[str, Any]]:
|
||||
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
|
||||
tmp_contents = []
|
||||
for content in contents:
|
||||
if content["type"] == "text":
|
||||
truncated_text = self._truncate_tokens(content["text"], n_tokens)
|
||||
tmp_contents.append({"type": "text", "text": truncated_text})
|
||||
else:
|
||||
tmp_contents.append(content)
|
||||
return tmp_contents
|
||||
|
||||
def _truncate_tokens(self, text: str, n_tokens: int) -> str:
|
||||
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer
|
||||
|
||||
encoded_tokens = encoding.encode(text)
|
||||
truncated_tokens = encoded_tokens[:n_tokens]
|
||||
truncated_text = encoding.decode(truncated_tokens) # Decode back to text
|
||||
|
||||
return truncated_text
|
||||
|
||||
def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]:
|
||||
if max_tokens is not None and max_tokens < 0:
|
||||
raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0")
|
||||
|
||||
try:
|
||||
allowed_tokens = token_count_utils.get_max_token_limit(self._model)
|
||||
except Exception:
|
||||
print(colored(f"Model {self._model} not found in token_count_utils.", "yellow"))
|
||||
allowed_tokens = None
|
||||
|
||||
if max_tokens is not None and allowed_tokens is not None and max_tokens > allowed_tokens:
|
||||
print(
|
||||
colored(
|
||||
f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.",
|
||||
"yellow",
|
||||
)
|
||||
)
|
||||
return allowed_tokens
|
||||
|
||||
return max_tokens if max_tokens is not None else sys.maxsize
|
||||
|
||||
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
|
||||
if min_tokens is None:
|
||||
return 0
|
||||
if min_tokens < 0:
|
||||
raise ValueError("min_tokens must be None or greater than or equal to 0.")
|
||||
if max_tokens is not None and min_tokens > max_tokens:
|
||||
raise ValueError("min_tokens must not be more than max_tokens.")
|
||||
return min_tokens
|
||||
|
||||
|
||||
class TextMessageCompressor:
|
||||
"""A transform for compressing text messages in a conversation history.
|
||||
|
||||
It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
|
||||
processing and response generation by downstream models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_compressor: Optional[TextCompressor] = None,
|
||||
min_tokens: Optional[int] = None,
|
||||
compression_params: dict = dict(),
|
||||
cache: Optional[AbstractCache] = None,
|
||||
filter_dict: Optional[dict[str, Any]] = None,
|
||||
exclude_filter: bool = True,
|
||||
):
|
||||
"""Args:
|
||||
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
|
||||
protocol. If None, it defaults to LLMLingua.
|
||||
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
|
||||
than or equal to 0 if not None. If None, no threshold-based compression is applied.
|
||||
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
|
||||
dictionary.
|
||||
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
|
||||
If None, no caching will be used.
|
||||
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
|
||||
If None, no filters will be applied.
|
||||
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
|
||||
excluded from compression. If False, messages that match the filter will be compressed.
|
||||
"""
|
||||
if text_compressor is None:
|
||||
text_compressor = LLMLingua()
|
||||
|
||||
self._validate_min_tokens(min_tokens)
|
||||
|
||||
self._text_compressor = text_compressor
|
||||
self._min_tokens = min_tokens
|
||||
self._compression_args = compression_params
|
||||
self._filter_dict = filter_dict
|
||||
self._exclude_filter = exclude_filter
|
||||
|
||||
if cache is None:
|
||||
self._cache = Cache.disk()
|
||||
else:
|
||||
self._cache = cache
|
||||
|
||||
# Optimizing savings calculations to optimize log generation
|
||||
self._recent_tokens_savings = 0
|
||||
|
||||
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Applies compression to messages in a conversation history based on the specified configuration.
|
||||
|
||||
The function processes each message according to the `compression_args` and `min_tokens` settings, applying
|
||||
the specified compression configuration and returning a new list of messages with reduced token counts
|
||||
where possible.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): A list of message dictionaries to be compressed.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A list of dictionaries with the message content compressed according to the configured
|
||||
method and scope.
|
||||
"""
|
||||
# Make sure there is at least one message
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
|
||||
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
|
||||
return messages
|
||||
|
||||
total_savings = 0
|
||||
processed_messages = messages.copy()
|
||||
for message in processed_messages:
|
||||
# Some messages may not have content.
|
||||
if not transforms_util.is_content_right_type(message.get("content")):
|
||||
continue
|
||||
|
||||
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
|
||||
continue
|
||||
|
||||
if transforms_util.is_content_text_empty(message["content"]):
|
||||
continue
|
||||
|
||||
cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
|
||||
cached_content = transforms_util.cache_content_get(self._cache, cache_key)
|
||||
if cached_content is not None:
|
||||
message["content"], savings = cached_content
|
||||
else:
|
||||
message["content"], savings = self._compress(message["content"])
|
||||
|
||||
transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)
|
||||
|
||||
assert isinstance(savings, int)
|
||||
total_savings += savings
|
||||
|
||||
self._recent_tokens_savings = total_savings
|
||||
return processed_messages
|
||||
|
||||
def get_logs(
|
||||
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||
) -> tuple[str, bool]:
|
||||
if self._recent_tokens_savings > 0:
|
||||
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
|
||||
else:
|
||||
return "No tokens saved with text compression.", False
|
||||
|
||||
def _compress(self, content: MessageContentType) -> tuple[MessageContentType, int]:
|
||||
"""Compresses the given text or multimodal content using the specified compression method."""
|
||||
if isinstance(content, str):
|
||||
return self._compress_text(content)
|
||||
elif isinstance(content, list):
|
||||
return self._compress_multimodal(content)
|
||||
else:
|
||||
return content, 0
|
||||
|
||||
def _compress_multimodal(self, content: MessageContentType) -> tuple[MessageContentType, int]:
|
||||
tokens_saved = 0
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
item["text"], savings = self._compress_text(item["text"])
|
||||
tokens_saved += savings
|
||||
|
||||
elif isinstance(item, str):
|
||||
item, savings = self._compress_text(item)
|
||||
tokens_saved += savings
|
||||
|
||||
return content, tokens_saved
|
||||
|
||||
def _compress_text(self, text: str) -> tuple[str, int]:
|
||||
"""Compresses the given text using the specified compression method."""
|
||||
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
|
||||
|
||||
savings = 0
|
||||
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
|
||||
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
|
||||
|
||||
return compressed_text["compressed_prompt"], savings
|
||||
|
||||
def _validate_min_tokens(self, min_tokens: Optional[int]):
|
||||
if min_tokens is not None and min_tokens <= 0:
|
||||
raise ValueError("min_tokens must be greater than 0 or None")
|
||||
|
||||
|
||||
class TextMessageContentName:
|
||||
"""A transform for including the agent's name in the content of a message.
|
||||
|
||||
How to create and apply the transform:
|
||||
# Imports
|
||||
from autogen.agentchat.contrib.capabilities import transform_messages, transforms
|
||||
|
||||
# Create Transform
|
||||
name_transform = transforms.TextMessageContentName(position="start", format_string="'{name}' said:\n")
|
||||
|
||||
# Create the TransformMessages
|
||||
context_handling = transform_messages.TransformMessages(
|
||||
transforms=[
|
||||
name_transform
|
||||
]
|
||||
)
|
||||
|
||||
# Add it to an agent so when they run inference it will apply to the messages
|
||||
context_handling.add_to_agent(my_agent)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
position: str = "start",
|
||||
format_string: str = "{name}:\n",
|
||||
deduplicate: bool = True,
|
||||
filter_dict: Optional[dict[str, Any]] = None,
|
||||
exclude_filter: bool = True,
|
||||
):
|
||||
"""Args:
|
||||
position (str): The position to add the name to the content. The possible options are 'start' or 'end'. Defaults to 'start'.
|
||||
format_string (str): The f-string to format the message name with. Use '{name}' as a placeholder for the agent's name. Defaults to '{name}:\n' and must contain '{name}'.
|
||||
deduplicate (bool): Whether to deduplicate the formatted string so it doesn't appear twice (sometimes the LLM will add it to new messages itself). Defaults to True.
|
||||
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
|
||||
If None, no filters will be applied.
|
||||
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
|
||||
excluded from compression. If False, messages that match the filter will be compressed.
|
||||
"""
|
||||
assert isinstance(position, str) and position in ["start", "end"]
|
||||
assert isinstance(format_string, str) and "{name}" in format_string
|
||||
assert isinstance(deduplicate, bool) and deduplicate is not None
|
||||
|
||||
self._position = position
|
||||
self._format_string = format_string
|
||||
self._deduplicate = deduplicate
|
||||
self._filter_dict = filter_dict
|
||||
self._exclude_filter = exclude_filter
|
||||
|
||||
# Track the number of messages changed for logging
|
||||
self._messages_changed = 0
|
||||
|
||||
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Applies the name change to the message based on the position and format string.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): A list of message dictionaries.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A list of dictionaries with the message content updated with names.
|
||||
"""
|
||||
# Make sure there is at least one message
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
messages_changed = 0
|
||||
processed_messages = copy.deepcopy(messages)
|
||||
for message in processed_messages:
|
||||
# Some messages may not have content.
|
||||
if not transforms_util.is_content_right_type(
|
||||
message.get("content")
|
||||
) or not transforms_util.is_content_right_type(message.get("name")):
|
||||
continue
|
||||
|
||||
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
|
||||
continue
|
||||
|
||||
if transforms_util.is_content_text_empty(message["content"]) or transforms_util.is_content_text_empty(
|
||||
message["name"]
|
||||
):
|
||||
continue
|
||||
|
||||
# Get and format the name in the content
|
||||
content = message["content"]
|
||||
formatted_name = self._format_string.format(name=message["name"])
|
||||
|
||||
if self._position == "start":
|
||||
if not self._deduplicate or not content.startswith(formatted_name):
|
||||
message["content"] = f"{formatted_name}{content}"
|
||||
|
||||
messages_changed += 1
|
||||
else:
|
||||
if not self._deduplicate or not content.endswith(formatted_name):
|
||||
message["content"] = f"{content}{formatted_name}"
|
||||
|
||||
messages_changed += 1
|
||||
|
||||
self._messages_changed = messages_changed
|
||||
return processed_messages
|
||||
|
||||
def get_logs(
|
||||
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||
) -> tuple[str, bool]:
|
||||
if self._messages_changed > 0:
|
||||
return f"{self._messages_changed} message(s) changed to incorporate name.", True
|
||||
else:
|
||||
return "No messages changed to incorporate name.", False
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from collections.abc import Hashable
|
||||
from typing import Any, Optional
|
||||
|
||||
from .... import token_count_utils
|
||||
from ....cache.abstract_cache_base import AbstractCache
|
||||
from ....oai.openai_utils import filter_config
|
||||
from ....types import MessageContentType
|
||||
|
||||
|
||||
def cache_key(content: MessageContentType, *args: Hashable) -> str:
|
||||
"""Calculates the cache key for the given message content and any other hashable args.
|
||||
|
||||
Args:
|
||||
content (MessageContentType): The message content to calculate the cache key for.
|
||||
*args: Any additional hashable args to include in the cache key.
|
||||
"""
|
||||
str_keys = [str(key) for key in (content, *args)]
|
||||
return "".join(str_keys)
|
||||
|
||||
|
||||
def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[tuple[MessageContentType, ...]]:
|
||||
"""Retrieves cached content from the cache.
|
||||
|
||||
Args:
|
||||
cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
|
||||
key (str): The key to retrieve the content from.
|
||||
"""
|
||||
if cache:
|
||||
cached_value = cache.get(key)
|
||||
if cached_value:
|
||||
return cached_value
|
||||
|
||||
|
||||
def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
|
||||
"""Sets content into the cache.
|
||||
|
||||
Args:
|
||||
cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
|
||||
key (str): The key to set the content into.
|
||||
content (MessageContentType): The message content to set into the cache.
|
||||
*extra_values: Additional values to be passed to the cache.
|
||||
"""
|
||||
if cache:
|
||||
cache_value = (content, *extra_values)
|
||||
cache.set(key, cache_value)
|
||||
|
||||
|
||||
def min_tokens_reached(messages: list[dict[str, Any]], min_tokens: Optional[int]) -> bool:
|
||||
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): A list of messages to check.
|
||||
min_tokens (None or int): The minimum number of tokens to check for.
|
||||
"""
|
||||
if not min_tokens:
|
||||
return True
|
||||
|
||||
messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
|
||||
return messages_tokens >= min_tokens
|
||||
|
||||
|
||||
def count_text_tokens(content: MessageContentType) -> int:
|
||||
"""Calculates the number of text tokens in the given message content.
|
||||
|
||||
Args:
|
||||
content (MessageContentType): The message content to calculate the number of text tokens for.
|
||||
"""
|
||||
token_count = 0
|
||||
if isinstance(content, str):
|
||||
token_count = token_count_utils.count_token(content)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
token_count += token_count_utils.count_token(item)
|
||||
else:
|
||||
token_count += count_text_tokens(item.get("text", ""))
|
||||
return token_count
|
||||
|
||||
|
||||
def is_content_right_type(content: Any) -> bool:
|
||||
"""A helper function to check if the passed in content is of the right type."""
|
||||
return isinstance(content, (str, list))
|
||||
|
||||
|
||||
def is_content_text_empty(content: MessageContentType) -> bool:
|
||||
"""Checks if the content of the message does not contain any text.
|
||||
|
||||
Args:
|
||||
content (MessageContentType): The message content to check.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content == ""
|
||||
elif isinstance(content, list):
|
||||
texts = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
texts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
texts.append(item.get("text", ""))
|
||||
return not any(texts)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def should_transform_message(message: dict[str, Any], filter_dict: Optional[dict[str, Any]], exclude: bool) -> bool:
|
||||
"""Validates whether the transform should be applied according to the filter dictionary.
|
||||
|
||||
Args:
|
||||
message (Dict[str, Any]): The message to validate.
|
||||
filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
|
||||
exclude (bool): Whether to exclude messages that match the filter dictionary.
|
||||
"""
|
||||
if not filter_dict:
|
||||
return True
|
||||
|
||||
return len(filter_config([message], filter_dict, exclude)) > 0
|
||||
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import copy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from ....code_utils import content_str
|
||||
from ....oai.client import OpenAIWrapper
|
||||
from ...assistant_agent import ConversableAgent
|
||||
from ..img_utils import (
|
||||
convert_base64_to_data_uri,
|
||||
get_image_data,
|
||||
get_pil_image,
|
||||
gpt4v_formatter,
|
||||
)
|
||||
from .agent_capability import AgentCapability
|
||||
|
||||
DEFAULT_DESCRIPTION_PROMPT = (
|
||||
"Write a detailed caption for this image. "
|
||||
"Pay special attention to any details that might be useful or relevant "
|
||||
"to the ongoing conversation."
|
||||
)
|
||||
|
||||
|
||||
class VisionCapability(AgentCapability):
|
||||
"""We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability,
|
||||
such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe
|
||||
the image (captioning) before sending the information to the agent's actual client.
|
||||
|
||||
The vision capability will hook to the ConversableAgent's `process_last_received_message`.
|
||||
|
||||
Some technical details:
|
||||
When the agent (who has the vision capability) received an message, it will:
|
||||
1. _process_received_message:
|
||||
a. _append_oai_message
|
||||
2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag.
|
||||
a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.)
|
||||
b. hook process_all_messages_before_reply
|
||||
3. send:
|
||||
a. hook process_message_before_send
|
||||
b. _append_oai_message
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lmm_config: dict[str, Any],
|
||||
description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT,
|
||||
custom_caption_func: Callable = None,
|
||||
) -> None:
|
||||
"""Initializes a new instance, setting up the configuration for interacting with
|
||||
a Language Multimodal (LMM) client and specifying optional parameters for image
|
||||
description and captioning.
|
||||
|
||||
Args:
|
||||
lmm_config (Dict): Configuration for the LMM client, which is used to call
|
||||
the LMM service for describing the image. This must be a dictionary containing
|
||||
the necessary configuration parameters. If `lmm_config` is False or an empty dictionary,
|
||||
it is considered invalid, and initialization will assert.
|
||||
description_prompt (Optional[str], optional): The prompt to use for generating
|
||||
descriptions of the image. This parameter allows customization of the
|
||||
prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided.
|
||||
custom_caption_func (Callable, optional): A callable that, if provided, will be used
|
||||
to generate captions for images. This allows for custom captioning logic outside
|
||||
of the standard LMM service interaction.
|
||||
The callable should take three parameters as input:
|
||||
1. an image URL (or local location)
|
||||
2. image_data (a PIL image)
|
||||
3. lmm_client (to call remote LMM)
|
||||
and then return a description (as string).
|
||||
If not provided, captioning will rely on the LMM client configured via `lmm_config`.
|
||||
If provided, we will not run the default self._get_image_caption method.
|
||||
|
||||
Raises:
|
||||
AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided,
|
||||
an AssertionError is raised to indicate that the Vision Capability requires
|
||||
one of these to be valid for operation.
|
||||
"""
|
||||
self._lmm_config = lmm_config
|
||||
self._description_prompt = description_prompt
|
||||
self._parent_agent = None
|
||||
|
||||
if lmm_config:
|
||||
self._lmm_client = OpenAIWrapper(**lmm_config)
|
||||
else:
|
||||
self._lmm_client = None
|
||||
|
||||
self._custom_caption_func = custom_caption_func
|
||||
assert self._lmm_config or custom_caption_func, (
|
||||
"Vision Capability requires a valid lmm_config or custom_caption_func."
|
||||
)
|
||||
|
||||
def add_to_agent(self, agent: ConversableAgent) -> None:
|
||||
self._parent_agent = agent
|
||||
|
||||
# Append extra info to the system message.
|
||||
agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.")
|
||||
|
||||
# Register a hook for processing the last message.
|
||||
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
|
||||
|
||||
def process_last_received_message(self, content: Union[str, list[dict[str, Any]]]) -> str:
|
||||
"""Processes the last received message content by normalizing and augmenting it
|
||||
with descriptions of any included images. The function supports input content
|
||||
as either a string or a list of dictionaries, where each dictionary represents
|
||||
a content item (e.g., text, image). If the content contains image URLs, it
|
||||
fetches the image data, generates a caption for each image, and inserts the
|
||||
caption into the augmented content.
|
||||
|
||||
The function aims to transform the content into a format compatible with GPT-4V
|
||||
multimodal inputs, specifically by formatting strings into PIL-compatible
|
||||
images if needed and appending text descriptions for images. This allows for
|
||||
a more accessible presentation of the content, especially in contexts where
|
||||
images cannot be displayed directly.
|
||||
|
||||
Args:
|
||||
content (Union[str, List[dict[str, Any]]]): The last received message content, which
|
||||
can be a plain text string or a list of dictionaries representing
|
||||
different types of content items (e.g., text, image_url).
|
||||
|
||||
Returns:
|
||||
str: The augmented message content
|
||||
|
||||
Raises:
|
||||
AssertionError: If an item in the content list is not a dictionary.
|
||||
|
||||
Examples:
|
||||
Assuming `self._get_image_caption(img_data)` returns
|
||||
"A beautiful sunset over the mountains" for the image.
|
||||
|
||||
- Input as String:
|
||||
content = "Check out this cool photo!"
|
||||
Output: "Check out this cool photo!"
|
||||
(Content is a string without an image, remains unchanged.)
|
||||
|
||||
- Input as String, with image location:
|
||||
content = "What's weather in this cool photo: `<img http://example.com/photo.jpg>`"
|
||||
Output: "What's weather in this cool photo: `<img http://example.com/photo.jpg>` in case you can not see, the caption of this image is:
|
||||
A beautiful sunset over the mountains\n"
|
||||
(Caption added after the image)
|
||||
|
||||
- Input as List with Text Only:
|
||||
content = `[{"type": "text", "text": "Here's an interesting fact."}]`
|
||||
Output: "Here's an interesting fact."
|
||||
(No images in the content, it remains unchanged.)
|
||||
|
||||
- Input as List with Image URL:
|
||||
```python
|
||||
content = [
|
||||
{"type": "text", "text": "What's weather in this cool photo:"},
|
||||
{"type": "image_url", "image_url": "http://example.com/photo.jpg"},
|
||||
]
|
||||
```
|
||||
Output: "What's weather in this cool photo: `<img http://example.com/photo.jpg>` in case you can not see, the caption of this image is:
|
||||
A beautiful sunset over the mountains\n"
|
||||
(Caption added after the image)
|
||||
"""
|
||||
copy.deepcopy(content)
|
||||
# normalize the content into the gpt-4v format for multimodal
|
||||
# we want to keep the URL format to keep it concise.
|
||||
if isinstance(content, str):
|
||||
content = gpt4v_formatter(content, img_format="url")
|
||||
|
||||
aug_content: str = ""
|
||||
for item in content:
|
||||
assert isinstance(item, dict)
|
||||
if item["type"] == "text":
|
||||
aug_content += item["text"]
|
||||
elif item["type"] == "image_url":
|
||||
img_url = item["image_url"]
|
||||
img_caption = ""
|
||||
|
||||
if self._custom_caption_func:
|
||||
img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client)
|
||||
elif self._lmm_client:
|
||||
img_data = get_image_data(img_url)
|
||||
img_caption = self._get_image_caption(img_data)
|
||||
else:
|
||||
img_caption = ""
|
||||
|
||||
aug_content += f"<img {img_url}> in case you can not see, the caption of this image is: {img_caption}\n"
|
||||
else:
|
||||
print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.")
|
||||
|
||||
return aug_content
|
||||
|
||||
def _get_image_caption(self, img_data: str) -> str:
|
||||
"""Args:
|
||||
img_data (str): base64 encoded image data.
|
||||
|
||||
Returns:
|
||||
str: caption for the given image.
|
||||
"""
|
||||
response = self._lmm_client.create(
|
||||
context=None,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self._description_prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": convert_base64_to_data_uri(img_data),
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
description = response.choices[0].message.content
|
||||
return content_str(description)
|
||||
411
mm_agents/coact/autogen/agentchat/contrib/img_utils.py
Normal file
411
mm_agents/coact/autogen/agentchat/contrib/img_utils.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import base64
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from math import ceil
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from ...import_utils import optional_import_block, require_optional_import
|
||||
from .. import utils
|
||||
|
||||
with optional_import_block():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Parameters for token counting for images for different models
|
||||
MODEL_PARAMS = {
|
||||
"gpt-4-vision": {
|
||||
"max_edge": 2048,
|
||||
"min_edge": 768,
|
||||
"tile_size": 512,
|
||||
"base_token_count": 85,
|
||||
"token_multiplier": 170,
|
||||
},
|
||||
"gpt-4o-mini": {
|
||||
"max_edge": 2048,
|
||||
"min_edge": 768,
|
||||
"tile_size": 512,
|
||||
"base_token_count": 2833,
|
||||
"token_multiplier": 5667,
|
||||
},
|
||||
"gpt-4o": {"max_edge": 2048, "min_edge": 768, "tile_size": 512, "base_token_count": 85, "token_multiplier": 170},
|
||||
}
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
def get_pil_image(image_file: Union[str, "Image.Image"]) -> "Image.Image":
|
||||
"""Loads an image from a file and returns a PIL Image object.
|
||||
|
||||
Parameters:
|
||||
image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.
|
||||
|
||||
Returns:
|
||||
Image.Image: The PIL Image object.
|
||||
"""
|
||||
if isinstance(image_file, Image.Image):
|
||||
# Already a PIL Image object
|
||||
return image_file
|
||||
|
||||
# Remove quotes if existed
|
||||
if image_file.startswith('"') and image_file.endswith('"'):
|
||||
image_file = image_file[1:-1]
|
||||
if image_file.startswith("'") and image_file.endswith("'"):
|
||||
image_file = image_file[1:-1]
|
||||
|
||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
# A URL file
|
||||
response = requests.get(image_file)
|
||||
content = BytesIO(response.content)
|
||||
image = Image.open(content)
|
||||
# Match base64-encoded image URIs for supported formats: jpg, jpeg, png, gif, bmp, webp
|
||||
elif re.match(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", image_file):
|
||||
# A URI. Remove the prefix and decode the base64 string.
|
||||
base64_data = re.sub(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", "", image_file)
|
||||
image = _to_pil(base64_data)
|
||||
elif os.path.exists(image_file):
|
||||
# A local file
|
||||
image = Image.open(image_file)
|
||||
else:
|
||||
# base64 encoded string
|
||||
image = _to_pil(image_file)
|
||||
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
def get_image_data(image_file: Union[str, "Image.Image"], use_b64=True) -> bytes:
|
||||
"""Loads an image and returns its data either as raw bytes or in base64-encoded format.
|
||||
|
||||
This function first loads an image from the specified file, URL, or base64 string using
|
||||
the `get_pil_image` function. It then saves this image in memory in PNG format and
|
||||
retrieves its binary content. Depending on the `use_b64` flag, this binary content is
|
||||
either returned directly or as a base64-encoded string.
|
||||
|
||||
Parameters:
|
||||
image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
|
||||
string of the image.
|
||||
use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
|
||||
If False, it returns the raw byte data of the image. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string
|
||||
if `use_b64` is True.
|
||||
"""
|
||||
image = get_pil_image(image_file)
|
||||
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
|
||||
if use_b64:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
else:
|
||||
return content
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, list[str]]:
|
||||
"""Formats the input prompt by replacing image tags and returns the new prompt along with image locations.
|
||||
|
||||
Parameters:
|
||||
- prompt (str): The input string that may contain image tags like `<img ...>`.
|
||||
- order_image_tokens (bool, optional): Whether to order the image tokens with numbers.
|
||||
It will be useful for GPT-4V. Defaults to False.
|
||||
|
||||
Returns:
|
||||
- Tuple[str, List[str]]: A tuple containing the formatted string and a list of images (loaded in b64 format).
|
||||
"""
|
||||
# Initialize variables
|
||||
new_prompt = prompt
|
||||
image_locations = []
|
||||
images = []
|
||||
image_count = 0
|
||||
|
||||
# Regular expression pattern for matching <img ...> tags
|
||||
img_tag_pattern = re.compile(r"<img ([^>]+)>")
|
||||
|
||||
# Find all image tags
|
||||
for match in img_tag_pattern.finditer(prompt):
|
||||
image_location = match.group(1)
|
||||
|
||||
try:
|
||||
img_data = get_image_data(image_location)
|
||||
except Exception as e:
|
||||
# Remove the token
|
||||
print(f"Warning! Unable to load image from {image_location}, because of {e}")
|
||||
new_prompt = new_prompt.replace(match.group(0), "", 1)
|
||||
continue
|
||||
|
||||
image_locations.append(image_location)
|
||||
images.append(img_data)
|
||||
|
||||
# Increment the image count and replace the tag in the prompt
|
||||
new_token = f"<image {image_count}>" if order_image_tokens else "<image>"
|
||||
|
||||
new_prompt = new_prompt.replace(match.group(0), new_token, 1)
|
||||
image_count += 1
|
||||
|
||||
return new_prompt, images
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
def pil_to_data_uri(image: "Image.Image") -> str:
|
||||
"""Converts a PIL Image object to a data URI.
|
||||
|
||||
Parameters:
|
||||
image (Image.Image): The PIL Image object.
|
||||
|
||||
Returns:
|
||||
str: The data URI string.
|
||||
"""
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
|
||||
|
||||
|
||||
def convert_base64_to_data_uri(base64_image):
|
||||
def _get_mime_type_from_data_uri(base64_image):
|
||||
# Decode the base64 string
|
||||
image_data = base64.b64decode(base64_image)
|
||||
# Check the first few bytes for known signatures
|
||||
if image_data.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
|
||||
return "image/gif"
|
||||
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
return "image/jpeg" # use jpeg for unknown formats, best guess.
|
||||
|
||||
mime_type = _get_mime_type_from_data_uri(base64_image)
|
||||
data_uri = f"data:{mime_type};base64,{base64_image}"
|
||||
return data_uri
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> list[Union[str, dict[str, Any]]]:
|
||||
"""Formats the input prompt by replacing image tags and returns a list of text and images.
|
||||
|
||||
Args:
|
||||
prompt (str): The input string that may contain image tags like `<img ...>`.
|
||||
img_format (str): what image format should be used. One of "uri", "url", "pil".
|
||||
|
||||
Returns:
|
||||
List[Union[str, dict[str, Any]]]: A list of alternating text and image dictionary items.
|
||||
"""
|
||||
assert img_format in ["uri", "url", "pil"]
|
||||
|
||||
output = []
|
||||
last_index = 0
|
||||
image_count = 0
|
||||
|
||||
# Find all image tags
|
||||
for parsed_tag in utils.parse_tags_from_content("img", prompt):
|
||||
image_location = parsed_tag["attr"]["src"]
|
||||
try:
|
||||
if img_format == "pil":
|
||||
img_data = get_pil_image(image_location)
|
||||
elif img_format == "uri":
|
||||
img_data = get_image_data(image_location)
|
||||
img_data = convert_base64_to_data_uri(img_data)
|
||||
elif img_format == "url":
|
||||
img_data = image_location
|
||||
else:
|
||||
raise ValueError(f"Unknown image format {img_format}")
|
||||
except Exception as e:
|
||||
# Warning and skip this token
|
||||
print(f"Warning! Unable to load image from {image_location}, because {e}")
|
||||
continue
|
||||
|
||||
# Add text before this image tag to output list
|
||||
output.append({"type": "text", "text": prompt[last_index : parsed_tag["match"].start()]})
|
||||
|
||||
# Add image data to output list
|
||||
output.append({"type": "image_url", "image_url": {"url": img_data}})
|
||||
|
||||
last_index = parsed_tag["match"].end()
|
||||
image_count += 1
|
||||
|
||||
# Add remaining text to output list
|
||||
if last_index < len(prompt):
|
||||
output.append({"type": "text", "text": prompt[last_index:]})
|
||||
return output
|
||||
|
||||
|
||||
def extract_img_paths(paragraph: str) -> list:
|
||||
"""Extract image paths (URLs or local paths) from a text paragraph.
|
||||
|
||||
Parameters:
|
||||
paragraph (str): The input text paragraph.
|
||||
|
||||
Returns:
|
||||
list: A list of extracted image paths.
|
||||
"""
|
||||
# Regular expression to match image URLs and file paths.
|
||||
# This regex detects URLs and file paths with common image extensions, including support for the webp format.
|
||||
img_path_pattern = re.compile(
|
||||
r"\b(?:http[s]?://\S+\.(?:jpg|jpeg|png|gif|bmp|webp)|\S+\.(?:jpg|jpeg|png|gif|bmp|webp))\b", re.IGNORECASE
|
||||
)
|
||||
|
||||
# Find all matches in the paragraph
|
||||
img_paths = re.findall(img_path_pattern, paragraph)
|
||||
return img_paths
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
def _to_pil(data: str) -> "Image.Image":
|
||||
"""Converts a base64 encoded image data string to a PIL Image object.
|
||||
|
||||
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
|
||||
and finally creates and returns a PIL Image object from the BytesIO object.
|
||||
|
||||
Parameters:
|
||||
data (str): The encoded image data string.
|
||||
|
||||
Returns:
|
||||
Image.Image: The PIL Image object created from the input data.
|
||||
"""
|
||||
return Image.open(BytesIO(base64.b64decode(data)))
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
def message_formatter_pil_to_b64(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts the PIL image URLs in the messages to base64 encoded data URIs.
|
||||
|
||||
This function iterates over a list of message dictionaries. For each message,
|
||||
if it contains a 'content' key with a list of items, it looks for items
|
||||
with an 'image_url' key. The function then converts the PIL image URL
|
||||
(pointed to by 'image_url') to a base64 encoded data URI.
|
||||
|
||||
Parameters:
|
||||
messages (List[Dict]): A list of message dictionaries. Each dictionary
|
||||
may contain a 'content' key with a list of items,
|
||||
some of which might be image URLs.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A new list of message dictionaries with PIL image URLs in the
|
||||
'image_url' key converted to base64 encoded data URIs.
|
||||
|
||||
Example Input:
|
||||
example 1:
|
||||
```python
|
||||
[
|
||||
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
|
||||
{'content': [
|
||||
{'type': 'text', 'text': "What's the breed of this dog here?"},
|
||||
{'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}},
|
||||
{'type': 'text', 'text': '.'}],
|
||||
'role': 'user'}
|
||||
]
|
||||
```
|
||||
|
||||
Example Output:
|
||||
example 1:
|
||||
```python
|
||||
[
|
||||
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
|
||||
{'content': [
|
||||
{'type': 'text', 'text': "What's the breed of this dog here?"},
|
||||
{'type': 'image_url', 'image_url': {'url': a B64 Image}},
|
||||
{'type': 'text', 'text': '.'}],
|
||||
'role': 'user'}
|
||||
]
|
||||
```
|
||||
"""
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
# deepcopy to avoid modifying the original message.
|
||||
message = copy.deepcopy(message)
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
# First, if the content is a string, parse it into a list of parts.
|
||||
# This is for tool output that contains images.
|
||||
if isinstance(message["content"], str):
|
||||
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
|
||||
|
||||
# Second, if the content is a list, process any image parts.
|
||||
if isinstance(message["content"], list):
|
||||
for item in message["content"]:
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and "image_url" in item
|
||||
and isinstance(item["image_url"]["url"], Image.Image)
|
||||
):
|
||||
item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"])
|
||||
|
||||
new_messages.append(message)
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
def num_tokens_from_gpt_image(
|
||||
image_data: Union[str, "Image.Image"], model: str = "gpt-4-vision", low_quality: bool = False
|
||||
) -> int:
|
||||
"""Calculate the number of tokens required to process an image based on its dimensions
|
||||
after scaling for different GPT models. Supports "gpt-4-vision", "gpt-4o", and "gpt-4o-mini".
|
||||
This function scales the image so that its longest edge is at most 2048 pixels and its shortest
|
||||
edge is at most 768 pixels (for "gpt-4-vision"). It then calculates the number of 512x512 tiles
|
||||
needed to cover the scaled image and computes the total tokens based on the number of these tiles.
|
||||
|
||||
Reference: https://openai.com/api/pricing/
|
||||
|
||||
Args:
|
||||
image_data : Union[str, Image.Image]: The image data which can either be a base64 encoded string, a URL, a file path, or a PIL Image object.
|
||||
model: str: The model being used for image processing. Can be "gpt-4-vision", "gpt-4o", or "gpt-4o-mini".
|
||||
low_quality: bool: Whether to use low-quality processing. Defaults to False.
|
||||
|
||||
Returns:
|
||||
int: The total number of tokens required for processing the image.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> from PIL import Image
|
||||
>>> img = Image.new("RGB", (2500, 2500), color="red")
|
||||
>>> num_tokens_from_gpt_image(img, model="gpt-4-vision")
|
||||
765
|
||||
"""
|
||||
image = get_pil_image(image_data) # PIL Image
|
||||
width, height = image.size
|
||||
|
||||
# Determine model parameters
|
||||
if "gpt-4-vision" in model or "gpt-4-turbo" in model or "gpt-4v" in model or "gpt-4-v" in model:
|
||||
params = MODEL_PARAMS["gpt-4-vision"]
|
||||
elif "gpt-4o-mini" in model:
|
||||
params = MODEL_PARAMS["gpt-4o-mini"]
|
||||
elif "gpt-4o" in model:
|
||||
params = MODEL_PARAMS["gpt-4o"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model {model} is not supported. Choose 'gpt-4-vision', 'gpt-4-turbo', 'gpt-4v', 'gpt-4-v', 'gpt-4o', or 'gpt-4o-mini'."
|
||||
)
|
||||
|
||||
if low_quality:
|
||||
return params["base_token_count"]
|
||||
|
||||
# 1. Constrain the longest edge
|
||||
if max(width, height) > params["max_edge"]:
|
||||
scale_factor = params["max_edge"] / max(width, height)
|
||||
width, height = int(width * scale_factor), int(height * scale_factor)
|
||||
|
||||
# 2. Further constrain the shortest edge
|
||||
if min(width, height) > params["min_edge"]:
|
||||
scale_factor = params["min_edge"] / min(width, height)
|
||||
width, height = int(width * scale_factor), int(height * scale_factor)
|
||||
|
||||
# 3. Count how many tiles are needed to cover the image
|
||||
tiles_width = ceil(width / params["tile_size"])
|
||||
tiles_height = ceil(height / params["tile_size"])
|
||||
total_tokens = params["base_token_count"] + params["token_multiplier"] * (tiles_width * tiles_height)
|
||||
|
||||
return total_tokens
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import copy
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from ... import OpenAIWrapper
|
||||
from ...code_utils import content_str
|
||||
from .. import Agent, ConversableAgent
|
||||
from ..contrib.img_utils import (
|
||||
gpt4v_formatter,
|
||||
message_formatter_pil_to_b64,
|
||||
)
|
||||
|
||||
DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant."""
|
||||
DEFAULT_MODEL = "gpt-4-vision-preview"
|
||||
|
||||
|
||||
class MultimodalConversableAgent(ConversableAgent):
|
||||
DEFAULT_CONFIG = {
|
||||
"model": DEFAULT_MODEL,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
system_message: Optional[Union[str, list]] = DEFAULT_LMM_SYS_MSG,
|
||||
is_termination_msg: str = None,
|
||||
*args,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Args:
|
||||
name (str): agent name.
|
||||
system_message (str): system message for the OpenAIWrapper inference.
|
||||
Please override this attribute if you want to reprogram the agent.
|
||||
**kwargs (dict): Please refer to other kwargs in
|
||||
[ConversableAgent](/docs/api-reference/autogen/ConversableAgent#conversableagent).
|
||||
"""
|
||||
super().__init__(
|
||||
name,
|
||||
system_message,
|
||||
is_termination_msg=is_termination_msg,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
# call the setter to handle special format.
|
||||
self.update_system_message(system_message)
|
||||
self._is_termination_msg = (
|
||||
is_termination_msg
|
||||
if is_termination_msg is not None
|
||||
else (lambda x: content_str(x.get("content")) == "TERMINATE")
|
||||
)
|
||||
|
||||
# Override the `generate_oai_reply`
|
||||
self.replace_reply_func(ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply)
|
||||
self.replace_reply_func(
|
||||
ConversableAgent.a_generate_oai_reply,
|
||||
MultimodalConversableAgent.a_generate_oai_reply,
|
||||
)
|
||||
|
||||
def update_system_message(self, system_message: Union[dict[str, Any], list[str], str]):
|
||||
"""Update the system message.
|
||||
|
||||
Args:
|
||||
system_message (str): system message for the OpenAIWrapper inference.
|
||||
"""
|
||||
self._oai_system_message[0]["content"] = self._message_to_dict(system_message)["content"]
|
||||
self._oai_system_message[0]["role"] = "system"
|
||||
|
||||
@staticmethod
|
||||
def _message_to_dict(message: Union[dict[str, Any], list[str], str]) -> dict:
|
||||
"""Convert a message to a dictionary. This implementation
|
||||
handles the GPT-4V formatting for easier prompts.
|
||||
|
||||
The message can be a string, a dictionary, or a list of dictionaries:
|
||||
- If it's a string, it will be cast into a list and placed in the 'content' field.
|
||||
- If it's a list, it will be directly placed in the 'content' field.
|
||||
- If it's a dictionary, it is already in message dict format. The 'content' field of this dictionary
|
||||
will be processed using the gpt4v_formatter.
|
||||
"""
|
||||
if isinstance(message, str):
|
||||
return {"content": gpt4v_formatter(message, img_format="pil")}
|
||||
if isinstance(message, list):
|
||||
return {"content": message}
|
||||
if isinstance(message, dict):
|
||||
assert "content" in message, "The message dict must have a `content` field"
|
||||
if isinstance(message["content"], str):
|
||||
message = copy.deepcopy(message)
|
||||
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
|
||||
try:
|
||||
content_str(message["content"])
|
||||
except (TypeError, ValueError) as e:
|
||||
print("The `content` field should be compatible with the content_str function!")
|
||||
raise e
|
||||
return message
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
def generate_oai_reply(
|
||||
self,
|
||||
messages: Optional[list[dict[str, Any]]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[OpenAIWrapper] = None,
|
||||
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
|
||||
"""Generate a reply using autogen.oai."""
|
||||
client = self.client if config is None else config
|
||||
if client is None:
|
||||
return False, None
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)
|
||||
|
||||
new_messages = []
|
||||
for message in messages_with_b64_img:
|
||||
if 'tool_responses' in message:
|
||||
for tool_response in message['tool_responses']:
|
||||
tmp_image = None
|
||||
tmp_list = []
|
||||
for ctx in message['content']:
|
||||
if ctx['type'] == 'image_url':
|
||||
tmp_image = ctx
|
||||
tmp_list.append({
|
||||
'role': 'tool',
|
||||
'tool_call_id': tool_response['tool_call_id'],
|
||||
'content': [message['content'][0]]
|
||||
})
|
||||
if tmp_image:
|
||||
tmp_list.append({
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{'type': 'text', 'text': 'I take a screenshot for the current state for you.'},
|
||||
tmp_image
|
||||
]
|
||||
})
|
||||
new_messages.extend(tmp_list)
|
||||
else:
|
||||
new_messages.append(message)
|
||||
messages_with_b64_img = new_messages.copy()
|
||||
|
||||
|
||||
# TODO: #1143 handle token limit exceeded error
|
||||
response = client.create(
|
||||
context=messages[-1].pop("context", None), messages=messages_with_b64_img, agent=self.name
|
||||
)
|
||||
|
||||
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
|
||||
extracted_response = client.extract_text_or_completion_object(response)[0]
|
||||
if not isinstance(extracted_response, str):
|
||||
extracted_response = extracted_response.model_dump()
|
||||
return True, extracted_response
|
||||
Reference in New Issue
Block a user