CoACT initialize (#292)
This commit is contained in:
@@ -0,0 +1,301 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import re
|
||||
from typing import Any, Literal, Optional, Protocol, Union
|
||||
|
||||
from .... import Agent, ConversableAgent, code_utils
|
||||
from ....cache import AbstractCache
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ....llm_config import LLMConfig
|
||||
from .. import img_utils
|
||||
from ..capabilities.agent_capability import AgentCapability
|
||||
from ..text_analyzer_agent import TextAnalyzerAgent
|
||||
|
||||
with optional_import_block():
|
||||
from PIL.Image import Image
|
||||
from openai import OpenAI
|
||||
|
||||
SYSTEM_MESSAGE = "You've been given the special ability to generate images."
|
||||
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
|
||||
|
||||
PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT.
|
||||
DO NOT include any advice. RESPOND like the following example:
|
||||
EXAMPLE: Blue background, 3D shapes, ...
|
||||
"""
|
||||
|
||||
|
||||
class ImageGenerator(Protocol):
|
||||
"""This class defines an interface for image generators.
|
||||
|
||||
Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
|
||||
input and returns a PIL Image object.
|
||||
|
||||
NOTE: Current implementation does not allow you to edit a previously existing image.
|
||||
"""
|
||||
|
||||
def generate_image(self, prompt: str) -> "Image":
|
||||
"""Generates an image based on the provided prompt.
|
||||
|
||||
Args:
|
||||
prompt: A string describing the desired image.
|
||||
|
||||
Returns:
|
||||
A PIL Image object representing the generated image.
|
||||
|
||||
Raises:
|
||||
ValueError: If the image generation fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def cache_key(self, prompt: str) -> str:
|
||||
"""Generates a unique cache key for the given prompt.
|
||||
|
||||
This key can be used to store and retrieve generated images based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: A string describing the desired image.
|
||||
|
||||
Returns:
|
||||
A unique string that can be used as a cache key.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
@require_optional_import("openai>=1.66.2", "openai")
|
||||
class DalleImageGenerator:
|
||||
"""Generates images using OpenAI's DALL-E models.
|
||||
|
||||
This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E
|
||||
models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate.
|
||||
|
||||
Note: Current implementation does not allow you to edit a previously existing image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
|
||||
quality: Literal["standard", "hd"] = "standard",
|
||||
num_images: int = 1,
|
||||
):
|
||||
"""Args:
|
||||
llm_config (LLMConfig or dict): llm config, must contain a valid dalle model and OpenAI API key in config_list.
|
||||
resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
|
||||
quality (str): The quality of the image you want to generate. Must be one of "standard", "hd".
|
||||
num_images (int): The number of images to generate.
|
||||
"""
|
||||
config_list = llm_config["config_list"]
|
||||
_validate_dalle_model(config_list[0]["model"])
|
||||
_validate_resolution_format(resolution)
|
||||
|
||||
self._model = config_list[0]["model"]
|
||||
self._resolution = resolution
|
||||
self._quality = quality
|
||||
self._num_images = num_images
|
||||
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])
|
||||
|
||||
def generate_image(self, prompt: str) -> "Image":
|
||||
response = self._dalle_client.images.generate(
|
||||
model=self._model,
|
||||
prompt=prompt,
|
||||
size=self._resolution,
|
||||
quality=self._quality,
|
||||
n=self._num_images,
|
||||
)
|
||||
|
||||
image_url = response.data[0].url
|
||||
if image_url is None:
|
||||
raise ValueError("Failed to generate image.")
|
||||
|
||||
return img_utils.get_pil_image(image_url)
|
||||
|
||||
def cache_key(self, prompt: str) -> str:
|
||||
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
|
||||
return ",".join([str(k) for k in keys])
|
||||
|
||||
|
||||
@require_optional_import("PIL", "unknown")
|
||||
class ImageGeneration(AgentCapability):
|
||||
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
|
||||
|
||||
1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and
|
||||
extract relevant details.
|
||||
2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image.
|
||||
3. Optionally caches generated images for faster retrieval in future conversations.
|
||||
|
||||
NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every
|
||||
message received by the agent.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import autogen
|
||||
from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration
|
||||
|
||||
# Assuming you have llm configs configured for the LLMs you want to use and Dalle.
|
||||
# Create the agent
|
||||
agent = autogen.ConversableAgent(
|
||||
name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER"
|
||||
)
|
||||
|
||||
# Create an ImageGenerator with desired settings
|
||||
dalle_gen = generate_images.DalleImageGenerator(llm_config={...})
|
||||
|
||||
# Add the ImageGeneration capability to the agent
|
||||
agent.add_capability(ImageGeneration(image_generator=dalle_gen))
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_generator: ImageGenerator,
|
||||
cache: Optional[AbstractCache] = None,
|
||||
text_analyzer_llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
|
||||
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
|
||||
verbosity: int = 0,
|
||||
register_reply_position: int = 2,
|
||||
):
|
||||
"""Args:
|
||||
image_generator (ImageGenerator): The image generator you would like to use to generate images.
|
||||
cache (None or AbstractCache): The cache client to use to store and retrieve generated images. If None,
|
||||
no caching will be used.
|
||||
text_analyzer_llm_config (LLMConfig or Dict or None): The LLM config for the text analyzer. If None, the LLM config will
|
||||
be retrieved from the agent you're adding the ability to.
|
||||
text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze
|
||||
incoming messages and extract the prompt for image generation. The default instructions focus on
|
||||
summarizing the prompt. You can customize the instructions to achieve more granular control over prompt
|
||||
extraction.
|
||||
Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.'
|
||||
verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text
|
||||
analyzer llm calls will be silent if verbosity is less than 2.
|
||||
register_reply_position (int): The position of the reply function in the agent's list of reply functions.
|
||||
This capability registers a new reply function to handle messages with image generation requests.
|
||||
Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
|
||||
"""
|
||||
self._image_generator = image_generator
|
||||
self._cache = cache
|
||||
self._text_analyzer_llm_config = text_analyzer_llm_config
|
||||
self._text_analyzer_instructions = text_analyzer_instructions
|
||||
self._verbosity = verbosity
|
||||
self._register_reply_position = register_reply_position
|
||||
|
||||
self._agent: Optional[ConversableAgent] = None
|
||||
self._text_analyzer: Optional[TextAnalyzerAgent] = None
|
||||
|
||||
def add_to_agent(self, agent: ConversableAgent):
|
||||
"""Adds the Image Generation capability to the specified ConversableAgent.
|
||||
|
||||
This function performs the following modifications to the agent:
|
||||
|
||||
1. Registers a reply function: A new reply function is registered with the agent to handle messages that
|
||||
potentially request image generation. This function analyzes the message and triggers image generation if
|
||||
necessary.
|
||||
2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements.
|
||||
3. Updates System Message: The agent's system message is updated to include a message indicating the
|
||||
capability to generate images has been added.
|
||||
4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation
|
||||
capability. This might be helpful in certain use cases, like group chats.
|
||||
|
||||
Args:
|
||||
agent (ConversableAgent): The ConversableAgent to add the capability to.
|
||||
"""
|
||||
self._agent = agent
|
||||
|
||||
agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position)
|
||||
|
||||
self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config
|
||||
self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config)
|
||||
|
||||
agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE)
|
||||
agent.description += "\n" + DESCRIPTION_MESSAGE
|
||||
|
||||
def _image_gen_reply(
|
||||
self,
|
||||
recipient: ConversableAgent,
|
||||
messages: Optional[list[dict[str, Any]]],
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
|
||||
if messages is None:
|
||||
return False, None
|
||||
|
||||
last_message = code_utils.content_str(messages[-1]["content"])
|
||||
|
||||
if not last_message:
|
||||
return False, None
|
||||
|
||||
if self._should_generate_image(last_message):
|
||||
prompt = self._extract_prompt(last_message)
|
||||
|
||||
image = self._cache_get(prompt)
|
||||
if image is None:
|
||||
image = self._image_generator.generate_image(prompt)
|
||||
self._cache_set(prompt, image)
|
||||
|
||||
return True, self._generate_content_message(prompt, image)
|
||||
|
||||
else:
|
||||
return False, None
|
||||
|
||||
def _should_generate_image(self, message: str) -> bool:
|
||||
assert self._text_analyzer is not None
|
||||
|
||||
instructions = """
|
||||
Does any part of the TEXT ask the agent to generate an image?
|
||||
The TEXT must explicitly mention that the image must be generated.
|
||||
Answer with just one word, yes or no.
|
||||
"""
|
||||
analysis = self._text_analyzer.analyze_text(message, instructions)
|
||||
|
||||
return "yes" in self._extract_analysis(analysis).lower()
|
||||
|
||||
def _extract_prompt(self, last_message) -> str:
|
||||
assert self._text_analyzer is not None
|
||||
|
||||
analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
|
||||
return self._extract_analysis(analysis)
|
||||
|
||||
def _cache_get(self, prompt: str) -> Optional["Image"]:
|
||||
if self._cache:
|
||||
key = self._image_generator.cache_key(prompt)
|
||||
cached_value = self._cache.get(key)
|
||||
|
||||
if cached_value:
|
||||
return img_utils.get_pil_image(cached_value)
|
||||
|
||||
def _cache_set(self, prompt: str, image: "Image"):
|
||||
if self._cache:
|
||||
key = self._image_generator.cache_key(prompt)
|
||||
self._cache.set(key, img_utils.pil_to_data_uri(image))
|
||||
|
||||
def _extract_analysis(self, analysis: Optional[Union[str, dict[str, Any]]]) -> str:
|
||||
if isinstance(analysis, dict):
|
||||
return code_utils.content_str(analysis["content"])
|
||||
else:
|
||||
return code_utils.content_str(analysis)
|
||||
|
||||
def _generate_content_message(self, prompt: str, image: "Image") -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
|
||||
{"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# Helpers
|
||||
def _validate_resolution_format(resolution: str):
|
||||
"""Checks if a string is in a valid resolution format (e.g., "1024x768")."""
|
||||
pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits
|
||||
matched_resolution = re.match(pattern, resolution)
|
||||
if matched_resolution is None:
|
||||
raise ValueError(f"Invalid resolution format: {resolution}")
|
||||
|
||||
|
||||
def _validate_dalle_model(model: str):
|
||||
if model not in ["dall-e-3", "dall-e-2"]:
|
||||
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")
|
||||
Reference in New Issue
Block a user