CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -136,6 +136,82 @@ class PythonController:
logger.error("Failed to execute command.") logger.error("Failed to execute command.")
return None return None
def run_python_script(self, script: str) -> Optional[Dict[str, Any]]:
"""
Executes a python script on the server.
"""
payload = json.dumps({"code": script})
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/run_python", headers={'Content-Type': 'application/json'},
data=payload, timeout=90)
if response.status_code == 200:
return response.json()
else:
return {"status": "error", "message": "Failed to execute command.", "output": None, "error": response.json()["error"]}
except requests.exceptions.ReadTimeout:
break
except Exception:
logger.error("An error occurred while trying to execute the command: %s", traceback.format_exc())
logger.info("Retrying to execute command.")
time.sleep(self.retry_interval)
logger.error("Failed to execute command.")
return {"status": "error", "message": "Failed to execute command.", "output": "", "error": "Retry limit reached."}
def run_bash_script(self, script: str, timeout: int = 30, working_dir: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""
Executes a bash script on the server.
:param script: The bash script content (can be multi-line)
:param timeout: Execution timeout in seconds (default: 30)
:param working_dir: Working directory for script execution (optional)
:return: Dictionary with status, output, error, and returncode, or None if failed
"""
payload = json.dumps({
"script": script,
"timeout": timeout,
"working_dir": working_dir
})
for _ in range(self.retry_times):
try:
response = requests.post(
self.http_server + "/run_bash_script",
headers={'Content-Type': 'application/json'},
data=payload,
timeout=timeout + 100 # Add buffer to HTTP timeout
)
if response.status_code == 200:
result = response.json()
logger.info("Bash script executed successfully with return code: %d", result.get("returncode", -1))
return result
else:
logger.error("Failed to execute bash script. Status code: %d, response: %s",
response.status_code, response.text)
logger.info("Retrying to execute bash script.")
except requests.exceptions.ReadTimeout:
logger.error("Bash script execution timed out")
return {
"status": "error",
"output": "",
"error": f"Script execution timed out after {timeout} seconds",
"returncode": -1
}
except Exception as e:
logger.error("An error occurred while trying to execute the bash script: %s", e)
logger.info("Retrying to execute bash script.")
time.sleep(self.retry_interval)
logger.error("Failed to execute bash script after %d retries.", self.retry_times)
return {
"status": "error",
"output": "",
"error": f"Failed to execute bash script after {self.retry_times} retries",
"returncode": -1
}
def execute_action(self, action: Dict[str, Any]): def execute_action(self, action: Dict[str, Any]):
""" """

View File

@@ -1568,5 +1568,230 @@ def end_recording():
return abort(500, description=f"Recording failed. The output file is missing or empty. ffmpeg stderr: {error_output}") return abort(500, description=f"Recording failed. The output file is missing or empty. ffmpeg stderr: {error_output}")
@app.route("/run_python", methods=['POST'])
def run_python():
data = request.json
code = data.get('code', None)
if not code:
return jsonify({'status': 'error', 'message': 'Code not supplied!'}), 400
# Create a temporary file to save the Python code
import tempfile
import uuid
# Generate unique filename
temp_filename = f"/tmp/python_exec_{uuid.uuid4().hex}.py"
try:
# Write code to temporary file
with open(temp_filename, 'w') as f:
f.write(code)
# Execute the file using subprocess to capture all output
result = subprocess.run(
['/usr/bin/python3', temp_filename],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=30 # 30 second timeout
)
# Clean up the temporary file
try:
os.remove(temp_filename)
except:
pass # Ignore cleanup errors
# Prepare response
output = result.stdout
error_output = result.stderr
# Combine output and errors if both exist
combined_message = output
if error_output:
combined_message += ('\n' + error_output) if output else error_output
# Determine status based on return code and errors
if result.returncode != 0:
status = 'error'
if not error_output:
# If no stderr but non-zero return code, add a generic error message
error_output = f"Process exited with code {result.returncode}"
combined_message = combined_message + '\n' + error_output if combined_message else error_output
else:
status = 'success'
return jsonify({
'status': status,
'message': combined_message,
'need_more': False, # Not applicable for file execution
'output': output, # stdout only
'error': error_output, # stderr only
'return_code': result.returncode
})
except subprocess.TimeoutExpired:
# Clean up the temporary file on timeout
try:
os.remove(temp_filename)
except:
pass
return jsonify({
'status': 'error',
'message': 'Execution timeout: Code took too long to execute',
'error': 'TimeoutExpired',
'need_more': False,
'output': None,
}), 500
except Exception as e:
# Clean up the temporary file on error
try:
os.remove(temp_filename)
except:
pass
# Capture the exception details
return jsonify({
'status': 'error',
'message': f'Execution error: {str(e)}',
'error': traceback.format_exc(),
'need_more': False,
'output': None,
}), 500
@app.route("/run_bash_script", methods=['POST'])
def run_bash_script():
data = request.json
script = data.get('script', None)
timeout = data.get('timeout', 100) # Default timeout of 30 seconds
working_dir = data.get('working_dir', None)
if not script:
return jsonify({
'status': 'error',
'output': 'Script not supplied!',
'error': "", # Always empty as requested
'returncode': -1
}), 400
# Expand user directory if provided
if working_dir:
working_dir = os.path.expanduser(working_dir)
if not os.path.exists(working_dir):
return jsonify({
'status': 'error',
'output': f'Working directory does not exist: {working_dir}',
'error': "", # Always empty as requested
'returncode': -1
}), 400
# Create a temporary script file
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as tmp_file:
if "#!/bin/bash" not in script:
script = "#!/bin/bash\n\n" + script
tmp_file.write(script)
tmp_file_path = tmp_file.name
try:
# Make the script executable
os.chmod(tmp_file_path, 0o755)
# Execute the script
if platform_name == "Windows":
# On Windows, use Git Bash or WSL if available, otherwise cmd
flags = subprocess.CREATE_NO_WINDOW
# Try to use bash if available (Git Bash, WSL, etc.)
result = subprocess.run(
['bash', tmp_file_path],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout
text=True,
timeout=timeout,
cwd=working_dir,
creationflags=flags,
shell=False
)
else:
# On Unix-like systems, use bash directly
flags = 0
result = subprocess.run(
['/bin/bash', tmp_file_path],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout
text=True,
timeout=timeout,
cwd=working_dir,
creationflags=flags,
shell=False
)
# Log the command execution for trajectory recording
_append_event("BashScript",
{"script": script, "output": result.stdout, "error": "", "returncode": result.returncode},
ts=time.time())
return jsonify({
'status': 'success' if result.returncode == 0 else 'error',
'output': result.stdout, # Contains both stdout and stderr merged
'error': "", # Always empty as requested
'returncode': result.returncode
})
except subprocess.TimeoutExpired:
return jsonify({
'status': 'error',
'output': f'Script execution timed out after {timeout} seconds',
'error': "", # Always empty as requested
'returncode': -1
}), 500
except FileNotFoundError:
# Bash not found, try with sh
try:
result = subprocess.run(
['sh', tmp_file_path],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout
text=True,
timeout=timeout,
cwd=working_dir,
shell=False
)
_append_event("BashScript",
{"script": script, "output": result.stdout, "error": "", "returncode": result.returncode},
ts=time.time())
return jsonify({
'status': 'success' if result.returncode == 0 else 'error',
'output': result.stdout, # Contains both stdout and stderr merged
'error': "", # Always empty as requested
'returncode': result.returncode,
})
except Exception as e:
return jsonify({
'status': 'error',
'output': f'Failed to execute script: {str(e)}',
'error': "", # Always empty as requested
'returncode': -1
}), 500
except Exception as e:
return jsonify({
'status': 'error',
'output': f'Failed to execute script: {str(e)}',
'error': "", # Always empty as requested
'returncode': -1
}), 500
finally:
# Clean up the temporary file
try:
os.unlink(tmp_file_path)
except:
pass
if __name__ == '__main__': if __name__ == '__main__':
app.run(debug=True, host="0.0.0.0") app.run(debug=True, host="0.0.0.0")

View File

@@ -0,0 +1,27 @@
[
{
"model": "gpt-4o",
"api_key": "KEY",
"tags": ["gpt-4o", "code", "explainer"]
},
{
"model": "o3",
"api_key": "KEY",
"tags": ["o3", "coding", "explainer"]
},
{
"model": "gpt-4.1",
"api_key": "KEY",
"tags": ["gpt-4.1", "coding", "explainer"]
},
{
"model": "o4-mini",
"api_key": "KEY",
"tags": ["o4-mini", "coding", "explainer"]
},
{
"model": "o3-mini",
"api_key": "KEY",
"tags": ["o3-mini", "coding", "explainer"]
}
]

View File

View File

@@ -0,0 +1,81 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import logging
from .agentchat import (
Agent,
AssistantAgent,
ChatResult,
ConversableAgent,
GroupChat,
GroupChatManager,
UpdateSystemMessage,
UserProxyAgent,
gather_usage_summary,
initiate_chats,
register_function,
)
from .agentchat.group.context_expression import ContextExpression
from .code_utils import DEFAULT_MODEL, FAST_MODEL
from .exception_utils import (
AgentNameConflictError,
InvalidCarryOverTypeError,
NoEligibleSpeakerError,
SenderRequiredError,
UndefinedNextAgentError,
)
from .llm_config import LLMConfig
from .oai import (
Cache,
ModelClient,
OpenAIWrapper,
config_list_from_dotenv,
config_list_from_json,
config_list_from_models,
config_list_gpt4_gpt35,
config_list_openai_aoai,
filter_config,
get_config_list,
)
# Set the root logger.
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
__all__ = [
"DEFAULT_MODEL",
"FAST_MODEL",
"Agent",
"AgentNameConflictError",
"AssistantAgent",
"Cache",
"ChatResult",
"ContextExpression",
"ConversableAgent",
"GroupChat",
"GroupChatManager",
"InvalidCarryOverTypeError",
"LLMConfig",
"ModelClient",
"NoEligibleSpeakerError",
"OpenAIWrapper",
"SenderRequiredError",
"UndefinedNextAgentError",
"UpdateSystemMessage",
"UserProxyAgent",
"config_list_from_dotenv",
"config_list_from_json",
"config_list_from_models",
"config_list_gpt4_gpt35",
"config_list_openai_aoai",
"filter_config",
"gather_usage_summary",
"get_config_list",
"initiate_chats",
"register_function",
]

View File

@@ -0,0 +1,38 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from .agent import Agent, LLMAgent
from .assistant_agent import AssistantAgent
from .chat import ChatResult, a_initiate_chats, initiate_chats
from .conversable_agent import ConversableAgent, UpdateSystemMessage, register_function
from .group.multi_agent_chat import a_initiate_group_chat, a_run_group_chat, initiate_group_chat, run_group_chat
from .groupchat import GroupChat, GroupChatManager
from .user_proxy_agent import UserProxyAgent
from .utils import gather_usage_summary
__all__ = [
"Agent",
"AssistantAgent",
"ChatResult",
"ConversableAgent",
"GroupChat",
"GroupChatManager",
"LLMAgent",
"UpdateSystemMessage",
"UserProxyAgent",
"a_initiate_chats",
"a_initiate_group_chat",
"a_initiate_swarm_chat",
"a_run_group_chat",
"a_run_swarm",
"gather_usage_summary",
"initiate_chats",
"initiate_group_chat",
"register_function",
"run_group_chat",
"run_swarm",
]

View File

@@ -0,0 +1,182 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import TYPE_CHECKING, Any, Optional, Protocol, TypeVar, Union, runtime_checkable
from ..doc_utils import export_module
__all__ = ["Agent", "LLMAgent", "LLMMessageType"]
Tool = TypeVar("Tool")
LLMMessageType = dict[str, Any]
DEFAULT_SUMMARY_METHOD = "last_msg"
@runtime_checkable
@export_module("autogen")
class Agent(Protocol):
"""(In preview) A protocol for Agent.
An agent can communicate with other agents and perform actions.
Different agents can differ in what actions they perform in the `receive` method.
"""
@property
def name(self) -> str:
"""The name of the agent."""
...
@property
def description(self) -> str:
"""The description of the agent. Used for the agent's introduction in
a group chat setting.
"""
...
def send(
self,
message: Union[dict[str, Any], str],
recipient: "Agent",
request_reply: Optional[bool] = None,
) -> None:
"""Send a message to another agent.
Args:
message (dict or str): the message to send. If a dict, it should be
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
recipient (Agent): the recipient of the message.
request_reply (bool): whether to request a reply from the recipient.
"""
...
async def a_send(
self,
message: Union[dict[str, Any], str],
recipient: "Agent",
request_reply: Optional[bool] = None,
) -> None:
"""(Async) Send a message to another agent.
Args:
message (dict or str): the message to send. If a dict, it should be
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
recipient (Agent): the recipient of the message.
request_reply (bool): whether to request a reply from the recipient.
"""
...
def receive(
self,
message: Union[dict[str, Any], str],
sender: "Agent",
request_reply: Optional[bool] = None,
) -> None:
"""Receive a message from another agent.
Args:
message (dict or str): the message received. If a dict, it should be
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
sender (Agent): the sender of the message.
request_reply (bool): whether the sender requests a reply.
"""
async def a_receive(
self,
message: Union[dict[str, Any], str],
sender: "Agent",
request_reply: Optional[bool] = None,
) -> None:
"""(Async) Receive a message from another agent.
Args:
message (dict or str): the message received. If a dict, it should be
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
sender (Agent): the sender of the message.
request_reply (bool): whether the sender requests a reply.
"""
...
def generate_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, dict[str, Any], None]:
"""Generate a reply based on the received messages.
Args:
messages (list[dict[str, Any]]): a list of messages received from other agents.
The messages are dictionaries that are JSON-serializable and
follows the OpenAI's ChatCompletion schema.
sender: sender of an Agent instance.
**kwargs: Additional keyword arguments.
Returns:
str or dict or None: the generated reply. If None, no reply is generated.
"""
async def a_generate_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, dict[str, Any], None]:
"""(Async) Generate a reply based on the received messages.
Args:
messages (list[dict[str, Any]]): a list of messages received from other agents.
The messages are dictionaries that are JSON-serializable and
follows the OpenAI's ChatCompletion schema.
sender: sender of an Agent instance.
**kwargs: Additional keyword arguments.
Returns:
str or dict or None: the generated reply. If None, no reply is generated.
"""
...
def set_ui_tools(self, tools: list[Tool]) -> None:
"""Set the UI tools for the agent.
Args:
tools: a list of UI tools to set.
"""
...
def unset_ui_tools(self, tools: list[Tool]) -> None:
"""Unset the UI tools for the agent.
Args:
tools: a list of UI tools to set.
"""
...
@runtime_checkable
@export_module("autogen")
class LLMAgent(Agent, Protocol):
"""(In preview) A protocol for an LLM agent."""
@property
def system_message(self) -> str:
"""The system message of this agent."""
def update_system_message(self, system_message: str) -> None:
"""Update this agent's system message.
Args:
system_message (str): system message for inference.
"""
if TYPE_CHECKING:
# mypy will fail if Conversable agent does not implement Agent protocol
from .conversable_agent import ConversableAgent
def _check_protocol_implementation(agent: ConversableAgent) -> Agent:
return agent

View File

@@ -0,0 +1,85 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import Any, Callable, Literal, Optional, Union
from ..doc_utils import export_module
from ..llm_config import LLMConfig
from ..runtime_logging import log_new_agent, logging_enabled
from .conversable_agent import ConversableAgent
@export_module("autogen")
class AssistantAgent(ConversableAgent):
"""(In preview) Assistant agent, designed to solve a task with LLM.
AssistantAgent is a subclass of ConversableAgent configured with a default system message.
The default system message is designed to solve a task with LLM,
including suggesting python code blocks and debugging.
`human_input_mode` is default to "NEVER"
and `code_execution_config` is default to False.
This agent doesn't execute code by default, and expects the user to execute the code.
"""
DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant.
Solve tasks using your coding and language skills.
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
Reply "TERMINATE" in the end when everything is done.
"""
DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
def __init__(
self,
name: str,
system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
llm_config: Optional[Union[LLMConfig, dict[str, Any], Literal[False]]] = None,
is_termination_msg: Optional[Callable[[dict[str, Any]], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
description: Optional[str] = None,
**kwargs: Any,
):
"""Args:
name (str): agent name.
system_message (str): system message for the ChatCompletion inference.
Please override this attribute if you want to reprogram the agent.
llm_config (dict or False or None): llm inference configuration.
Please refer to [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create)
for available options.
is_termination_msg (function): a function that takes a message in the form of a dictionary
and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call".
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
The limit only plays a role when human_input_mode is not "ALWAYS".
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](https://docs.ag2.ai/latest/docs/api-reference/autogen/ConversableAgent).
"""
super().__init__(
name,
system_message,
is_termination_msg,
max_consecutive_auto_reply,
human_input_mode,
llm_config=llm_config,
description=description,
**kwargs,
)
if logging_enabled():
log_new_agent(self, locals())
# Update the provided description if None, and we are using the default system_message,
# then use the default description.
if description is None and system_message == self.DEFAULT_SYSTEM_MESSAGE:
self.description = self.DEFAULT_DESCRIPTION

View File

@@ -0,0 +1,309 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import asyncio
import datetime
import logging
import warnings
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from typing import Any
from ..doc_utils import export_module
from ..events.agent_events import PostCarryoverProcessingEvent
from ..io.base import IOStream
from .utils import consolidate_chat_info
logger = logging.getLogger(__name__)
Prerequisite = tuple[int, int]
__all__ = ["ChatResult", "a_initiate_chats", "initiate_chats"]
@dataclass
@export_module("autogen")
class ChatResult:
"""(Experimental) The result of a chat. Almost certain to be changed."""
chat_id: int = None
"""chat id"""
chat_history: list[dict[str, Any]] = None
"""The chat history."""
summary: str = None
"""A summary obtained from the chat."""
cost: dict[str, dict[str, Any]] = (
None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference"
)
"""The cost of the chat.
The value for each usage type is a dictionary containing cost information for that specific type.
- "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
- "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
"""
human_input: list[str] = None
"""A list of human input solicited during the chat."""
def _validate_recipients(chat_queue: list[dict[str, Any]]) -> None:
"""Validate recipients exits and warn repetitive recipients."""
receipts_set = set()
for chat_info in chat_queue:
assert "recipient" in chat_info, "recipient must be provided."
receipts_set.add(chat_info["recipient"])
if len(receipts_set) < len(chat_queue):
warnings.warn(
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
UserWarning,
)
def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[Prerequisite]:
"""Create list of Prerequisite (prerequisite_chat_id, chat_id)"""
prerequisites = []
for chat_info in chat_queue:
if "chat_id" not in chat_info:
raise ValueError("Each chat must have a unique id for async multi-chat execution.")
chat_id = chat_info["chat_id"]
pre_chats = chat_info.get("prerequisites", [])
for pre_chat_id in pre_chats:
if not isinstance(pre_chat_id, int):
raise ValueError("Prerequisite chat id is not int.")
prerequisites.append((chat_id, pre_chat_id))
return prerequisites
def __find_async_chat_order(chat_ids: set[int], prerequisites: list[Prerequisite]) -> list[int]:
"""Find chat order for async execution based on the prerequisite chats
Args:
chat_ids: A set of all chat IDs that need to be scheduled
prerequisites: A list of tuples (chat_id, prerequisite_chat_id) where each tuple indicates that chat_id depends on prerequisite_chat_id
Returns:
list: a list of chat_id in order.
"""
edges = defaultdict(set)
indegree = defaultdict(int)
for pair in prerequisites:
chat, pre = pair[0], pair[1]
if chat not in edges[pre]:
indegree[chat] += 1
edges[pre].add(chat)
bfs = [i for i in chat_ids if i not in indegree]
chat_order = []
steps = len(indegree)
for _ in range(steps + 1):
if not bfs:
break
chat_order.extend(bfs)
nxt = []
for node in bfs:
if node in edges:
for course in edges[node]:
indegree[course] -= 1
if indegree[course] == 0:
nxt.append(course)
indegree.pop(course)
edges.pop(node)
bfs = nxt
if indegree:
return []
return chat_order
def _post_process_carryover_item(carryover_item):
if isinstance(carryover_item, str):
return carryover_item
elif isinstance(carryover_item, dict) and "content" in carryover_item:
return str(carryover_item["content"])
else:
return str(carryover_item)
def __post_carryover_processing(chat_info: dict[str, Any]) -> None:
iostream = IOStream.get_default()
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
UserWarning,
)
iostream.send(PostCarryoverProcessingEvent(chat_info=chat_info))
@export_module("autogen")
def initiate_chats(chat_queue: list[dict[str, Any]]) -> list[ChatResult]:
"""Initiate a list of chats.
Args:
chat_queue (List[Dict]): A list of dictionaries containing the information about the chats.
Each dictionary should contain the input arguments for
[`ConversableAgent.initiate_chat`](../ConversableAgent#initiate-chat).
For example:
- `"sender"` - the sender agent.
- `"recipient"` - the recipient agent.
- `"clear_history"` (bool) - whether to clear the chat history with the agent.
Default is True.
- `"silent"` (bool or None) - (Experimental) whether to print the messages in this
conversation. Default is False.
- `"cache"` (Cache or None) - the cache client to use for this conversation.
Default is None.
- `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat
will continue until a termination condition is met. Default is None.
- `"summary_method"` (str or callable) - a string or callable specifying the method to get
a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method.
Default is {}.
- `"message"` (str, callable or None) - if None, input() will be called to get the
initial message.
- `**context` - additional context information to be passed to the chat.
- `"carryover"` - It can be used to specify the carryover information to be passed
to this chat. If provided, we will combine this carryover with the "message" content when
generating the initial chat message in `generate_init_message`.
- `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
then summary from all the finished chats will be taken.
Returns:
(list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
"""
consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
current_chat_queue = chat_queue.copy()
finished_chats = []
while current_chat_queue:
chat_info = current_chat_queue.pop(0)
_chat_carryover = chat_info.get("carryover", [])
finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
"finished_chat_indexes_to_exclude_from_carryover", []
)
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [
r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover
]
if not chat_info.get("silent", False):
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = sender.initiate_chat(**chat_info)
finished_chats.append(chat_res)
return finished_chats
def __system_now_str():
ct = datetime.datetime.now()
return f" System time at {ct}. "
def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int):
"""Update ChatResult when async Task for Chat is completed."""
logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str())
chat_result = chat_future.result()
chat_result.chat_id = chat_id
async def _dependent_chat_future(
chat_id: int, chat_info: dict[str, Any], prerequisite_chat_futures: dict[int, asyncio.Future]
) -> asyncio.Task:
"""Create an async Task for each chat."""
logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
_chat_carryover = chat_info.get("carryover", [])
finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
"finished_chat_indexes_to_exclude_from_carryover", []
)
finished_chats = dict()
for chat in prerequisite_chat_futures:
chat_future = prerequisite_chat_futures[chat]
if chat_future.cancelled():
raise RuntimeError(f"Chat {chat} is cancelled.")
# wait for prerequisite chat results for the new chat carryover
finished_chats[chat] = await chat_future
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
data = [
chat_result.summary
for chat_id, chat_result in finished_chats.items()
if chat_id not in finished_chat_indexes_to_exclude_from_carryover
]
chat_info["carryover"] = _chat_carryover + data
if not chat_info.get("silent", False):
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
chat_res_future.add_done_callback(call_back_with_args)
logger.debug(f"Task for chat {chat_id} created." + __system_now_str())
return chat_res_future
async def a_initiate_chats(chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]:
"""(async) Initiate a list of chats.
Args:
chat_queue (List[Dict]): A list of dictionaries containing the information about the chats.
Each dictionary should contain the input arguments for
[`ConversableAgent.initiate_chat`](../../../ConversableAgent#initiate-chat).
For example:
- `"sender"` - the sender agent.
- `"recipient"` - the recipient agent.
- `"clear_history"` (bool) - whether to clear the chat history with the agent.
Default is True.
- `"silent"` (bool or None) - (Experimental) whether to print the messages in this
conversation. Default is False.
- `"cache"` (Cache or None) - the cache client to use for this conversation.
Default is None.
- `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat
will continue until a termination condition is met. Default is None.
- `"summary_method"` (str or callable) - a string or callable specifying the method to get
a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method.
Default is {}.
- `"message"` (str, callable or None) - if None, input() will be called to get the
initial message.
- `**context` - additional context information to be passed to the chat.
- `"carryover"` - It can be used to specify the carryover information to be passed
to this chat. If provided, we will combine this carryover with the "message" content when
generating the initial chat message in `generate_init_message`.
- `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
then summary from all the finished chats will be taken.
Returns:
- (Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
"""
consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
num_chats = chat_book.keys()
prerequisites = __create_async_prerequisites(chat_queue)
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
finished_chat_futures = dict()
for chat_id in chat_order_by_id:
chat_info = chat_book[chat_id]
prerequisite_chat_ids = chat_info.get("prerequisites", [])
pre_chat_futures = dict()
for pre_chat_id in prerequisite_chat_ids:
pre_chat_future = finished_chat_futures[pre_chat_id]
pre_chat_futures[pre_chat_id] = pre_chat_future
current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures)
finished_chat_futures[chat_id] = current_chat_future
await asyncio.gather(*list(finished_chat_futures.values()))
finished_chats = dict()
for chat in finished_chat_futures:
chat_result = finished_chat_futures[chat].result()
finished_chats[chat] = chat_result
return finished_chats

View File

@@ -0,0 +1,5 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
__all__: list[str] = []

View File

@@ -0,0 +1,5 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
__all__: list[str] = []

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from ...assistant_agent import ConversableAgent
class AgentCapability:
"""Base class for composable capabilities that can be added to an agent."""
def __init__(self):
pass
def add_to_agent(self, agent: ConversableAgent):
"""Adds a particular capability to the given agent. Must be implemented by the capability subclass.
An implementation will typically call agent.register_hook() one or more times. See teachability.py as an example.
"""
raise NotImplementedError

View File

@@ -0,0 +1,301 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import re
from typing import Any, Literal, Optional, Protocol, Union
from .... import Agent, ConversableAgent, code_utils
from ....cache import AbstractCache
from ....import_utils import optional_import_block, require_optional_import
from ....llm_config import LLMConfig
from .. import img_utils
from ..capabilities.agent_capability import AgentCapability
from ..text_analyzer_agent import TextAnalyzerAgent
with optional_import_block():
from PIL.Image import Image
from openai import OpenAI
SYSTEM_MESSAGE = "You've been given the special ability to generate images."
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT.
DO NOT include any advice. RESPOND like the following example:
EXAMPLE: Blue background, 3D shapes, ...
"""
class ImageGenerator(Protocol):
"""This class defines an interface for image generators.
Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
input and returns a PIL Image object.
NOTE: Current implementation does not allow you to edit a previously existing image.
"""
def generate_image(self, prompt: str) -> "Image":
"""Generates an image based on the provided prompt.
Args:
prompt: A string describing the desired image.
Returns:
A PIL Image object representing the generated image.
Raises:
ValueError: If the image generation fails.
"""
...
def cache_key(self, prompt: str) -> str:
"""Generates a unique cache key for the given prompt.
This key can be used to store and retrieve generated images based on the prompt.
Args:
prompt: A string describing the desired image.
Returns:
A unique string that can be used as a cache key.
"""
...
@require_optional_import("PIL", "unknown")
@require_optional_import("openai>=1.66.2", "openai")
class DalleImageGenerator:
"""Generates images using OpenAI's DALL-E models.
This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E
models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate.
Note: Current implementation does not allow you to edit a previously existing image.
"""
def __init__(
self,
llm_config: Union[LLMConfig, dict[str, Any]],
resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
quality: Literal["standard", "hd"] = "standard",
num_images: int = 1,
):
"""Args:
llm_config (LLMConfig or dict): llm config, must contain a valid dalle model and OpenAI API key in config_list.
resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
quality (str): The quality of the image you want to generate. Must be one of "standard", "hd".
num_images (int): The number of images to generate.
"""
config_list = llm_config["config_list"]
_validate_dalle_model(config_list[0]["model"])
_validate_resolution_format(resolution)
self._model = config_list[0]["model"]
self._resolution = resolution
self._quality = quality
self._num_images = num_images
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])
def generate_image(self, prompt: str) -> "Image":
response = self._dalle_client.images.generate(
model=self._model,
prompt=prompt,
size=self._resolution,
quality=self._quality,
n=self._num_images,
)
image_url = response.data[0].url
if image_url is None:
raise ValueError("Failed to generate image.")
return img_utils.get_pil_image(image_url)
def cache_key(self, prompt: str) -> str:
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
return ",".join([str(k) for k in keys])
@require_optional_import("PIL", "unknown")
class ImageGeneration(AgentCapability):
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and
extract relevant details.
2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image.
3. Optionally caches generated images for faster retrieval in future conversations.
NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every
message received by the agent.
Example:
```python
import autogen
from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration
# Assuming you have llm configs configured for the LLMs you want to use and Dalle.
# Create the agent
agent = autogen.ConversableAgent(
name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER"
)
# Create an ImageGenerator with desired settings
dalle_gen = generate_images.DalleImageGenerator(llm_config={...})
# Add the ImageGeneration capability to the agent
agent.add_capability(ImageGeneration(image_generator=dalle_gen))
```
"""
def __init__(
self,
image_generator: ImageGenerator,
cache: Optional[AbstractCache] = None,
text_analyzer_llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
verbosity: int = 0,
register_reply_position: int = 2,
):
"""Args:
image_generator (ImageGenerator): The image generator you would like to use to generate images.
cache (None or AbstractCache): The cache client to use to store and retrieve generated images. If None,
no caching will be used.
text_analyzer_llm_config (LLMConfig or Dict or None): The LLM config for the text analyzer. If None, the LLM config will
be retrieved from the agent you're adding the ability to.
text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze
incoming messages and extract the prompt for image generation. The default instructions focus on
summarizing the prompt. You can customize the instructions to achieve more granular control over prompt
extraction.
Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.'
verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text
analyzer llm calls will be silent if verbosity is less than 2.
register_reply_position (int): The position of the reply function in the agent's list of reply functions.
This capability registers a new reply function to handle messages with image generation requests.
Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
"""
self._image_generator = image_generator
self._cache = cache
self._text_analyzer_llm_config = text_analyzer_llm_config
self._text_analyzer_instructions = text_analyzer_instructions
self._verbosity = verbosity
self._register_reply_position = register_reply_position
self._agent: Optional[ConversableAgent] = None
self._text_analyzer: Optional[TextAnalyzerAgent] = None
def add_to_agent(self, agent: ConversableAgent):
"""Adds the Image Generation capability to the specified ConversableAgent.
This function performs the following modifications to the agent:
1. Registers a reply function: A new reply function is registered with the agent to handle messages that
potentially request image generation. This function analyzes the message and triggers image generation if
necessary.
2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements.
3. Updates System Message: The agent's system message is updated to include a message indicating the
capability to generate images has been added.
4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation
capability. This might be helpful in certain use cases, like group chats.
Args:
agent (ConversableAgent): The ConversableAgent to add the capability to.
"""
self._agent = agent
agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position)
self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config
self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config)
agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE)
agent.description += "\n" + DESCRIPTION_MESSAGE
def _image_gen_reply(
self,
recipient: ConversableAgent,
messages: Optional[list[dict[str, Any]]],
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
if messages is None:
return False, None
last_message = code_utils.content_str(messages[-1]["content"])
if not last_message:
return False, None
if self._should_generate_image(last_message):
prompt = self._extract_prompt(last_message)
image = self._cache_get(prompt)
if image is None:
image = self._image_generator.generate_image(prompt)
self._cache_set(prompt, image)
return True, self._generate_content_message(prompt, image)
else:
return False, None
def _should_generate_image(self, message: str) -> bool:
assert self._text_analyzer is not None
instructions = """
Does any part of the TEXT ask the agent to generate an image?
The TEXT must explicitly mention that the image must be generated.
Answer with just one word, yes or no.
"""
analysis = self._text_analyzer.analyze_text(message, instructions)
return "yes" in self._extract_analysis(analysis).lower()
def _extract_prompt(self, last_message) -> str:
assert self._text_analyzer is not None
analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
return self._extract_analysis(analysis)
def _cache_get(self, prompt: str) -> Optional["Image"]:
if self._cache:
key = self._image_generator.cache_key(prompt)
cached_value = self._cache.get(key)
if cached_value:
return img_utils.get_pil_image(cached_value)
def _cache_set(self, prompt: str, image: "Image"):
if self._cache:
key = self._image_generator.cache_key(prompt)
self._cache.set(key, img_utils.pil_to_data_uri(image))
def _extract_analysis(self, analysis: Optional[Union[str, dict[str, Any]]]) -> str:
if isinstance(analysis, dict):
return code_utils.content_str(analysis["content"])
else:
return code_utils.content_str(analysis)
def _generate_content_message(self, prompt: str, image: "Image") -> dict[str, Any]:
return {
"content": [
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
{"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
]
}
# Helpers
def _validate_resolution_format(resolution: str):
"""Checks if a string is in a valid resolution format (e.g., "1024x768")."""
pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits
matched_resolution = re.match(pattern, resolution)
if matched_resolution is None:
raise ValueError(f"Invalid resolution format: {resolution}")
def _validate_dalle_model(model: str):
if model not in ["dall-e-3", "dall-e-2"]:
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")

View File

@@ -0,0 +1,393 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import os
import pickle
from typing import Any, Optional, Union
from ....formatting_utils import colored
from ....import_utils import optional_import_block, require_optional_import
from ....llm_config import LLMConfig
from ...assistant_agent import ConversableAgent
from ..text_analyzer_agent import TextAnalyzerAgent
from .agent_capability import AgentCapability
with optional_import_block():
import chromadb
from chromadb.config import Settings
class Teachability(AgentCapability):
"""Teachability uses a vector database to give an agent the ability to remember user teachings,
where the user is any caller (human or not) sending messages to the teachable agent.
Teachability is designed to be composable with other agent capabilities.
To make any conversable agent teachable, instantiate both the agent and the Teachability class,
then pass the agent to teachability.add_to_agent(agent).
Note that teachable agents in a group chat must be given unique path_to_db_dir values.
When adding Teachability to an agent, the following are modified:
- The agent's system message is appended with a note about the agent's new ability.
- A hook is added to the agent's `process_last_received_message` hookable method,
and the hook potentially modifies the last of the received messages to include earlier teachings related to the message.
Added teachings do not propagate into the stored message history.
If new user teachings are detected, they are added to new memos in the vector database.
"""
def __init__(
self,
verbosity: Optional[int] = 0,
reset_db: Optional[bool] = False,
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
recall_threshold: Optional[float] = 1.5,
max_num_retrievals: Optional[int] = 10,
llm_config: Optional[Union[LLMConfig, dict[str, Any], bool]] = None,
):
"""Args:
verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
reset_db (Optional, bool): True to clear the DB before starting. Default False.
path_to_db_dir (Optional, str): path to the directory where this particular agent's DB is stored. Default "./tmp/teachable_agent_db"
recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
llm_config (LLMConfig or dict or False): llm inference configuration passed to TextAnalyzerAgent.
If None, TextAnalyzerAgent uses llm_config from the teachable agent.
"""
self.verbosity = verbosity
self.path_to_db_dir = path_to_db_dir
self.recall_threshold = recall_threshold
self.max_num_retrievals = max_num_retrievals
self.llm_config = llm_config
self.analyzer = None
self.teachable_agent = None
# Create the memo store.
self.memo_store = MemoStore(self.verbosity, reset_db, self.path_to_db_dir)
def add_to_agent(self, agent: ConversableAgent):
"""Adds teachability to the given agent."""
self.teachable_agent = agent
# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
# Was an llm_config passed to the constructor?
if self.llm_config is None:
# No. Use the agent's llm_config.
self.llm_config = agent.llm_config
assert self.llm_config, "Teachability requires a valid llm_config."
# Create the analyzer agent.
self.analyzer = TextAnalyzerAgent(llm_config=self.llm_config)
# Append extra info to the system message.
agent.update_system_message(
agent.system_message
+ "\nYou've been given the special ability to remember user teachings from prior conversations."
)
def prepopulate_db(self):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()
def process_last_received_message(self, text: Union[dict[str, Any], str]):
"""Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
"""
# Try to retrieve relevant memos from the DB.
expanded_text = text
if self.memo_store.last_memo_id > 0:
expanded_text = self._consider_memo_retrieval(text)
# Try to store any user teachings in new memos to be used in the future.
self._consider_memo_storage(text)
# Return the (possibly) expanded message text.
return expanded_text
def _consider_memo_storage(self, comment: Union[dict[str, Any], str]):
"""Decides whether to store something from one user comment in the DB."""
memo_added = False
# Check for a problem-solution pair.
response = self._analyze(
comment,
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
)
if "yes" in response.lower():
# Can we extract advice?
advice = self._analyze(
comment,
"Briefly copy any advice from the TEXT that may be useful for a similar but different task in the future. But if no advice is present, just respond with 'none'.",
)
if "none" not in advice.lower():
# Yes. Extract the task.
task = self._analyze(
comment,
"Briefly copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice.",
)
# Generalize the task.
general_task = self._analyze(
task,
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
)
# Add the task-advice (problem-solution) pair to the vector DB.
if self.verbosity >= 1:
print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow"))
self.memo_store.add_input_output_pair(general_task, advice)
memo_added = True
# Check for information to be learned.
response = self._analyze(
comment,
"Does the TEXT contain information that could be committed to memory? Answer with just one word, yes or no.",
)
if "yes" in response.lower():
# Yes. What question would this information answer?
question = self._analyze(
comment,
"Imagine that the user forgot this information in the TEXT. How would they ask you for this information? Include no other text in your response.",
)
# Extract the information.
answer = self._analyze(
comment, "Copy the information from the TEXT that should be committed to memory. Add no explanation."
)
# Add the question-answer pair to the vector DB.
if self.verbosity >= 1:
print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow"))
self.memo_store.add_input_output_pair(question, answer)
memo_added = True
# Were any memos added?
if memo_added:
# Yes. Save them to disk.
self.memo_store._save_memos()
def _consider_memo_retrieval(self, comment: Union[dict[str, Any], str]):
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""
# First, use the comment directly as the lookup key.
if self.verbosity >= 1:
print(colored("\nLOOK FOR RELEVANT MEMOS, AS QUESTION-ANSWER PAIRS", "light_yellow"))
memo_list = self._retrieve_relevant_memos(comment)
# Next, if the comment involves a task, then extract and generalize the task before using it as the lookup key.
response = self._analyze(
comment,
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
)
if "yes" in response.lower():
if self.verbosity >= 1:
print(colored("\nLOOK FOR RELEVANT MEMOS, AS TASK-ADVICE PAIRS", "light_yellow"))
# Extract the task.
task = self._analyze(
comment, "Copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice."
)
# Generalize the task.
general_task = self._analyze(
task,
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
)
# Append any relevant memos.
memo_list.extend(self._retrieve_relevant_memos(general_task))
# De-duplicate the memo list.
memo_list = list(set(memo_list))
# Append the memos to the text of the last message.
return comment + self._concatenate_memo_texts(memo_list)
def _retrieve_relevant_memos(self, input_text: str) -> list:
"""Returns semantically related memos from the DB."""
memo_list = self.memo_store.get_related_memos(
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
)
if self.verbosity >= 1: # noqa: SIM102
# Was anything retrieved?
if len(memo_list) == 0:
# No. Look at the closest memo.
print(colored("\nTHE CLOSEST MEMO IS BEYOND THE THRESHOLD:", "light_yellow"))
self.memo_store.get_nearest_memo(input_text)
print() # Print a blank line. The memo details were printed by get_nearest_memo().
# Create a list of just the memo output_text strings.
memo_list = [memo[1] for memo in memo_list]
return memo_list
def _concatenate_memo_texts(self, memo_list: list) -> str:
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
memo_texts = ""
if len(memo_list) > 0:
info = "\n# Memories that might help\n"
for memo in memo_list:
info = info + "- " + memo + "\n"
if self.verbosity >= 1:
print(colored("\nMEMOS APPENDED TO LAST MESSAGE...\n" + info + "\n", "light_yellow"))
memo_texts = memo_texts + "\n" + info
return memo_texts
def _analyze(self, text_to_analyze: Union[dict[str, Any], str], analysis_instructions: Union[dict[str, Any], str]):
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
self.analyzer.reset() # Clear the analyzer's list of messages.
self.teachable_agent.send(
recipient=self.analyzer, message=text_to_analyze, request_reply=False, silent=(self.verbosity < 2)
) # Put the message in the analyzer's list.
self.teachable_agent.send(
recipient=self.analyzer, message=analysis_instructions, request_reply=True, silent=(self.verbosity < 2)
) # Request the reply.
return self.teachable_agent.last_message(self.analyzer)["content"]
@require_optional_import("chromadb", "teachable")
class MemoStore:
"""Provides memory storage and retrieval for a teachable agent, using a vector database.
Each DB entry (called a memo) is a pair of strings: an input text and an output text.
The input text might be a question, or a task to perform.
The output text might be an answer to the question, or advice on how to perform the task.
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
"""
def __init__(
self,
verbosity: Optional[int] = 0,
reset: Optional[bool] = False,
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
):
"""Args:
- verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
- reset (Optional, bool): True to clear the DB before starting. Default False.
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
"""
self.verbosity = verbosity
self.path_to_db_dir = path_to_db_dir
# Load or create the vector DB on disk.
settings = Settings(
anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
)
self.db_client = chromadb.Client(settings)
self.vec_db = self.db_client.create_collection("memos", get_or_create=True) # The collection is the DB.
# Load or create the associated memo dict on disk.
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
self.uid_text_dict = {}
self.last_memo_id = 0
if (not reset) and os.path.exists(self.path_to_dict):
print(colored("\nLOADING MEMORY FROM DISK", "light_green"))
print(colored(f" Location = {self.path_to_dict}", "light_green"))
with open(self.path_to_dict, "rb") as f:
self.uid_text_dict = pickle.load(f)
self.last_memo_id = len(self.uid_text_dict)
if self.verbosity >= 3:
self.list_memos()
# Clear the DB if requested.
if reset:
self.reset_db()
def list_memos(self):
"""Prints the contents of MemoStore."""
print(colored("LIST OF MEMOS", "light_green"))
for uid, text in self.uid_text_dict.items():
input_text, output_text = text
print(
colored(
f" ID: {uid}\n INPUT TEXT: {input_text}\n OUTPUT TEXT: {output_text}",
"light_green",
)
)
def _save_memos(self):
"""Saves self.uid_text_dict to disk."""
with open(self.path_to_dict, "wb") as file:
pickle.dump(self.uid_text_dict, file)
def reset_db(self):
"""Forces immediate deletion of the DB's contents, in memory and on disk."""
print(colored("\nCLEARING MEMORY", "light_green"))
self.db_client.delete_collection("memos")
self.vec_db = self.db_client.create_collection("memos")
self.uid_text_dict = {}
self._save_memos()
def add_input_output_pair(self, input_text: str, output_text: str):
"""Adds an input-output pair to the vector DB."""
self.last_memo_id += 1
self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
self.uid_text_dict[str(self.last_memo_id)] = input_text, output_text
if self.verbosity >= 1:
print(
colored(
f"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {self.last_memo_id}\n INPUT\n {input_text}\n OUTPUT\n {output_text}\n",
"light_yellow",
)
)
if self.verbosity >= 3:
self.list_memos()
def get_nearest_memo(self, query_text: str):
"""Retrieves the nearest memo to the given query text."""
results = self.vec_db.query(query_texts=[query_text], n_results=1)
uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
input_text_2, output_text = self.uid_text_dict[uid]
assert input_text == input_text_2
if self.verbosity >= 1:
print(
colored(
f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}",
"light_yellow",
)
)
return input_text, output_text, distance
def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]):
"""Retrieves memos that are related to the given query text within the specified distance threshold."""
if n_results > len(self.uid_text_dict):
n_results = len(self.uid_text_dict)
results = self.vec_db.query(query_texts=[query_text], n_results=n_results)
memos = []
num_results = len(results["ids"][0])
for i in range(num_results):
uid, input_text, distance = results["ids"][0][i], results["documents"][0][i], results["distances"][0][i]
if distance < threshold:
input_text_2, output_text = self.uid_text_dict[uid]
assert input_text == input_text_2
if self.verbosity >= 1:
print(
colored(
f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}",
"light_yellow",
)
)
memos.append((input_text, output_text, distance))
return memos
def prepopulate(self):
"""Adds a few arbitrary examples to the vector DB, just to make retrieval less trivial."""
if self.verbosity >= 1:
print(colored("\nPREPOPULATING MEMORY", "light_green"))
examples = []
examples.append({"text": "When I say papers I mean research papers, which are typically pdfs.", "label": "yes"})
examples.append({"text": "Please verify that each paper you listed actually uses langchain.", "label": "no"})
examples.append({"text": "Tell gpt the output should still be latex code.", "label": "no"})
examples.append({"text": "Hint: convert pdfs to text and then answer questions based on them.", "label": "yes"})
examples.append({
"text": "To create a good PPT, include enough content to make it interesting.",
"label": "yes",
})
examples.append({
"text": "No, for this case the columns should be aspects and the rows should be frameworks.",
"label": "no",
})
examples.append({"text": "When writing code, remember to include any libraries that are used.", "label": "yes"})
examples.append({"text": "Please summarize the papers by Eric Horvitz on bounded rationality.", "label": "no"})
examples.append({"text": "Compare the h-index of Daniel Weld and Oren Etzioni.", "label": "no"})
examples.append({
"text": "Double check to be sure that the columns in a table correspond to what was asked for.",
"label": "yes",
})
for example in examples:
self.add_input_output_pair(example["text"], example["label"])
self._save_memos()

View File

@@ -0,0 +1,66 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import Any, Protocol
from ....import_utils import optional_import_block, require_optional_import
with optional_import_block() as result:
import llmlingua
from llmlingua import PromptCompressor
class TextCompressor(Protocol):
"""Defines a protocol for text compression to optimize agent interactions."""
def compress_text(self, text: str, **compression_params) -> dict[str, Any]:
"""This method takes a string as input and returns a dictionary containing the compressed text and other
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
"""
...
@require_optional_import("llmlingua", "long-context")
class LLMLingua:
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
and the specific configurations used for the PromptCompressor.
"""
def __init__(
self,
prompt_compressor_kwargs: dict = dict(
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2=True,
device_map="cpu",
),
structured_compression: bool = False,
) -> None:
"""Args:
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2 set to True, and device_map set to "cpu".
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
is used. Defaults to False.
dictionary.
Raises:
ImportError: If the llmlingua library is not installed.
"""
self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)
assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
self._compression_method = (
self._prompt_compressor.structured_compress_prompt
if structured_compression
else self._prompt_compressor.compress_prompt
)
def compress_text(self, text: str, **compression_params) -> dict[str, Any]:
return self._compression_method([text], **compression_params)

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from ....agentchat import ConversableAgent
from ....tools import Tool
class ToolsCapability:
"""Adding a list of tools as composable capabilities to a single agent.
This class can be inherited from to allow code to run at the point of creating or adding the capability.
Note: both caller and executor of the tools are the same agent.
"""
def __init__(self, tool_list: list[Tool]):
self.tools = [tool for tool in tool_list]
def add_to_agent(self, agent: ConversableAgent):
"""Add tools to the given agent."""
for tool in self.tools:
tool.register_tool(agent=agent)

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import copy
from typing import TYPE_CHECKING, Any
from ....formatting_utils import colored
from .transforms import MessageTransform
if TYPE_CHECKING:
from ...conversable_agent import ConversableAgent
class TransformMessages:
"""Agent capability for transforming messages before reply generation.
This capability allows you to apply a series of message transformations to
a ConversableAgent's incoming messages before they are processed for response
generation. This is useful for tasks such as:
- Limiting the number of messages considered for context.
- Truncating messages to meet token limits.
- Filtering sensitive information.
- Customizing message formatting.
To use `TransformMessages`:
1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`).
2. Instantiate `TransformMessages` with a list of these transformations.
3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`.
NOTE: Order of message transformations is important. You could get different results based on
the order of transformations.
Example:
```python
from agentchat import ConversableAgent
from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter
max_messages = MessageHistoryLimiter(max_messages=2)
truncate_messages = MessageTokenLimiter(max_tokens=500)
transform_messages = TransformMessages(transforms=[max_messages, truncate_messages])
agent = ConversableAgent(...)
transform_messages.add_to_agent(agent)
```
"""
def __init__(self, *, transforms: list[MessageTransform] = [], verbose: bool = True):
"""Args:
transforms: A list of message transformations to apply.
verbose: Whether to print logs of each transformation or not.
"""
self._transforms = transforms
self._verbose = verbose
def add_to_agent(self, agent: "ConversableAgent"):
"""Adds the message transformations capability to the specified ConversableAgent.
This function performs the following modifications to the agent:
1. Registers a hook that automatically transforms all messages before they are processed for
response generation.
"""
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
def _transform_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
post_transform_messages = copy.deepcopy(messages)
system_message = None
if messages[0]["role"] == "system":
system_message = copy.deepcopy(messages[0])
post_transform_messages.pop(0)
for transform in self._transforms:
# deepcopy in case pre_transform_messages will later be used for logs printing
pre_transform_messages = (
copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages
)
post_transform_messages = transform.apply_transform(pre_transform_messages)
if self._verbose:
logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages)
if had_effect:
print(colored(logs_str, "yellow"))
if system_message:
post_transform_messages.insert(0, system_message)
return post_transform_messages

View File

@@ -0,0 +1,579 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import copy
import sys
from typing import Any, Optional, Protocol, Union
import tiktoken
from termcolor import colored
from .... import token_count_utils
from ....cache import AbstractCache, Cache
from ....types import MessageContentType
from . import transforms_util
from .text_compressors import LLMLingua, TextCompressor
class MessageTransform(Protocol):
"""Defines a contract for message transformation.
Classes implementing this protocol should provide an `apply_transform` method
that takes a list of messages and returns the transformed list.
"""
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Applies a transformation to a list of messages.
Args:
messages: A list of dictionaries representing messages.
Returns:
A new list of dictionaries containing the transformed messages.
"""
...
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
"""Creates the string including the logs of the transformation
Alongside the string, it returns a boolean indicating whether the transformation had an effect or not.
Args:
pre_transform_messages: A list of dictionaries representing messages before the transformation.
post_transform_messages: A list of dictionaries representig messages after the transformation.
Returns:
A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not.
"""
...
class MessageHistoryLimiter:
"""Limits the number of messages considered by an agent for response generation.
This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages).
It trims the conversation history by removing older messages, retaining only the most recent messages.
"""
def __init__(
self,
max_messages: Optional[int] = None,
keep_first_message: bool = False,
exclude_names: Optional[list[str]] = None,
):
"""Args:
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
keep_first_message bool: Whether to keep the original first message in the conversation history.
Defaults to False.
exclude_names Optional[list[str]]: List of message sender names to exclude from the message history.
Messages from these senders will be filtered out before applying the message limit. Defaults to None.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages
self._keep_first_message = keep_first_message
self._exclude_names = exclude_names
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Truncates the conversation history to the specified maximum number of messages.
This method returns a new list containing the most recent messages up to the specified
maximum number of messages (max_messages). If max_messages is None, it returns the
original list of messages unmodified.
Args:
messages (List[Dict]): The list of messages representing the conversation history.
Returns:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""
exclude_names = getattr(self, "_exclude_names", None)
filtered = [msg for msg in messages if msg.get("name") not in exclude_names] if exclude_names else messages
if self._max_messages is None or len(filtered) <= self._max_messages:
return filtered
truncated_messages = []
remaining_count = self._max_messages
# Start with the first message if we need to keep it
if self._keep_first_message and filtered:
truncated_messages = [filtered[0]]
remaining_count -= 1
# Loop through messages in reverse
for i in range(len(filtered) - 1, 0, -1):
if remaining_count > 1:
truncated_messages.insert(1 if self._keep_first_message else 0, filtered[i])
if remaining_count == 1: # noqa: SIM102
# If there's only 1 slot left and it's a 'tools' message, ignore it.
if filtered[i].get("role") != "tool":
truncated_messages.insert(1, filtered[i])
remaining_count -= 1
if remaining_count == 0:
break
return truncated_messages
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
pre_transform_messages_len = len(pre_transform_messages)
post_transform_messages_len = len(post_transform_messages)
if post_transform_messages_len < pre_transform_messages_len:
logs_str = (
f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. "
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}."
)
return logs_str, True
return "No messages were removed.", False
def _validate_max_messages(self, max_messages: Optional[int]):
if max_messages is not None and max_messages < 1:
raise ValueError("max_messages must be None or greater than 1")
class MessageTokenLimiter:
"""Truncates messages to meet token limits for efficient processing and response generation.
This transformation applies two levels of truncation to the conversation history:
1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message.
2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens.
NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token
counts for the same text.
NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input
(e.g images).
The truncation process follows these steps in order:
1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages
is less than this threshold, then the messages are returned as is. In other case, the following process is applied.
2. Messages are processed in reverse order (newest to oldest).
3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
and other types of content, only the text content is truncated.
4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
exceeds this limit, the current message being processed get truncated to meet the total token count and any
remaining messages get discarded.
5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
original message order.
"""
def __init__(
self,
max_tokens_per_message: Optional[int] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
filter_dict: Optional[dict[str, Any]] = None,
exclude_filter: bool = True,
):
"""Args:
max_tokens_per_message (None or int): Maximum number of tokens to keep in each message.
Must be greater than or equal to 0 if not None.
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
Must be greater than or equal to 0 if not None.
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from token truncation. If False, messages that match the filter will be truncated.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Applies token truncation to the conversation history.
Args:
messages (List[Dict]): The list of messages representing the conversation history.
Returns:
List[Dict]: A new list containing the truncated messages up to the specified token limits.
"""
assert self._max_tokens_per_message is not None
assert self._max_tokens is not None
assert self._min_tokens is not None
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages
temp_messages = copy.deepcopy(messages)
processed_messages = []
processed_messages_tokens = 0
for msg in reversed(temp_messages):
# Some messages may not have content.
if not transforms_util.is_content_right_type(msg.get("content")):
processed_messages.insert(0, msg)
continue
if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
processed_messages.insert(0, msg)
processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
continue
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
# If adding this message would exceed the token limit, truncate the last message to meet the total token
# limit and discard all remaining messages
if expected_tokens_remained < 0:
msg["content"] = self._truncate_str_to_tokens(
msg["content"], self._max_tokens - processed_messages_tokens
)
processed_messages.insert(0, msg)
break
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
msg_tokens = transforms_util.count_text_tokens(msg["content"])
# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
processed_messages.insert(0, msg)
return processed_messages
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
pre_transform_messages_tokens = sum(
transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
)
post_transform_messages_tokens = sum(
transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
)
if post_transform_messages_tokens < pre_transform_messages_tokens:
logs_str = (
f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. "
f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}"
)
return logs_str, True
return "No tokens were truncated.", False
def _truncate_str_to_tokens(self, contents: Union[str, list], n_tokens: int) -> Union[str, list]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
elif isinstance(contents, list):
return self._truncate_multimodal_text(contents, n_tokens)
else:
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")
def _truncate_multimodal_text(self, contents: list[dict[str, Any]], n_tokens: int) -> list[dict[str, Any]]:
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
tmp_contents = []
for content in contents:
if content["type"] == "text":
truncated_text = self._truncate_tokens(content["text"], n_tokens)
tmp_contents.append({"type": "text", "text": truncated_text})
else:
tmp_contents.append(content)
return tmp_contents
def _truncate_tokens(self, text: str, n_tokens: int) -> str:
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer
encoded_tokens = encoding.encode(text)
truncated_tokens = encoded_tokens[:n_tokens]
truncated_text = encoding.decode(truncated_tokens) # Decode back to text
return truncated_text
def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]:
if max_tokens is not None and max_tokens < 0:
raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0")
try:
allowed_tokens = token_count_utils.get_max_token_limit(self._model)
except Exception:
print(colored(f"Model {self._model} not found in token_count_utils.", "yellow"))
allowed_tokens = None
if max_tokens is not None and allowed_tokens is not None and max_tokens > allowed_tokens:
print(
colored(
f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.",
"yellow",
)
)
return allowed_tokens
return max_tokens if max_tokens is not None else sys.maxsize
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
if min_tokens is None:
return 0
if min_tokens < 0:
raise ValueError("min_tokens must be None or greater than or equal to 0.")
if max_tokens is not None and min_tokens > max_tokens:
raise ValueError("min_tokens must not be more than max_tokens.")
return min_tokens
class TextMessageCompressor:
"""A transform for compressing text messages in a conversation history.
It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
processing and response generation by downstream models.
"""
def __init__(
self,
text_compressor: Optional[TextCompressor] = None,
min_tokens: Optional[int] = None,
compression_params: dict = dict(),
cache: Optional[AbstractCache] = None,
filter_dict: Optional[dict[str, Any]] = None,
exclude_filter: bool = True,
):
"""Args:
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
protocol. If None, it defaults to LLMLingua.
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
than or equal to 0 if not None. If None, no threshold-based compression is applied.
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from compression. If False, messages that match the filter will be compressed.
"""
if text_compressor is None:
text_compressor = LLMLingua()
self._validate_min_tokens(min_tokens)
self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
if cache is None:
self._cache = Cache.disk()
else:
self._cache = cache
# Optimizing savings calculations to optimize log generation
self._recent_tokens_savings = 0
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Applies compression to messages in a conversation history based on the specified configuration.
The function processes each message according to the `compression_args` and `min_tokens` settings, applying
the specified compression configuration and returning a new list of messages with reduced token counts
where possible.
Args:
messages (List[Dict]): A list of message dictionaries to be compressed.
Returns:
List[Dict]: A list of dictionaries with the message content compressed according to the configured
method and scope.
"""
# Make sure there is at least one message
if not messages:
return messages
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages
total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not transforms_util.is_content_right_type(message.get("content")):
continue
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
continue
if transforms_util.is_content_text_empty(message["content"]):
continue
cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
cached_content = transforms_util.cache_content_get(self._cache, cache_key)
if cached_content is not None:
message["content"], savings = cached_content
else:
message["content"], savings = self._compress(message["content"])
transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)
assert isinstance(savings, int)
total_savings += savings
self._recent_tokens_savings = total_savings
return processed_messages
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
if self._recent_tokens_savings > 0:
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
else:
return "No tokens saved with text compression.", False
def _compress(self, content: MessageContentType) -> tuple[MessageContentType, int]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return content, 0
def _compress_multimodal(self, content: MessageContentType) -> tuple[MessageContentType, int]:
tokens_saved = 0
for item in content:
if isinstance(item, dict) and "text" in item:
item["text"], savings = self._compress_text(item["text"])
tokens_saved += savings
elif isinstance(item, str):
item, savings = self._compress_text(item)
tokens_saved += savings
return content, tokens_saved
def _compress_text(self, text: str) -> tuple[str, int]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
savings = 0
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
return compressed_text["compressed_prompt"], savings
def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")
class TextMessageContentName:
"""A transform for including the agent's name in the content of a message.
How to create and apply the transform:
# Imports
from autogen.agentchat.contrib.capabilities import transform_messages, transforms
# Create Transform
name_transform = transforms.TextMessageContentName(position="start", format_string="'{name}' said:\n")
# Create the TransformMessages
context_handling = transform_messages.TransformMessages(
transforms=[
name_transform
]
)
# Add it to an agent so when they run inference it will apply to the messages
context_handling.add_to_agent(my_agent)
"""
def __init__(
self,
position: str = "start",
format_string: str = "{name}:\n",
deduplicate: bool = True,
filter_dict: Optional[dict[str, Any]] = None,
exclude_filter: bool = True,
):
"""Args:
position (str): The position to add the name to the content. The possible options are 'start' or 'end'. Defaults to 'start'.
format_string (str): The f-string to format the message name with. Use '{name}' as a placeholder for the agent's name. Defaults to '{name}:\n' and must contain '{name}'.
deduplicate (bool): Whether to deduplicate the formatted string so it doesn't appear twice (sometimes the LLM will add it to new messages itself). Defaults to True.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from compression. If False, messages that match the filter will be compressed.
"""
assert isinstance(position, str) and position in ["start", "end"]
assert isinstance(format_string, str) and "{name}" in format_string
assert isinstance(deduplicate, bool) and deduplicate is not None
self._position = position
self._format_string = format_string
self._deduplicate = deduplicate
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
# Track the number of messages changed for logging
self._messages_changed = 0
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Applies the name change to the message based on the position and format string.
Args:
messages (List[Dict]): A list of message dictionaries.
Returns:
List[Dict]: A list of dictionaries with the message content updated with names.
"""
# Make sure there is at least one message
if not messages:
return messages
messages_changed = 0
processed_messages = copy.deepcopy(messages)
for message in processed_messages:
# Some messages may not have content.
if not transforms_util.is_content_right_type(
message.get("content")
) or not transforms_util.is_content_right_type(message.get("name")):
continue
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
continue
if transforms_util.is_content_text_empty(message["content"]) or transforms_util.is_content_text_empty(
message["name"]
):
continue
# Get and format the name in the content
content = message["content"]
formatted_name = self._format_string.format(name=message["name"])
if self._position == "start":
if not self._deduplicate or not content.startswith(formatted_name):
message["content"] = f"{formatted_name}{content}"
messages_changed += 1
else:
if not self._deduplicate or not content.endswith(formatted_name):
message["content"] = f"{content}{formatted_name}"
messages_changed += 1
self._messages_changed = messages_changed
return processed_messages
def get_logs(
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
) -> tuple[str, bool]:
if self._messages_changed > 0:
return f"{self._messages_changed} message(s) changed to incorporate name.", True
else:
return "No messages changed to incorporate name.", False

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from collections.abc import Hashable
from typing import Any, Optional
from .... import token_count_utils
from ....cache.abstract_cache_base import AbstractCache
from ....oai.openai_utils import filter_config
from ....types import MessageContentType
def cache_key(content: MessageContentType, *args: Hashable) -> str:
"""Calculates the cache key for the given message content and any other hashable args.
Args:
content (MessageContentType): The message content to calculate the cache key for.
*args: Any additional hashable args to include in the cache key.
"""
str_keys = [str(key) for key in (content, *args)]
return "".join(str_keys)
def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[tuple[MessageContentType, ...]]:
"""Retrieves cached content from the cache.
Args:
cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
key (str): The key to retrieve the content from.
"""
if cache:
cached_value = cache.get(key)
if cached_value:
return cached_value
def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
"""Sets content into the cache.
Args:
cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
key (str): The key to set the content into.
content (MessageContentType): The message content to set into the cache.
*extra_values: Additional values to be passed to the cache.
"""
if cache:
cache_value = (content, *extra_values)
cache.set(key, cache_value)
def min_tokens_reached(messages: list[dict[str, Any]], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
Args:
messages (List[Dict]): A list of messages to check.
min_tokens (None or int): The minimum number of tokens to check for.
"""
if not min_tokens:
return True
messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens
def count_text_tokens(content: MessageContentType) -> int:
"""Calculates the number of text tokens in the given message content.
Args:
content (MessageContentType): The message content to calculate the number of text tokens for.
"""
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
token_count += token_count_utils.count_token(item)
else:
token_count += count_text_tokens(item.get("text", ""))
return token_count
def is_content_right_type(content: Any) -> bool:
"""A helper function to check if the passed in content is of the right type."""
return isinstance(content, (str, list))
def is_content_text_empty(content: MessageContentType) -> bool:
"""Checks if the content of the message does not contain any text.
Args:
content (MessageContentType): The message content to check.
"""
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
texts = []
for item in content:
if isinstance(item, str):
texts.append(item)
elif isinstance(item, dict):
texts.append(item.get("text", ""))
return not any(texts)
else:
return True
def should_transform_message(message: dict[str, Any], filter_dict: Optional[dict[str, Any]], exclude: bool) -> bool:
"""Validates whether the transform should be applied according to the filter dictionary.
Args:
message (Dict[str, Any]): The message to validate.
filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
exclude (bool): Whether to exclude messages that match the filter dictionary.
"""
if not filter_dict:
return True
return len(filter_config([message], filter_dict, exclude)) > 0

View File

@@ -0,0 +1,212 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import copy
from typing import Any, Callable, Optional, Union
from ....code_utils import content_str
from ....oai.client import OpenAIWrapper
from ...assistant_agent import ConversableAgent
from ..img_utils import (
convert_base64_to_data_uri,
get_image_data,
get_pil_image,
gpt4v_formatter,
)
from .agent_capability import AgentCapability
DEFAULT_DESCRIPTION_PROMPT = (
"Write a detailed caption for this image. "
"Pay special attention to any details that might be useful or relevant "
"to the ongoing conversation."
)
class VisionCapability(AgentCapability):
"""We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability,
such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe
the image (captioning) before sending the information to the agent's actual client.
The vision capability will hook to the ConversableAgent's `process_last_received_message`.
Some technical details:
When the agent (who has the vision capability) received an message, it will:
1. _process_received_message:
a. _append_oai_message
2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag.
a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.)
b. hook process_all_messages_before_reply
3. send:
a. hook process_message_before_send
b. _append_oai_message
"""
def __init__(
self,
lmm_config: dict[str, Any],
description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT,
custom_caption_func: Callable = None,
) -> None:
"""Initializes a new instance, setting up the configuration for interacting with
a Language Multimodal (LMM) client and specifying optional parameters for image
description and captioning.
Args:
lmm_config (Dict): Configuration for the LMM client, which is used to call
the LMM service for describing the image. This must be a dictionary containing
the necessary configuration parameters. If `lmm_config` is False or an empty dictionary,
it is considered invalid, and initialization will assert.
description_prompt (Optional[str], optional): The prompt to use for generating
descriptions of the image. This parameter allows customization of the
prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided.
custom_caption_func (Callable, optional): A callable that, if provided, will be used
to generate captions for images. This allows for custom captioning logic outside
of the standard LMM service interaction.
The callable should take three parameters as input:
1. an image URL (or local location)
2. image_data (a PIL image)
3. lmm_client (to call remote LMM)
and then return a description (as string).
If not provided, captioning will rely on the LMM client configured via `lmm_config`.
If provided, we will not run the default self._get_image_caption method.
Raises:
AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided,
an AssertionError is raised to indicate that the Vision Capability requires
one of these to be valid for operation.
"""
self._lmm_config = lmm_config
self._description_prompt = description_prompt
self._parent_agent = None
if lmm_config:
self._lmm_client = OpenAIWrapper(**lmm_config)
else:
self._lmm_client = None
self._custom_caption_func = custom_caption_func
assert self._lmm_config or custom_caption_func, (
"Vision Capability requires a valid lmm_config or custom_caption_func."
)
def add_to_agent(self, agent: ConversableAgent) -> None:
self._parent_agent = agent
# Append extra info to the system message.
agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.")
# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
def process_last_received_message(self, content: Union[str, list[dict[str, Any]]]) -> str:
"""Processes the last received message content by normalizing and augmenting it
with descriptions of any included images. The function supports input content
as either a string or a list of dictionaries, where each dictionary represents
a content item (e.g., text, image). If the content contains image URLs, it
fetches the image data, generates a caption for each image, and inserts the
caption into the augmented content.
The function aims to transform the content into a format compatible with GPT-4V
multimodal inputs, specifically by formatting strings into PIL-compatible
images if needed and appending text descriptions for images. This allows for
a more accessible presentation of the content, especially in contexts where
images cannot be displayed directly.
Args:
content (Union[str, List[dict[str, Any]]]): The last received message content, which
can be a plain text string or a list of dictionaries representing
different types of content items (e.g., text, image_url).
Returns:
str: The augmented message content
Raises:
AssertionError: If an item in the content list is not a dictionary.
Examples:
Assuming `self._get_image_caption(img_data)` returns
"A beautiful sunset over the mountains" for the image.
- Input as String:
content = "Check out this cool photo!"
Output: "Check out this cool photo!"
(Content is a string without an image, remains unchanged.)
- Input as String, with image location:
content = "What's weather in this cool photo: `<img http://example.com/photo.jpg>`"
Output: "What's weather in this cool photo: `<img http://example.com/photo.jpg>` in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)
- Input as List with Text Only:
content = `[{"type": "text", "text": "Here's an interesting fact."}]`
Output: "Here's an interesting fact."
(No images in the content, it remains unchanged.)
- Input as List with Image URL:
```python
content = [
{"type": "text", "text": "What's weather in this cool photo:"},
{"type": "image_url", "image_url": "http://example.com/photo.jpg"},
]
```
Output: "What's weather in this cool photo: `<img http://example.com/photo.jpg>` in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)
"""
copy.deepcopy(content)
# normalize the content into the gpt-4v format for multimodal
# we want to keep the URL format to keep it concise.
if isinstance(content, str):
content = gpt4v_formatter(content, img_format="url")
aug_content: str = ""
for item in content:
assert isinstance(item, dict)
if item["type"] == "text":
aug_content += item["text"]
elif item["type"] == "image_url":
img_url = item["image_url"]
img_caption = ""
if self._custom_caption_func:
img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client)
elif self._lmm_client:
img_data = get_image_data(img_url)
img_caption = self._get_image_caption(img_data)
else:
img_caption = ""
aug_content += f"<img {img_url}> in case you can not see, the caption of this image is: {img_caption}\n"
else:
print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.")
return aug_content
def _get_image_caption(self, img_data: str) -> str:
"""Args:
img_data (str): base64 encoded image data.
Returns:
str: caption for the given image.
"""
response = self._lmm_client.create(
context=None,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": self._description_prompt},
{
"type": "image_url",
"image_url": convert_base64_to_data_uri(img_data),
},
],
}
],
)
description = response.choices[0].message.content
return content_str(description)

View File

@@ -0,0 +1,411 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import base64
import copy
import os
import re
from io import BytesIO
from math import ceil
from typing import Any, Union
import requests
from ...import_utils import optional_import_block, require_optional_import
from .. import utils
with optional_import_block():
from PIL import Image
# Parameters for token counting for images for different models
MODEL_PARAMS = {
"gpt-4-vision": {
"max_edge": 2048,
"min_edge": 768,
"tile_size": 512,
"base_token_count": 85,
"token_multiplier": 170,
},
"gpt-4o-mini": {
"max_edge": 2048,
"min_edge": 768,
"tile_size": 512,
"base_token_count": 2833,
"token_multiplier": 5667,
},
"gpt-4o": {"max_edge": 2048, "min_edge": 768, "tile_size": 512, "base_token_count": 85, "token_multiplier": 170},
}
@require_optional_import("PIL", "unknown")
def get_pil_image(image_file: Union[str, "Image.Image"]) -> "Image.Image":
"""Loads an image from a file and returns a PIL Image object.
Parameters:
image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.
Returns:
Image.Image: The PIL Image object.
"""
if isinstance(image_file, Image.Image):
# Already a PIL Image object
return image_file
# Remove quotes if existed
if image_file.startswith('"') and image_file.endswith('"'):
image_file = image_file[1:-1]
if image_file.startswith("'") and image_file.endswith("'"):
image_file = image_file[1:-1]
if image_file.startswith("http://") or image_file.startswith("https://"):
# A URL file
response = requests.get(image_file)
content = BytesIO(response.content)
image = Image.open(content)
# Match base64-encoded image URIs for supported formats: jpg, jpeg, png, gif, bmp, webp
elif re.match(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", image_file):
# A URI. Remove the prefix and decode the base64 string.
base64_data = re.sub(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", "", image_file)
image = _to_pil(base64_data)
elif os.path.exists(image_file):
# A local file
image = Image.open(image_file)
else:
# base64 encoded string
image = _to_pil(image_file)
return image.convert("RGB")
@require_optional_import("PIL", "unknown")
def get_image_data(image_file: Union[str, "Image.Image"], use_b64=True) -> bytes:
"""Loads an image and returns its data either as raw bytes or in base64-encoded format.
This function first loads an image from the specified file, URL, or base64 string using
the `get_pil_image` function. It then saves this image in memory in PNG format and
retrieves its binary content. Depending on the `use_b64` flag, this binary content is
either returned directly or as a base64-encoded string.
Parameters:
image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
string of the image.
use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
If False, it returns the raw byte data of the image. Defaults to True.
Returns:
bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string
if `use_b64` is True.
"""
image = get_pil_image(image_file)
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
if use_b64:
return base64.b64encode(content).decode("utf-8")
else:
return content
@require_optional_import("PIL", "unknown")
def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, list[str]]:
"""Formats the input prompt by replacing image tags and returns the new prompt along with image locations.
Parameters:
- prompt (str): The input string that may contain image tags like `<img ...>`.
- order_image_tokens (bool, optional): Whether to order the image tokens with numbers.
It will be useful for GPT-4V. Defaults to False.
Returns:
- Tuple[str, List[str]]: A tuple containing the formatted string and a list of images (loaded in b64 format).
"""
# Initialize variables
new_prompt = prompt
image_locations = []
images = []
image_count = 0
# Regular expression pattern for matching <img ...> tags
img_tag_pattern = re.compile(r"<img ([^>]+)>")
# Find all image tags
for match in img_tag_pattern.finditer(prompt):
image_location = match.group(1)
try:
img_data = get_image_data(image_location)
except Exception as e:
# Remove the token
print(f"Warning! Unable to load image from {image_location}, because of {e}")
new_prompt = new_prompt.replace(match.group(0), "", 1)
continue
image_locations.append(image_location)
images.append(img_data)
# Increment the image count and replace the tag in the prompt
new_token = f"<image {image_count}>" if order_image_tokens else "<image>"
new_prompt = new_prompt.replace(match.group(0), new_token, 1)
image_count += 1
return new_prompt, images
@require_optional_import("PIL", "unknown")
def pil_to_data_uri(image: "Image.Image") -> str:
"""Converts a PIL Image object to a data URI.
Parameters:
image (Image.Image): The PIL Image object.
Returns:
str: The data URI string.
"""
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
def convert_base64_to_data_uri(base64_image):
def _get_mime_type_from_data_uri(base64_image):
# Decode the base64 string
image_data = base64.b64decode(base64_image)
# Check the first few bytes for known signatures
if image_data.startswith(b"\xff\xd8\xff"):
return "image/jpeg"
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
return "image/png"
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
return "image/gif"
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
return "image/webp"
return "image/jpeg" # use jpeg for unknown formats, best guess.
mime_type = _get_mime_type_from_data_uri(base64_image)
data_uri = f"data:{mime_type};base64,{base64_image}"
return data_uri
@require_optional_import("PIL", "unknown")
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> list[Union[str, dict[str, Any]]]:
"""Formats the input prompt by replacing image tags and returns a list of text and images.
Args:
prompt (str): The input string that may contain image tags like `<img ...>`.
img_format (str): what image format should be used. One of "uri", "url", "pil".
Returns:
List[Union[str, dict[str, Any]]]: A list of alternating text and image dictionary items.
"""
assert img_format in ["uri", "url", "pil"]
output = []
last_index = 0
image_count = 0
# Find all image tags
for parsed_tag in utils.parse_tags_from_content("img", prompt):
image_location = parsed_tag["attr"]["src"]
try:
if img_format == "pil":
img_data = get_pil_image(image_location)
elif img_format == "uri":
img_data = get_image_data(image_location)
img_data = convert_base64_to_data_uri(img_data)
elif img_format == "url":
img_data = image_location
else:
raise ValueError(f"Unknown image format {img_format}")
except Exception as e:
# Warning and skip this token
print(f"Warning! Unable to load image from {image_location}, because {e}")
continue
# Add text before this image tag to output list
output.append({"type": "text", "text": prompt[last_index : parsed_tag["match"].start()]})
# Add image data to output list
output.append({"type": "image_url", "image_url": {"url": img_data}})
last_index = parsed_tag["match"].end()
image_count += 1
# Add remaining text to output list
if last_index < len(prompt):
output.append({"type": "text", "text": prompt[last_index:]})
return output
def extract_img_paths(paragraph: str) -> list:
"""Extract image paths (URLs or local paths) from a text paragraph.
Parameters:
paragraph (str): The input text paragraph.
Returns:
list: A list of extracted image paths.
"""
# Regular expression to match image URLs and file paths.
# This regex detects URLs and file paths with common image extensions, including support for the webp format.
img_path_pattern = re.compile(
r"\b(?:http[s]?://\S+\.(?:jpg|jpeg|png|gif|bmp|webp)|\S+\.(?:jpg|jpeg|png|gif|bmp|webp))\b", re.IGNORECASE
)
# Find all matches in the paragraph
img_paths = re.findall(img_path_pattern, paragraph)
return img_paths
@require_optional_import("PIL", "unknown")
def _to_pil(data: str) -> "Image.Image":
"""Converts a base64 encoded image data string to a PIL Image object.
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
and finally creates and returns a PIL Image object from the BytesIO object.
Parameters:
data (str): The encoded image data string.
Returns:
Image.Image: The PIL Image object created from the input data.
"""
return Image.open(BytesIO(base64.b64decode(data)))
@require_optional_import("PIL", "unknown")
def message_formatter_pil_to_b64(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts the PIL image URLs in the messages to base64 encoded data URIs.
This function iterates over a list of message dictionaries. For each message,
if it contains a 'content' key with a list of items, it looks for items
with an 'image_url' key. The function then converts the PIL image URL
(pointed to by 'image_url') to a base64 encoded data URI.
Parameters:
messages (List[Dict]): A list of message dictionaries. Each dictionary
may contain a 'content' key with a list of items,
some of which might be image URLs.
Returns:
List[Dict]: A new list of message dictionaries with PIL image URLs in the
'image_url' key converted to base64 encoded data URIs.
Example Input:
example 1:
```python
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here?"},
{'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]
```
Example Output:
example 1:
```python
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here?"},
{'type': 'image_url', 'image_url': {'url': a B64 Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]
```
"""
new_messages = []
for message in messages:
# deepcopy to avoid modifying the original message.
message = copy.deepcopy(message)
if isinstance(message, dict) and "content" in message:
# First, if the content is a string, parse it into a list of parts.
# This is for tool output that contains images.
if isinstance(message["content"], str):
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
# Second, if the content is a list, process any image parts.
if isinstance(message["content"], list):
for item in message["content"]:
if (
isinstance(item, dict)
and "image_url" in item
and isinstance(item["image_url"]["url"], Image.Image)
):
item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"])
new_messages.append(message)
return new_messages
@require_optional_import("PIL", "unknown")
def num_tokens_from_gpt_image(
image_data: Union[str, "Image.Image"], model: str = "gpt-4-vision", low_quality: bool = False
) -> int:
"""Calculate the number of tokens required to process an image based on its dimensions
after scaling for different GPT models. Supports "gpt-4-vision", "gpt-4o", and "gpt-4o-mini".
This function scales the image so that its longest edge is at most 2048 pixels and its shortest
edge is at most 768 pixels (for "gpt-4-vision"). It then calculates the number of 512x512 tiles
needed to cover the scaled image and computes the total tokens based on the number of these tiles.
Reference: https://openai.com/api/pricing/
Args:
image_data : Union[str, Image.Image]: The image data which can either be a base64 encoded string, a URL, a file path, or a PIL Image object.
model: str: The model being used for image processing. Can be "gpt-4-vision", "gpt-4o", or "gpt-4o-mini".
low_quality: bool: Whether to use low-quality processing. Defaults to False.
Returns:
int: The total number of tokens required for processing the image.
Examples:
--------
>>> from PIL import Image
>>> img = Image.new("RGB", (2500, 2500), color="red")
>>> num_tokens_from_gpt_image(img, model="gpt-4-vision")
765
"""
image = get_pil_image(image_data) # PIL Image
width, height = image.size
# Determine model parameters
if "gpt-4-vision" in model or "gpt-4-turbo" in model or "gpt-4v" in model or "gpt-4-v" in model:
params = MODEL_PARAMS["gpt-4-vision"]
elif "gpt-4o-mini" in model:
params = MODEL_PARAMS["gpt-4o-mini"]
elif "gpt-4o" in model:
params = MODEL_PARAMS["gpt-4o"]
else:
raise ValueError(
f"Model {model} is not supported. Choose 'gpt-4-vision', 'gpt-4-turbo', 'gpt-4v', 'gpt-4-v', 'gpt-4o', or 'gpt-4o-mini'."
)
if low_quality:
return params["base_token_count"]
# 1. Constrain the longest edge
if max(width, height) > params["max_edge"]:
scale_factor = params["max_edge"] / max(width, height)
width, height = int(width * scale_factor), int(height * scale_factor)
# 2. Further constrain the shortest edge
if min(width, height) > params["min_edge"]:
scale_factor = params["min_edge"] / min(width, height)
width, height = int(width * scale_factor), int(height * scale_factor)
# 3. Count how many tiles are needed to cover the image
tiles_width = ceil(width / params["tile_size"])
tiles_height = ceil(height / params["tile_size"])
total_tokens = params["base_token_count"] + params["token_multiplier"] * (tiles_width * tiles_height)
return total_tokens

View File

@@ -0,0 +1,153 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import copy
from typing import Any, Optional, Union
from ... import OpenAIWrapper
from ...code_utils import content_str
from .. import Agent, ConversableAgent
from ..contrib.img_utils import (
gpt4v_formatter,
message_formatter_pil_to_b64,
)
DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant."""
DEFAULT_MODEL = "gpt-4-vision-preview"
class MultimodalConversableAgent(ConversableAgent):
DEFAULT_CONFIG = {
"model": DEFAULT_MODEL,
}
def __init__(
self,
name: str,
system_message: Optional[Union[str, list]] = DEFAULT_LMM_SYS_MSG,
is_termination_msg: str = None,
*args,
**kwargs: Any,
):
"""Args:
name (str): agent name.
system_message (str): system message for the OpenAIWrapper inference.
Please override this attribute if you want to reprogram the agent.
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](/docs/api-reference/autogen/ConversableAgent#conversableagent).
"""
super().__init__(
name,
system_message,
is_termination_msg=is_termination_msg,
*args,
**kwargs,
)
# call the setter to handle special format.
self.update_system_message(system_message)
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
# Override the `generate_oai_reply`
self.replace_reply_func(ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply)
self.replace_reply_func(
ConversableAgent.a_generate_oai_reply,
MultimodalConversableAgent.a_generate_oai_reply,
)
def update_system_message(self, system_message: Union[dict[str, Any], list[str], str]):
"""Update the system message.
Args:
system_message (str): system message for the OpenAIWrapper inference.
"""
self._oai_system_message[0]["content"] = self._message_to_dict(system_message)["content"]
self._oai_system_message[0]["role"] = "system"
@staticmethod
def _message_to_dict(message: Union[dict[str, Any], list[str], str]) -> dict:
"""Convert a message to a dictionary. This implementation
handles the GPT-4V formatting for easier prompts.
The message can be a string, a dictionary, or a list of dictionaries:
- If it's a string, it will be cast into a list and placed in the 'content' field.
- If it's a list, it will be directly placed in the 'content' field.
- If it's a dictionary, it is already in message dict format. The 'content' field of this dictionary
will be processed using the gpt4v_formatter.
"""
if isinstance(message, str):
return {"content": gpt4v_formatter(message, img_format="pil")}
if isinstance(message, list):
return {"content": message}
if isinstance(message, dict):
assert "content" in message, "The message dict must have a `content` field"
if isinstance(message["content"], str):
message = copy.deepcopy(message)
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
try:
content_str(message["content"])
except (TypeError, ValueError) as e:
print("The `content` field should be compatible with the content_str function!")
raise e
return message
raise ValueError(f"Unsupported message type: {type(message)}")
def generate_oai_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[OpenAIWrapper] = None,
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
"""Generate a reply using autogen.oai."""
client = self.client if config is None else config
if client is None:
return False, None
if messages is None:
messages = self._oai_messages[sender]
messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)
new_messages = []
for message in messages_with_b64_img:
if 'tool_responses' in message:
for tool_response in message['tool_responses']:
tmp_image = None
tmp_list = []
for ctx in message['content']:
if ctx['type'] == 'image_url':
tmp_image = ctx
tmp_list.append({
'role': 'tool',
'tool_call_id': tool_response['tool_call_id'],
'content': [message['content'][0]]
})
if tmp_image:
tmp_list.append({
'role': 'user',
'content': [
{'type': 'text', 'text': 'I take a screenshot for the current state for you.'},
tmp_image
]
})
new_messages.extend(tmp_list)
else:
new_messages.append(message)
messages_with_b64_img = new_messages.copy()
# TODO: #1143 handle token limit exceeded error
response = client.create(
context=messages[-1].pop("context", None), messages=messages_with_b64_img, agent=self.name
)
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
extracted_response = client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str):
extracted_response = extracted_response.model_dump()
return True, extracted_response

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
__all__: list[str] = []
from .available_condition import ExpressionAvailableCondition, StringAvailableCondition
from .context_condition import ExpressionContextCondition, StringContextCondition
from .context_expression import ContextExpression
from .context_str import ContextStr
from .context_variables import ContextVariables
from .handoffs import Handoffs
from .llm_condition import ContextStrLLMCondition, StringLLMCondition
from .on_condition import OnCondition
from .on_context_condition import OnContextCondition
from .reply_result import ReplyResult
from .speaker_selection_result import SpeakerSelectionResult
from .targets.group_chat_target import GroupChatConfig, GroupChatTarget
"""
from .targets.group_manager_target import (
GroupManagerSelectionMessageContextStr,
GroupManagerSelectionMessageString,
GroupManagerTarget,
)
"""
from .targets.transition_target import (
AgentNameTarget,
AgentTarget,
AskUserTarget,
NestedChatTarget,
RevertToUserTarget,
StayTarget,
TerminateTarget,
)
__all__ = [
"AgentNameTarget",
"AgentTarget",
"AskUserTarget",
"ContextExpression",
"ContextStr",
"ContextStrLLMCondition",
"ContextVariables",
"ExpressionAvailableCondition",
"ExpressionContextCondition",
"GroupChatConfig",
"GroupChatTarget",
# "GroupManagerSelectionMessageContextStr",
# "GroupManagerSelectionMessageString",
# "GroupManagerTarget",
"Handoffs",
"NestedChatTarget",
"OnCondition",
"OnContextCondition",
"ReplyResult",
"RevertToUserTarget",
"SpeakerSelectionResult",
"StayTarget",
"StringAvailableCondition",
"StringContextCondition",
"StringLLMCondition",
"TerminateTarget",
]

View File

@@ -0,0 +1,91 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from .context_expression import ContextExpression
if TYPE_CHECKING:
# Avoid circular import
from ..conversable_agent import ConversableAgent
__all__ = ["AvailableCondition", "ExpressionAvailableCondition", "StringAvailableCondition"]
class AvailableCondition(BaseModel):
"""Protocol for determining if a condition is available to be evaluated."""
def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool:
"""Determine if the condition should be considered for evaluation.
Args:
agent: The agent evaluating the condition
messages: The conversation history
Returns:
True if the condition should be evaluated, False otherwise
"""
raise NotImplementedError("Requires subclasses to implement.")
class StringAvailableCondition(AvailableCondition):
"""String-based available condition.
This condition checks if a named context variable exists and is truthy.
"""
context_variable: str
def __init__(self, context_variable: str, **data: Any) -> None:
"""Initialize with a context variable name as a positional parameter.
Args:
context_variable: The name of the context variable to check
data: Additional data for the parent class
"""
super().__init__(context_variable=context_variable, **data)
def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool:
"""Check if the named context variable is truthy.
Args:
agent: The agent with context variables
messages: The conversation history (not used)
Returns:
True if the variable exists and is truthy, False otherwise
"""
return bool(agent.context_variables.get(self.context_variable, False))
class ExpressionAvailableCondition(AvailableCondition):
"""Expression-based available condition.
This condition evaluates a ContextExpression against the context variables.
"""
expression: ContextExpression
def __init__(self, expression: ContextExpression, **data: Any) -> None:
"""Initialize with an expression as a positional parameter.
Args:
expression: The context expression to evaluate
data: Additional data for the parent class
"""
super().__init__(expression=expression, **data)
def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool:
"""Evaluate the expression against the context variables.
Args:
agent: The agent with context variables
messages: The conversation history (not used)
Returns:
Boolean result of the expression evaluation
"""
return self.expression.evaluate(agent.context_variables)

View File

@@ -0,0 +1,77 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from pydantic import BaseModel
from .context_expression import ContextExpression
from .context_variables import ContextVariables
__all__ = ["ContextCondition", "ExpressionContextCondition", "StringContextCondition"]
class ContextCondition(BaseModel):
"""Protocol for conditions evaluated directly using context variables."""
def evaluate(self, context_variables: ContextVariables) -> bool:
"""Evaluate the condition to a boolean result.
Args:
context_variables: The context variables to evaluate against
Returns:
Boolean result of the condition evaluation
"""
raise NotImplementedError("Requires subclasses to implement.")
class StringContextCondition(ContextCondition):
"""Simple string-based context condition.
This condition checks if a named context variable exists and is truthy.
"""
variable_name: str
def evaluate(self, context_variables: ContextVariables) -> bool:
"""Check if the named context variable is truthy.
Args:
context_variables: The context variables to check against
Returns:
True if the variable exists and is truthy, False otherwise
"""
return bool(context_variables.get(self.variable_name, False))
class ExpressionContextCondition(ContextCondition):
"""Complex expression-based context condition.
This condition evaluates a ContextExpression against the context variables.
"""
expression: ContextExpression
def __init__(self, expression: ContextExpression, **data: Any) -> None:
"""Initialize with an expression as a positional parameter.
Args:
expression: The context expression to evaluate
data: Additional data for the parent class
"""
super().__init__(expression=expression, **data)
def evaluate(self, context_variables: ContextVariables) -> bool:
"""Evaluate the expression against the context variables.
Args:
context_variables: The context variables to evaluate against
Returns:
Boolean result of the expression evaluation
"""
return self.expression.evaluate(context_variables)

View File

@@ -0,0 +1,238 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import ast
import re
from dataclasses import dataclass
from ...doc_utils import export_module
from .context_variables import ContextVariables
@dataclass
@export_module("autogen")
class ContextExpression:
"""A class to evaluate logical expressions using context variables.
Args:
expression (str): A string containing a logical expression with context variable references.
- Variable references use ${var_name} syntax: ${logged_in}, ${attempts}
- String literals can use normal quotes: 'hello', "world"
- Supported operators:
- Logical: not/!, and/&, or/|
- Comparison: >, <, >=, <=, ==, !=
- Supported functions:
- len(${var_name}): Gets the length of a list, string, or other collection
- Parentheses can be used for grouping
- Examples:
- "not ${logged_in} and ${is_admin} or ${guest_checkout}"
- "!${logged_in} & ${is_admin} | ${guest_checkout}"
- "len(${orders}) > 0 & ${user_active}"
- "len(${cart_items}) == 0 | ${checkout_started}"
Raises:
SyntaxError: If the expression cannot be parsed
ValueError: If the expression contains disallowed operations
"""
expression: str
def __post_init__(self) -> None:
# Validate the expression immediately upon creation
try:
# Extract variable references and replace with placeholders
self._variable_names = self._extract_variable_names(self.expression)
# Convert symbolic operators to Python keywords
python_expr = self._convert_to_python_syntax(self.expression)
# Sanitize for AST parsing
sanitized_expr = self._prepare_for_ast(python_expr)
# Use ast to parse and validate the expression
self._ast = ast.parse(sanitized_expr, mode="eval")
# Verify it only contains allowed operations
self._validate_operations(self._ast.body)
# Store the Python-syntax version for evaluation
self._python_expr = python_expr
except SyntaxError as e:
raise SyntaxError(f"Invalid expression syntax in '{self.expression}': {str(e)}")
except Exception as e:
raise ValueError(f"Error validating expression '{self.expression}': {str(e)}")
def _extract_variable_names(self, expr: str) -> list[str]:
"""Extract all variable references ${var_name} from the expression."""
# Find all patterns like ${var_name}
matches = re.findall(r"\${([^}]*)}", expr)
return matches
def _convert_to_python_syntax(self, expr: str) -> str:
"""Convert symbolic operators to Python keywords."""
# We need to be careful about operators inside string literals
# First, temporarily replace string literals with placeholders
string_literals = []
def replace_string_literal(match: re.Match[str]) -> str:
string_literals.append(match.group(0))
return f"__STRING_LITERAL_{len(string_literals) - 1}__"
# Replace both single and double quoted strings
expr_without_strings = re.sub(r"'[^']*'|\"[^\"]*\"", replace_string_literal, expr)
# Handle the NOT operator (!) - no parentheses handling needed
# Replace standalone ! before variables or expressions
expr_without_strings = re.sub(r"!\s*(\${|\()", "not \\1", expr_without_strings)
# Handle AND and OR operators - simpler approach without parentheses handling
expr_without_strings = re.sub(r"\s+&\s+", " and ", expr_without_strings)
expr_without_strings = re.sub(r"\s+\|\s+", " or ", expr_without_strings)
# Now put string literals back
for i, literal in enumerate(string_literals):
expr_without_strings = expr_without_strings.replace(f"__STRING_LITERAL_{i}__", literal)
return expr_without_strings
def _prepare_for_ast(self, expr: str) -> str:
"""Convert the expression to valid Python for AST parsing by replacing variables with placeholders."""
# Replace ${var_name} with var_name for AST parsing
processed_expr = expr
for var_name in self._variable_names:
processed_expr = processed_expr.replace(f"${{{var_name}}}", var_name)
return processed_expr
def _validate_operations(self, node: ast.AST) -> None:
"""Recursively validate that only allowed operations exist in the AST."""
allowed_node_types = (
# Boolean operations
ast.BoolOp,
ast.UnaryOp,
ast.And,
ast.Or,
ast.Not,
# Comparison operations
ast.Compare,
ast.Eq,
ast.NotEq,
ast.Lt,
ast.LtE,
ast.Gt,
ast.GtE,
# Basic nodes
ast.Name,
ast.Load,
ast.Constant,
ast.Expression,
# Support for basic numeric operations in comparisons
ast.Num,
ast.NameConstant,
# Support for negative numbers
ast.USub,
ast.UnaryOp,
# Support for string literals
ast.Str,
ast.Constant,
# Support for function calls (specifically len())
ast.Call,
)
if not isinstance(node, allowed_node_types):
raise ValueError(f"Operation type {type(node).__name__} is not allowed in logical expressions")
# Special validation for function calls - only allow len()
if isinstance(node, ast.Call):
if not (isinstance(node.func, ast.Name) and node.func.id == "len"):
raise ValueError(f"Only the len() function is allowed, got: {getattr(node.func, 'id', 'unknown')}")
if len(node.args) != 1:
raise ValueError(f"len() function must have exactly one argument, got {len(node.args)}")
# Special validation for Compare nodes
if isinstance(node, ast.Compare):
for op in node.ops:
if not isinstance(op, (ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE)):
raise ValueError(f"Comparison operator {type(op).__name__} is not allowed")
# Recursively check child nodes
for child in ast.iter_child_nodes(node):
self._validate_operations(child)
def evaluate(self, context_variables: ContextVariables) -> bool:
"""Evaluate the expression using the provided context variables.
Args:
context_variables: Dictionary of context variables to use for evaluation
Returns:
bool: The result of evaluating the expression
Raises:
KeyError: If a variable referenced in the expression is not found in the context
"""
# Create a modified expression that we can safely evaluate
eval_expr = self._python_expr # Use the Python-syntax version
# First, handle len() functions with variable references inside
len_pattern = r"len\(\${([^}]*)}\)"
len_matches = list(re.finditer(len_pattern, eval_expr))
# Process all len() operations first
for match in len_matches:
var_name = match.group(1)
# Check if variable exists in context, raise KeyError if not
if not context_variables.contains(var_name):
raise KeyError(f"Missing context variable: '{var_name}'")
var_value = context_variables.get(var_name)
# Calculate the length - works for lists, strings, dictionaries, etc.
try:
length_value = len(var_value) # type: ignore[arg-type]
except TypeError:
# If the value doesn't support len(), treat as 0
length_value = 0
# Replace the len() expression with the actual length
full_match = match.group(0)
eval_expr = eval_expr.replace(full_match, str(length_value))
# Then replace remaining variable references with their values
for var_name in self._variable_names:
# Skip variables that were already processed in len() expressions
if any(m.group(1) == var_name for m in len_matches):
continue
# Check if variable exists in context, raise KeyError if not
if not context_variables.contains(var_name):
raise KeyError(f"Missing context variable: '{var_name}'")
# Get the value from context
var_value = context_variables.get(var_name)
# Format the value appropriately based on its type
if isinstance(var_value, (bool, int, float)):
formatted_value = str(var_value)
elif isinstance(var_value, str):
formatted_value = f"'{var_value}'" # Quote strings
elif isinstance(var_value, (list, dict, tuple)):
# For collections, convert to their boolean evaluation
formatted_value = str(bool(var_value))
else:
formatted_value = str(var_value)
# Replace the variable reference with the formatted value
eval_expr = eval_expr.replace(f"${{{var_name}}}", formatted_value)
try:
return eval(eval_expr) # type: ignore[no-any-return]
except Exception as e:
raise ValueError(
f"Error evaluating expression '{self.expression}' (are you sure you're using ${{my_context_variable_key}}): {str(e)}"
)
def __str__(self) -> str:
return f"ContextExpression('{self.expression}')"

View File

@@ -0,0 +1,41 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from pydantic import BaseModel
from .context_variables import ContextVariables
__all__ = ["ContextStr"]
class ContextStr(BaseModel):
"""A string that requires context variable substitution.
Use the format method to substitute context variables into the string.
"""
"""The string to be substituted with context variables. It is expected that the string will contain `{var}` placeholders and that string format will be able to replace all values."""
template: str
def format(self, context_variables: ContextVariables) -> Optional[str]:
"""Substitute context variables into the string.
Args:
context_variables (ContextVariables): The context variables to substitute into the string.
Returns:
Optional[str]: The formatted string with context variables substituted.
"""
context = context_variables.to_dict()
if not context:
return self.template
return self.template.format(**context)
def __str__(self) -> str:
return f"ContextStr, unformatted: {self.template}"

View File

@@ -0,0 +1,192 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Generator, Iterable, Optional
from pydantic import BaseModel, Field
__all__ = ["ContextVariables"]
# Parameter name for context variables
# Use the value in functions and they will be substituted with the context variables:
# e.g. def my_function(context_variables: ContextVariables, my_other_parameters: Any) -> Any:
__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables"
class ContextVariables(BaseModel):
"""
Stores and manages context variables for agentic workflows.
Utilises a dictionary-like interface for setting, getting, and removing variables.
"""
# Internal storage for context variables
data: dict[str, Any] = Field(default_factory=dict)
def __init__(self, data: Optional[dict[str, Any]] = None, **kwargs: Any) -> None:
"""Initialize with data dictionary as an optional positional parameter.
Args:
data: Initial dictionary of context variables (optional)
kwargs: Additional keyword arguments for the parent class
"""
init_data = data or {}
super().__init__(data=init_data, **kwargs)
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""
Get a value from the context by key.
Args:
key: The key to retrieve
default: The default value to return if key is not found
Returns:
The value associated with the key or default if not found
"""
return self.data.get(key, default)
def set(self, key: str, value: Any) -> None:
"""
Set a value in the context by key.
Args:
key: The key to set
value: The value to store
"""
self.data[key] = value
def remove(self, key: str) -> bool:
"""
Remove a key from the context.
Args:
key: The key to remove
Returns:
True if the key was removed, False if it didn't exist
"""
if key in self.data:
del self.data[key]
return True
return False
def keys(self) -> Iterable[str]:
"""
Get all keys in the context.
Returns:
An iterable of all keys
"""
return self.data.keys()
def values(self) -> Iterable[Any]:
"""
Get all values in the context.
Returns:
An iterable of all values
"""
return self.data.values()
def items(self) -> Iterable[tuple[str, Any]]:
"""
Get all key-value pairs in the context.
Returns:
An iterable of all key-value pairs
"""
return self.data.items()
def clear(self) -> None:
"""Clear all keys and values from the context."""
self.data.clear()
def contains(self, key: str) -> bool:
"""
Check if a key exists in the context.
Args:
key: The key to check
Returns:
True if the key exists, False otherwise
"""
return key in self.data
def update(self, other: dict[str, Any]) -> None:
"""
Update context with key-value pairs from another dictionary.
Args:
other: Dictionary containing key-value pairs to add
"""
self.data.update(other)
def to_dict(self) -> dict[str, Any]:
"""
Convert context variables to a dictionary.
Returns:
Dictionary representation of all context variables
"""
return self.data.copy()
# Dictionary-compatible interface
def __getitem__(self, key: str) -> Any:
"""Get a value using dictionary syntax: context[key]"""
try:
return self.data[key]
except KeyError:
raise KeyError(f"Context variable '{key}' not found")
def __setitem__(self, key: str, value: Any) -> None:
"""Set a value using dictionary syntax: context[key] = value"""
self.data[key] = value
def __delitem__(self, key: str) -> None:
"""Delete a key using dictionary syntax: del context[key]"""
try:
del self.data[key]
except KeyError:
raise KeyError(f"Cannot delete non-existent context variable '{key}'")
def __contains__(self, key: str) -> bool:
"""Check if key exists using 'in' operator: key in context"""
return key in self.data
def __len__(self) -> int:
"""Get the number of items: len(context)"""
return len(self.data)
def __iter__(self) -> Generator[tuple[str, Any], None, None]:
"""Iterate over keys: for key in context"""
for key in self.data:
yield (key, self.data[key])
def __str__(self) -> str:
"""String representation of context variables."""
return f"ContextVariables({self.data})"
def __repr__(self) -> str:
"""Detailed representation of context variables."""
return f"ContextVariables(data={self.data!r})"
# Utility methods
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ContextVariables":
"""
Create a new ContextVariables instance from a dictionary.
E.g.:
my_context = {"user_id": "12345", "settings": {"theme": "dark"}}
context = ContextVariables.from_dict(my_context)
Args:
data: Dictionary of key-value pairs
Returns:
New ContextVariables instance
"""
return cls(data=data)

View File

@@ -0,0 +1,202 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import inspect
from copy import deepcopy
from typing import Annotated, Any, Callable, Optional
from ...oai import OpenAIWrapper
from ...tools import Depends, Tool
from ...tools.dependency_injection import inject_params, on
from ..agent import Agent
from ..conversable_agent import ConversableAgent
from .context_variables import __CONTEXT_VARIABLES_PARAM_NAME__, ContextVariables
from .reply_result import ReplyResult
from .targets.transition_target import TransitionTarget
__TOOL_EXECUTOR_NAME__ = "_Group_Tool_Executor"
class GroupToolExecutor(ConversableAgent):
"""Tool executor for the group chat initiated with initiate_group_chat"""
def __init__(self) -> None:
super().__init__(
name=__TOOL_EXECUTOR_NAME__,
system_message="Tool Execution, do not use this agent directly.",
human_input_mode="NEVER",
code_execution_config=False,
)
# Store the next target from a tool call
self._group_next_target: Optional[TransitionTarget] = None
# Primary tool reply function for handling the tool reply and the ReplyResult and TransitionTarget returns
self.register_reply([Agent, None], self._generate_group_tool_reply, remove_other_reply_funcs=True)
def set_next_target(self, next_target: TransitionTarget) -> None:
"""Sets the next target to transition to, used in the determine_next_agent function."""
self._group_next_target = next_target
def get_next_target(self) -> TransitionTarget:
"""Gets the next target to transition to."""
"""Returns the next target to transition to, if it exists."""
if self._group_next_target is None:
raise ValueError(
"No next target set. Please set a next target before calling this method. Use has_next_target() to check if a next target exists."
)
return self._group_next_target
def has_next_target(self) -> bool:
"""Checks if there is a next target to transition to."""
return self._group_next_target is not None
def clear_next_target(self) -> None:
"""Clears the next target to transition to."""
self._group_next_target = None
def _modify_context_variables_param(
self, f: Callable[..., Any], context_variables: ContextVariables
) -> Callable[..., Any]:
"""Modifies the context_variables parameter to use dependency injection and link it to the group context variables.
This essentially changes:
def some_function(some_variable: int, context_variables: ContextVariables) -> str:
to:
def some_function(some_variable: int, context_variables: Annotated[ContextVariables, Depends(on(self.context_variables))]) -> str:
"""
sig = inspect.signature(f)
# Check if context_variables parameter exists and update it if so
if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters:
new_params = []
for name, param in sig.parameters.items():
if name == __CONTEXT_VARIABLES_PARAM_NAME__:
# Replace with new annotation using Depends
new_param = param.replace(annotation=Annotated[ContextVariables, Depends(on(context_variables))])
new_params.append(new_param)
else:
new_params.append(param)
# Update signature
new_sig = sig.replace(parameters=new_params)
f.__signature__ = new_sig # type: ignore[attr-defined]
return f
def _change_tool_context_variables_to_depends(
self, agent: ConversableAgent, current_tool: Tool, context_variables: ContextVariables
) -> None:
"""Checks for the context_variables parameter in the tool and updates it to use dependency injection."""
# If the tool has a context_variables parameter, remove the tool and reregister it without the parameter
if __CONTEXT_VARIABLES_PARAM_NAME__ in current_tool.tool_schema["function"]["parameters"]["properties"]:
# We'll replace the tool, so start with getting the underlying function
tool_func = current_tool._func
# Remove the Tool from the agent
name = current_tool._name
description = current_tool._description
agent.remove_tool_for_llm(current_tool)
# Recreate the tool without the context_variables parameter
tool_func = self._modify_context_variables_param(current_tool._func, context_variables)
tool_func = inject_params(tool_func)
new_tool = ConversableAgent._create_tool_if_needed(
func_or_tool=tool_func, name=name, description=description
)
# Re-register with the agent
agent.register_for_llm()(new_tool)
def register_agents_functions(self, agents: list[ConversableAgent], context_variables: ContextVariables) -> None:
"""Adds the functions of the agents to the group tool executor."""
for agent in agents:
# As we're moving towards tools and away from function maps, this may not be used
self._function_map.update(agent._function_map)
# Update any agent tools that have context_variables parameters to use Dependency Injection
for tool in agent.tools:
self._change_tool_context_variables_to_depends(agent, tool, context_variables)
# Add all tools to the Tool Executor agent
for tool in agent.tools:
self.register_for_execution(serialize=False, silent_override=True)(tool)
def _generate_group_tool_reply(
self,
agent: ConversableAgent,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[OpenAIWrapper] = None,
) -> tuple[bool, Optional[dict[str, Any]]]:
"""Pre-processes and generates tool call replies.
This function:
1. Adds context_variables back to the tool call for the function, if necessary.
2. Generates the tool calls reply.
3. Updates context_variables and next_agent based on the tool call response."""
if config is None:
config = agent # type: ignore[assignment]
if messages is None:
messages = agent._oai_messages[sender]
message = messages[-1]
if "tool_calls" in message:
tool_call_count = len(message["tool_calls"])
# Loop through tool calls individually (so context can be updated after each function call)
next_target: Optional[TransitionTarget] = None
tool_responses_inner = []
contents = []
for index in range(tool_call_count):
message_copy = deepcopy(message)
# 1. add context_variables to the tool call arguments
tool_call = message_copy["tool_calls"][index]
# Ensure we are only executing the one tool at a time
message_copy["tool_calls"] = [tool_call]
# 2. generate tool calls reply
_, tool_message = agent.generate_tool_calls_reply([message_copy])
if tool_message is None:
raise ValueError("Tool call did not return a message")
# 3. update context_variables and next_agent, convert content to string
for tool_response in tool_message["tool_responses"]:
content = tool_response.get("content")
# Tool Call returns that are a target are either a ReplyResult or a TransitionTarget are the next agent
if isinstance(content, ReplyResult):
if content.context_variables and content.context_variables.to_dict() != {}:
agent.context_variables.update(content.context_variables.to_dict())
if content.target is not None:
next_target = content.target
elif isinstance(content, TransitionTarget):
next_target = content
# Serialize the content to a string
if content is not None:
tool_response["content"] = str(content)
tool_responses_inner.append(tool_response)
contents.append(str(tool_response["content"]))
self._group_next_target = next_target # type: ignore[attr-defined]
# Put the tool responses and content strings back into the response message
# Caters for multiple tool calls
if tool_message is None:
raise ValueError("Tool call did not return a message")
tool_message["tool_responses"] = tool_responses_inner
tool_message["content"] = "\n".join(contents)
return True, tool_message
return False, None

View File

@@ -0,0 +1,636 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import copy
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from ..agent import Agent
from ..groupchat import GroupChat, GroupChatManager
from .context_variables import ContextVariables
from .group_tool_executor import GroupToolExecutor
from .targets.group_manager_target import GroupManagerTarget
from .targets.transition_target import (
AgentNameTarget,
AgentTarget,
TransitionTarget,
)
if TYPE_CHECKING:
from ..conversable_agent import ConversableAgent
# Utility functions for group chat preparation and management
# These are extracted from multi_agent_chat.py to avoid circular imports
def update_conditional_functions(agent: "ConversableAgent", messages: list[dict[str, Any]]) -> None:
"""Updates the agent's functions based on the OnCondition's available condition.
All functions are removed and then added back if they are available
"""
for on_condition in agent.handoffs.llm_conditions:
is_available = on_condition.available.is_available(agent, messages) if on_condition.available else True
# Remove it from their tools
for tool in agent.tools:
if tool.name == on_condition.llm_function_name:
agent.remove_tool_for_llm(tool)
break
# then add the function if it is available, so that the function signature is updated
if is_available:
agent._add_single_function(
_create_on_condition_handoff_function(on_condition.target),
on_condition.llm_function_name,
on_condition.condition.get_prompt(agent, messages),
)
def establish_group_agent(agent: "ConversableAgent") -> None:
"""Establish the group agent with the group-related attributes and hooks. Not for the tool executor.
Args:
agent ("ConversableAgent"): The agent to establish as a group agent.
"""
def _group_agent_str(self: "ConversableAgent") -> str:
"""Customise the __str__ method to show the agent name for transition messages."""
return f"Group agent --> {self.name}"
# Register the hook to update agent state (except tool executor)
agent.register_hook("update_agent_state", update_conditional_functions)
# Register a reply function to run Python function-based OnContextConditions before any other reply function
agent.register_reply(trigger=([Agent, None]), reply_func=_run_oncontextconditions, position=0)
agent._get_display_name = MethodType(_group_agent_str, agent) # type: ignore[method-assign]
# Mark this agent as established as a group agent
agent._group_is_established = True # type: ignore[attr-defined]
def link_agents_to_group_manager(agents: list[Agent], group_chat_manager: Agent) -> None:
"""Link all agents to the GroupChatManager so they can access the underlying GroupChat and other agents.
This is primarily used so that agents can get to the tool executor to help set the next agent.
Does not link the Tool Executor agent.
"""
for agent in agents:
agent._group_manager = group_chat_manager # type: ignore[attr-defined]
def _evaluate_after_works_conditions(
agent: "ConversableAgent",
groupchat: GroupChat,
user_agent: Optional["ConversableAgent"],
) -> Optional[Union[Agent, str]]:
"""Evaluate after_works context conditions for an agent.
Args:
agent: The agent to evaluate after_works conditions for
groupchat: The current group chat
user_agent: Optional user proxy agent
Returns:
The resolved speaker selection result if a condition matches, None otherwise
"""
if not hasattr(agent, "handoffs") or not agent.handoffs.after_works: # type: ignore[attr-defined]
return None
for after_work_condition in agent.handoffs.after_works: # type: ignore[attr-defined]
# Check if condition is available
is_available = (
after_work_condition.available.is_available(agent, groupchat.messages)
if after_work_condition.available
else True
)
# Evaluate the condition (None condition means always true)
if is_available and (
after_work_condition.condition is None or after_work_condition.condition.evaluate(agent.context_variables)
):
# Condition matched, resolve and return
return after_work_condition.target.resolve(
groupchat,
agent,
user_agent,
).get_speaker_selection_result(groupchat)
return None
def _run_oncontextconditions(
agent: "ConversableAgent",
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
"""Run OnContextConditions for an agent before any other reply function."""
for on_condition in agent.handoffs.context_conditions: # type: ignore[attr-defined]
is_available = (
on_condition.available.is_available(agent, messages if messages else []) if on_condition.available else True
)
if is_available and (
on_condition.condition is None or on_condition.condition.evaluate(agent.context_variables)
):
# Condition has been met, we'll set the Tool Executor's next target
# attribute and that will be picked up on the next iteration when
# _determine_next_agent is called
for agent in agent._group_manager.groupchat.agents: # type: ignore[attr-defined]
if isinstance(agent, GroupToolExecutor):
agent.set_next_target(on_condition.target)
break
transfer_name = on_condition.target.display_name()
return True, "[Handing off to " + transfer_name + "]"
return False, None
def _create_on_condition_handoff_function(target: TransitionTarget) -> Callable[[], TransitionTarget]:
"""Creates a function that will be used by the tool call reply function when the condition is met.
Args:
target (TransitionTarget): The target to transfer to.
Returns:
Callable: The transfer function.
"""
def transfer_to_target() -> TransitionTarget:
return target
return transfer_to_target
def create_on_condition_handoff_functions(agent: "ConversableAgent") -> None:
"""Creates the functions for the OnConditions so that the current tool handling works.
Args:
agent ("ConversableAgent"): The agent to create the functions for.
"""
# Populate the function names for the handoffs
agent.handoffs.set_llm_function_names()
# Create a function for each OnCondition
for on_condition in agent.handoffs.llm_conditions:
# Create a function that will be called when the condition is met
agent._add_single_function(
_create_on_condition_handoff_function(on_condition.target),
on_condition.llm_function_name,
on_condition.condition.get_prompt(agent, []),
)
def ensure_handoff_agents_in_group(agents: list["ConversableAgent"]) -> None:
"""Ensure the agents in handoffs are in the group chat."""
agent_names = [agent.name for agent in agents]
for agent in agents:
for llm_conditions in agent.handoffs.llm_conditions:
if (
isinstance(llm_conditions.target, (AgentTarget, AgentNameTarget))
and llm_conditions.target.agent_name not in agent_names
):
raise ValueError("Agent in OnCondition Hand-offs must be in the agents list")
for context_conditions in agent.handoffs.context_conditions:
if (
isinstance(context_conditions.target, (AgentTarget, AgentNameTarget))
and context_conditions.target.agent_name not in agent_names
):
raise ValueError("Agent in OnContextCondition Hand-offs must be in the agents list")
# Check after_works targets
for after_work_condition in agent.handoffs.after_works:
if (
isinstance(after_work_condition.target, (AgentTarget, AgentNameTarget))
and after_work_condition.target.agent_name not in agent_names
):
raise ValueError("Agent in after work target Hand-offs must be in the agents list")
def prepare_exclude_transit_messages(agents: list["ConversableAgent"]) -> None:
"""Preparation for excluding transit messages by getting all tool names and registering a hook on agents to remove those messages."""
# get all transit functions names
to_be_removed: list[str] = []
for agent in agents:
for on_condition in agent.handoffs.llm_conditions:
if on_condition.llm_function_name:
to_be_removed.append(on_condition.llm_function_name)
else:
raise ValueError("OnCondition must have a function name")
remove_function = make_remove_function(to_be_removed)
# register hook to remove transit messages for group agents
for agent in agents:
agent.register_hook("process_all_messages_before_reply", remove_function)
def prepare_group_agents(
agents: list["ConversableAgent"],
context_variables: ContextVariables,
exclude_transit_message: bool = True,
) -> tuple[GroupToolExecutor, list["ConversableAgent"]]:
"""Validates agents, create the tool executor, wrap necessary targets in agents.
Args:
agents (list["ConversableAgent"]): List of all agents in the conversation.
context_variables (ContextVariables): Context variables to assign to all agents.
exclude_transit_message (bool): Whether to exclude transit messages from the agents.
Returns:
"ConversableAgent": The tool executor agent.
list["ConversableAgent"]: List of wrapped agents.
"""
# Initialise all agents as group agents
for agent in agents:
if not hasattr(agent, "_group_is_established"):
establish_group_agent(agent)
# Ensure all agents in hand-off after-works are in the passed in agents list
ensure_handoff_agents_in_group(agents)
# Create Tool Executor for the group
tool_execution = GroupToolExecutor()
# Wrap handoff targets in agents that need to be wrapped
wrapped_chat_agents: list["ConversableAgent"] = []
for agent in agents:
wrap_agent_handoff_targets(agent, wrapped_chat_agents)
# Create the functions for the OnConditions so that the current tool handling works
for agent in agents:
create_on_condition_handoff_functions(agent)
# Register all the agents' functions with the tool executor and
# use dependency injection for the context variables parameter
# Update tool execution agent with all the functions from all the agents
tool_execution.register_agents_functions(agents + wrapped_chat_agents, context_variables)
if exclude_transit_message:
prepare_exclude_transit_messages(agents)
return tool_execution, wrapped_chat_agents
def wrap_agent_handoff_targets(agent: "ConversableAgent", wrapped_agent_list: list["ConversableAgent"]) -> None:
"""Wrap handoff targets in agents that need to be wrapped to be part of the group chat.
Example is NestedChatTarget.
Args:
agent ("ConversableAgent"): The agent to wrap the handoff targets for.
wrapped_agent_list (list["ConversableAgent"]): List of wrapped chat agents that will be appended to.
"""
# Wrap OnCondition targets
for i, handoff_oncondition_requiring_wrapping in enumerate(agent.handoffs.get_llm_conditions_requiring_wrapping()):
# Create wrapper agent
wrapper_agent = handoff_oncondition_requiring_wrapping.target.create_wrapper_agent(parent_agent=agent, index=i)
wrapped_agent_list.append(wrapper_agent)
# Change this handoff target to point to the newly created agent
handoff_oncondition_requiring_wrapping.target = AgentTarget(wrapper_agent)
for i, handoff_oncontextcondition_requiring_wrapping in enumerate(
agent.handoffs.get_context_conditions_requiring_wrapping()
):
# Create wrapper agent
wrapper_agent = handoff_oncontextcondition_requiring_wrapping.target.create_wrapper_agent(
parent_agent=agent, index=i
)
wrapped_agent_list.append(wrapper_agent)
# Change this handoff target to point to the newly created agent
handoff_oncontextcondition_requiring_wrapping.target = AgentTarget(wrapper_agent)
def process_initial_messages(
messages: Union[list[dict[str, Any]], str],
user_agent: Optional["ConversableAgent"],
agents: list["ConversableAgent"],
wrapped_agents: list["ConversableAgent"],
) -> tuple[list[dict[str, Any]], Optional["ConversableAgent"], list[str], list[Agent]]:
"""Process initial messages, validating agent names against messages, and determining the last agent to speak.
Args:
messages: Initial messages to process.
user_agent: Optional user proxy agent passed in to a_/initiate_group_chat.
agents: Agents in the group.
wrapped_agents: List of wrapped agents.
Returns:
list[dict[str, Any]]: Processed message(s).
Agent: Last agent to speak.
list[str]: List of agent names.
list[Agent]: List of temporary user proxy agents to add to GroupChat.
"""
from ..conversable_agent import ConversableAgent # NEED SOLUTION
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
group_agent_names = [agent.name for agent in agents + wrapped_agents]
# If there's only one message and there's no identified group agent
# Start with a user proxy agent, creating one if they haven't passed one in
last_agent: Optional[ConversableAgent]
temp_user_proxy: Optional[ConversableAgent] = None
temp_user_list: list[Agent] = []
if len(messages) == 1 and "name" not in messages[0] and not user_agent:
temp_user_proxy = ConversableAgent(name="_User", code_execution_config=False, human_input_mode="ALWAYS")
last_agent = temp_user_proxy
temp_user_list.append(temp_user_proxy)
else:
last_message = messages[0]
if "name" in last_message:
if last_message["name"] in group_agent_names:
last_agent = next(agent for agent in agents + wrapped_agents if agent.name == last_message["name"]) # type: ignore[assignment]
elif user_agent and last_message["name"] == user_agent.name:
last_agent = user_agent
else:
raise ValueError(f"Invalid group agent name in last message: {last_message['name']}")
else:
last_agent = user_agent if user_agent else temp_user_proxy
return messages, last_agent, group_agent_names, temp_user_list
def setup_context_variables(
tool_execution: "ConversableAgent",
agents: list["ConversableAgent"],
manager: GroupChatManager,
user_agent: Optional["ConversableAgent"],
context_variables: ContextVariables,
) -> None:
"""Assign a common context_variables reference to all agents in the group, including the tool executor, group chat manager, and user proxy agent.
Args:
tool_execution: The tool execution agent.
agents: List of all agents in the conversation.
manager: GroupChatManager instance.
user_agent: Optional user proxy agent.
context_variables: Context variables to assign to all agents.
"""
for agent in agents + [tool_execution] + [manager] + ([user_agent] if user_agent else []):
agent.context_variables = context_variables
def cleanup_temp_user_messages(chat_result: Any) -> None:
"""Remove temporary user proxy agent name from messages before returning.
Args:
chat_result: ChatResult instance.
"""
for message in chat_result.chat_history:
if "name" in message and message["name"] == "_User":
del message["name"]
def get_last_agent_speaker(
groupchat: GroupChat, group_agent_names: list[str], tool_executor: GroupToolExecutor
) -> Agent:
"""Get the last group agent from the group chat messages. Not including the tool executor."""
last_group_speaker = None
for message in reversed(groupchat.messages):
if "name" in message and message["name"] in group_agent_names and message["name"] != tool_executor.name:
agent = groupchat.agent_by_name(name=message["name"])
if agent:
last_group_speaker = agent
break
if last_group_speaker is None:
raise ValueError("No group agent found in the message history")
return last_group_speaker
def determine_next_agent(
last_speaker: "ConversableAgent",
groupchat: GroupChat,
initial_agent: "ConversableAgent",
use_initial_agent: bool,
tool_executor: GroupToolExecutor,
group_agent_names: list[str],
user_agent: Optional["ConversableAgent"],
group_after_work: TransitionTarget,
) -> Optional[Union[Agent, str]]:
"""Determine the next agent in the conversation.
Args:
last_speaker ("ConversableAgent"): The last agent to speak.
groupchat (GroupChat): GroupChat instance.
initial_agent ("ConversableAgent"): The initial agent in the conversation.
use_initial_agent (bool): Whether to use the initial agent straight away.
tool_executor ("ConversableAgent"): The tool execution agent.
group_agent_names (list[str]): List of agent names.
user_agent (UserProxyAgent): Optional user proxy agent.
group_after_work (TransitionTarget): Group-level Transition option when an agent doesn't select the next agent.
Returns:
Optional[Union[Agent, str]]: The next agent or speaker selection method.
"""
# Logic for determining the next target (anything based on Transition Target: an agent, wrapped agent, TerminateTarget, StayTarget, RevertToUserTarget, GroupManagerTarget, etc.
# 1. If it's the first response -> initial agent
# 2. If the last message is a tool call -> tool execution agent
# 3. If the Tool Executor has determined a next target (e.g. ReplyResult specified target) -> transition to tool reply target
# 4. If the user last spoke -> return to the previous agent
# NOW "AFTER WORK":
# 5. Get the After Work condition (if the agent doesn't have one, get the group-level one)
# 6. Resolve and return the After Work condition -> agent / wrapped agent / TerminateTarget / StayTarget / RevertToUserTarget / GroupManagerTarget / etc.
# 1. If it's the first response, return the initial agent
if use_initial_agent:
return initial_agent
# 2. If the last message is a tool call, return the tool execution agent
if "tool_calls" in groupchat.messages[-1]:
return tool_executor
# 3. If the Tool Executor has determined a next target, return that
if tool_executor.has_next_target():
next_agent = tool_executor.get_next_target()
tool_executor.clear_next_target()
if next_agent.can_resolve_for_speaker_selection():
return next_agent.resolve(groupchat, last_speaker, user_agent).get_speaker_selection_result(groupchat)
else:
raise ValueError(
"Tool Executor next target must be a valid TransitionTarget that can resolve for speaker selection."
)
# get the last group agent
last_agent_speaker = get_last_agent_speaker(groupchat, group_agent_names, tool_executor)
# If we are returning from a tool execution, return to the last agent that spoke
if groupchat.messages[-1]["role"] == "tool":
return last_agent_speaker
# If the user last spoke, return to the agent prior to them (if they don't have an after work, otherwise it's treated like any other agent)
if user_agent and last_speaker == user_agent:
if not user_agent.handoffs.after_works:
return last_agent_speaker
else:
last_agent_speaker = user_agent
# AFTER WORK:
# First, try to evaluate after_works context conditions
after_works_result = _evaluate_after_works_conditions(
last_agent_speaker, # type: ignore[arg-type]
groupchat,
user_agent,
)
if after_works_result is not None:
return after_works_result
# If no after_works conditions matched, use the group-level after_work
# Resolve the next agent, termination, or speaker selection method
resolved_speaker_selection_result = group_after_work.resolve(
groupchat,
last_agent_speaker, # type: ignore[arg-type]
user_agent,
).get_speaker_selection_result(groupchat)
return resolved_speaker_selection_result
def create_group_transition(
initial_agent: "ConversableAgent",
tool_execution: GroupToolExecutor,
group_agent_names: list[str],
user_agent: Optional["ConversableAgent"],
group_after_work: TransitionTarget,
) -> Callable[["ConversableAgent", GroupChat], Optional[Union[Agent, str]]]:
"""Creates a transition function for group chat with enclosed state for the use_initial_agent.
Args:
initial_agent ("ConversableAgent"): The first agent to speak
tool_execution (GroupToolExecutor): The tool execution agent
group_agent_names (list[str]): List of all agent names
user_agent (UserProxyAgent): Optional user proxy agent
group_after_work (TransitionTarget): Group-level after work
Returns:
Callable[["ConversableAgent", GroupChat], Optional[Union[Agent, str]]]: The transition function
"""
# Create enclosed state, this will be set once per creation so will only be True on the first execution
# of group_transition
state = {"use_initial_agent": True}
def group_transition(last_speaker: "ConversableAgent", groupchat: GroupChat) -> Optional[Union[Agent, str]]:
result = determine_next_agent(
last_speaker=last_speaker,
groupchat=groupchat,
initial_agent=initial_agent,
use_initial_agent=state["use_initial_agent"],
tool_executor=tool_execution,
group_agent_names=group_agent_names,
user_agent=user_agent,
group_after_work=group_after_work,
)
state["use_initial_agent"] = False
return result
return group_transition
def create_group_manager(
groupchat: GroupChat,
group_manager_args: Optional[dict[str, Any]],
agents: list["ConversableAgent"],
group_after_work: TransitionTarget,
) -> GroupChatManager:
"""Create a GroupChatManager for the group chat utilising any arguments passed in and ensure an LLM Config exists if needed
Args:
groupchat (GroupChat): The groupchat.
group_manager_args (dict[str, Any]): Group manager arguments to create the GroupChatManager.
agents (list["ConversableAgent"]): List of agents in the group to check handoffs and after work.
group_after_work (TransitionTarget): Group-level after work to check.
Returns:
GroupChatManager: GroupChatManager instance.
"""
manager_args = (group_manager_args or {}).copy()
if "groupchat" in manager_args:
raise ValueError("'groupchat' cannot be specified in group_manager_args as it is set by initiate_group_chat")
manager = GroupChatManager(groupchat, **manager_args)
# Ensure that our manager has an LLM Config if we have any GroupManagerTarget targets used
if manager.llm_config is False:
has_group_manager_target = False
if isinstance(group_after_work, GroupManagerTarget):
# Check group after work
has_group_manager_target = True
else:
# Check agent hand-offs and after work
for agent in agents:
if (
len(agent.handoffs.get_context_conditions_by_target_type(GroupManagerTarget)) > 0
or len(agent.handoffs.get_llm_conditions_by_target_type(GroupManagerTarget)) > 0
or any(isinstance(aw.target, GroupManagerTarget) for aw in agent.handoffs.after_works)
):
has_group_manager_target = True
break
if has_group_manager_target:
raise ValueError(
"The group manager doesn't have an LLM Config and it is required for any targets or after works using a GroupManagerTarget. Use the 'llm_config' in the group_manager_args parameter to specify the LLM Config for the group manager."
)
return manager
def make_remove_function(tool_msgs_to_remove: list[str]) -> Callable[[list[dict[str, Any]]], list[dict[str, Any]]]:
"""Create a function to remove messages with tool calls from the messages list.
The returned function can be registered as a hook to "process_all_messages_before_reply"" to remove messages with tool calls.
"""
def remove_messages(messages: list[dict[str, Any]], tool_msgs_to_remove: list[str]) -> list[dict[str, Any]]:
copied = copy.deepcopy(messages)
new_messages = []
removed_tool_ids = []
for message in copied:
# remove tool calls
if message.get("tool_calls") is not None:
filtered_tool_calls = []
for tool_call in message["tool_calls"]:
if tool_call.get("function") is not None and tool_call["function"]["name"] in tool_msgs_to_remove:
# remove
removed_tool_ids.append(tool_call["id"])
else:
filtered_tool_calls.append(tool_call)
if len(filtered_tool_calls) > 0:
message["tool_calls"] = filtered_tool_calls
else:
del message["tool_calls"]
if (
message.get("content") is None
or message.get("content") == ""
or message.get("content") == "None"
):
continue # if no tool call and no content, skip this message
# else: keep the message with tool_calls removed
# remove corresponding tool responses
elif message.get("tool_responses") is not None:
filtered_tool_responses = []
for tool_response in message["tool_responses"]:
if tool_response["tool_call_id"] not in removed_tool_ids:
filtered_tool_responses.append(tool_response)
if len(filtered_tool_responses) > 0:
message["tool_responses"] = filtered_tool_responses
else:
continue
new_messages.append(message)
return new_messages
return partial(remove_messages, tool_msgs_to_remove=tool_msgs_to_remove)

View File

@@ -0,0 +1,320 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Union, overload
from pydantic import BaseModel, Field
from .on_condition import OnCondition
from .on_context_condition import OnContextCondition
from .targets.transition_target import TransitionTarget
__all__ = ["Handoffs"]
class Handoffs(BaseModel):
"""
Container for all handoff transition conditions of a ConversableAgent.
Three types of conditions can be added, each with a different order and time of use:
1. OnContextConditions (evaluated without an LLM)
2. OnConditions (evaluated with an LLM)
3. After work TransitionTarget (if no other transition is triggered)
Supports method chaining:
agent.handoffs.add_context_conditions([condition1]) \
.add_llm_condition(condition2) \
.set_after_work(after_work)
"""
context_conditions: list[OnContextCondition] = Field(default_factory=list)
llm_conditions: list[OnCondition] = Field(default_factory=list)
after_works: list[OnContextCondition] = Field(default_factory=list)
def add_context_condition(self, condition: OnContextCondition) -> "Handoffs":
"""
Add a single context condition.
Args:
condition: The OnContextCondition to add
Returns:
Self for method chaining
"""
# Validate that it is an OnContextCondition
if not isinstance(condition, OnContextCondition):
raise TypeError(f"Expected an OnContextCondition instance, got {type(condition).__name__}")
self.context_conditions.append(condition)
return self
def add_context_conditions(self, conditions: list[OnContextCondition]) -> "Handoffs":
"""
Add multiple context conditions.
Args:
conditions: List of OnContextConditions to add
Returns:
Self for method chaining
"""
# Validate that it is a list of OnContextConditions
if not all(isinstance(condition, OnContextCondition) for condition in conditions):
raise TypeError("All conditions must be of type OnContextCondition")
self.context_conditions.extend(conditions)
return self
def add_llm_condition(self, condition: OnCondition) -> "Handoffs":
"""
Add a single LLM condition.
Args:
condition: The OnCondition to add
Returns:
Self for method chaining
"""
# Validate that it is an OnCondition
if not isinstance(condition, OnCondition):
raise TypeError(f"Expected an OnCondition instance, got {type(condition).__name__}")
self.llm_conditions.append(condition)
return self
def add_llm_conditions(self, conditions: list[OnCondition]) -> "Handoffs":
"""
Add multiple LLM conditions.
Args:
conditions: List of OnConditions to add
Returns:
Self for method chaining
"""
# Validate that it is a list of OnConditions
if not all(isinstance(condition, OnCondition) for condition in conditions):
raise TypeError("All conditions must be of type OnCondition")
self.llm_conditions.extend(conditions)
return self
def set_after_work(self, target: TransitionTarget) -> "Handoffs":
"""
Set the after work target (replaces all after_works with single entry).
For backward compatibility, this creates an OnContextCondition with no condition (always true).
Args:
target: The after work TransitionTarget to set
Returns:
Self for method chaining
"""
if not isinstance(target, TransitionTarget):
raise TypeError(f"Expected a TransitionTarget instance, got {type(target).__name__}")
# Create OnContextCondition with no condition (always true)
after_work_condition = OnContextCondition(target=target, condition=None)
self.after_works = [after_work_condition]
return self
def add_after_work(self, condition: OnContextCondition) -> "Handoffs":
"""
Add a single after-work condition.
If the condition has condition=None, it will replace any existing
condition=None entry and be placed at the end.
Args:
condition: The OnContextCondition to add
Returns:
Self for method chaining
"""
if not isinstance(condition, OnContextCondition):
raise TypeError(f"Expected an OnContextCondition instance, got {type(condition).__name__}")
if condition.condition is None:
# Remove any existing condition=None entries
self.after_works = [c for c in self.after_works if c.condition is not None]
# Add the new one at the end
self.after_works.append(condition)
else:
# For regular conditions, check if we need to move condition=None to the end
none_conditions = [c for c in self.after_works if c.condition is None]
if none_conditions:
# Remove the None condition temporarily
self.after_works = [c for c in self.after_works if c.condition is not None]
# Add the new regular condition
self.after_works.append(condition)
# Re-add the None condition at the end
self.after_works.append(none_conditions[0])
else:
# No None condition exists, just append
self.after_works.append(condition)
return self
def add_after_works(self, conditions: list[OnContextCondition]) -> "Handoffs":
"""
Add multiple after-work conditions.
Special handling for condition=None entries:
- Only one condition=None entry is allowed (the fallback)
- It will always be placed at the end of the list
- If multiple condition=None entries are provided, only the last one is kept
Args:
conditions: List of OnContextConditions to add
Returns:
Self for method chaining
"""
# Validate that it is a list of OnContextConditions
if not all(isinstance(condition, OnContextCondition) for condition in conditions):
raise TypeError("All conditions must be of type OnContextCondition")
# Separate conditions with None and without None
none_conditions = [c for c in conditions if c.condition is None]
regular_conditions = [c for c in conditions if c.condition is not None]
# Remove any existing condition=None entries
self.after_works = [c for c in self.after_works if c.condition is not None]
# Add regular conditions
self.after_works.extend(regular_conditions)
# Add at most one None condition at the end
if none_conditions:
self.after_works.append(none_conditions[-1]) # Use the last one if multiple provided
return self
@overload
def add(self, condition: OnContextCondition) -> "Handoffs": ...
@overload
def add(self, condition: OnCondition) -> "Handoffs": ...
def add(self, condition: Union[OnContextCondition, OnCondition]) -> "Handoffs":
"""
Add a single condition (OnContextCondition or OnCondition).
Args:
condition: The condition to add (OnContextCondition or OnCondition)
Raises:
TypeError: If the condition type is not supported
Returns:
Self for method chaining
"""
# This add method is a helper method designed to make it easier for
# adding handoffs without worrying about the specific type.
if isinstance(condition, OnContextCondition):
return self.add_context_condition(condition)
elif isinstance(condition, OnCondition):
return self.add_llm_condition(condition)
else:
raise TypeError(f"Unsupported condition type: {type(condition).__name__}")
def add_many(self, conditions: list[Union[OnContextCondition, OnCondition]]) -> "Handoffs":
"""
Add multiple conditions of any supported types (OnContextCondition and OnCondition).
Args:
conditions: List of conditions to add
Raises:
TypeError: If an unsupported condition type is provided
Returns:
Self for method chaining
"""
# This add_many method is a helper method designed to make it easier for
# adding handoffs without worrying about the specific type.
context_conditions = []
llm_conditions = []
for condition in conditions:
if isinstance(condition, OnContextCondition):
context_conditions.append(condition)
elif isinstance(condition, OnCondition):
llm_conditions.append(condition)
else:
raise TypeError(f"Unsupported condition type: {type(condition).__name__}")
if context_conditions:
self.add_context_conditions(context_conditions)
if llm_conditions:
self.add_llm_conditions(llm_conditions)
return self
def clear(self) -> "Handoffs":
"""
Clear all handoff conditions.
Returns:
Self for method chaining
"""
self.context_conditions.clear()
self.llm_conditions.clear()
self.after_works.clear()
return self
def get_llm_conditions_by_target_type(self, target_type: type) -> list[OnCondition]:
"""
Get OnConditions for a specific target type.
Args:
target_type: The type of condition to retrieve
Returns:
List of conditions of the specified type, or None if none exist
"""
return [on_condition for on_condition in self.llm_conditions if on_condition.has_target_type(target_type)]
def get_context_conditions_by_target_type(self, target_type: type) -> list[OnContextCondition]:
"""
Get OnContextConditions for a specific target type.
Args:
target_type: The type of condition to retrieve
Returns:
List of conditions of the specified type, or None if none exist
"""
return [
on_context_condition
for on_context_condition in self.context_conditions
if on_context_condition.has_target_type(target_type)
]
def get_llm_conditions_requiring_wrapping(self) -> list[OnCondition]:
"""
Get LLM conditions that have targets that require wrapping.
Returns:
List of LLM conditions that require wrapping
"""
return [condition for condition in self.llm_conditions if condition.target_requires_wrapping()]
def get_context_conditions_requiring_wrapping(self) -> list[OnContextCondition]:
"""
Get context conditions that have targets that require wrapping.
Returns:
List of context conditions that require wrapping
"""
return [condition for condition in self.context_conditions if condition.target_requires_wrapping()]
def set_llm_function_names(self) -> None:
"""
Set the LLM function names for all LLM conditions, creating unique names for each function.
"""
for i, condition in enumerate(self.llm_conditions):
# Function names are made unique and allow multiple OnCondition's to the same agent
condition.llm_function_name = f"transfer_to_{condition.target.normalized_name()}_{i + 1}"

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from .context_str import ContextStr
if TYPE_CHECKING:
# Avoid circular import
from ..conversable_agent import ConversableAgent
__all__ = ["ContextStrLLMCondition", "LLMCondition", "StringLLMCondition"]
class LLMCondition(BaseModel):
"""Protocol for conditions evaluated by an LLM."""
def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str:
"""Get the prompt text for LLM evaluation.
Args:
agent: The agent evaluating the condition
messages: The conversation history
Returns:
The prompt text to be evaluated by the LLM
"""
raise NotImplementedError("Requires subclasses to implement.")
class StringLLMCondition(LLMCondition):
"""Simple string-based LLM condition.
This condition provides a static string prompt to be evaluated by an LLM.
"""
prompt: str
def __init__(self, prompt: str, **data: Any) -> None:
"""Initialize with a prompt string as a positional parameter.
Args:
prompt: The static prompt string to evaluate
data: Additional data for the parent class
"""
super().__init__(prompt=prompt, **data)
def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str:
"""Return the static prompt string.
Args:
agent: The agent evaluating the condition (not used)
messages: The conversation history (not used)
Returns:
The static prompt string
"""
return self.prompt
class ContextStrLLMCondition(LLMCondition):
"""Context variable-based LLM condition.
This condition uses a ContextStr object with context variable placeholders that
will be substituted before being evaluated by an LLM.
"""
context_str: ContextStr
def __init__(self, context_str: ContextStr, **data: Any) -> None:
"""Initialize with a context string as a positional parameter.
Args:
context_str: The ContextStr object with variable placeholders
data: Additional data for the parent class
"""
super().__init__(context_str=context_str, **data)
def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str:
"""Return the prompt with context variables substituted.
Args:
agent: The agent evaluating the condition (provides context variables)
messages: The conversation history (not used)
Returns:
The prompt with context variables substituted
"""
result = self.context_str.format(agent.context_variables)
return result if result is not None else ""

View File

@@ -0,0 +1,237 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import threading
from typing import TYPE_CHECKING, Any, Union
from ...doc_utils import export_module
from ...events.agent_events import ErrorEvent, RunCompletionEvent
from ...io.base import IOStream
from ...io.run_response import AsyncRunResponse, AsyncRunResponseProtocol, RunResponse, RunResponseProtocol
from ...io.thread_io_stream import AsyncThreadIOStream, ThreadIOStream
from ..chat import ChatResult
from .context_variables import ContextVariables
from .group_utils import cleanup_temp_user_messages
if TYPE_CHECKING:
from ..agent import Agent
from .patterns.pattern import Pattern
__all__ = [
"a_initiate_group_chat",
"a_run_group_chat",
"initiate_group_chat",
"run_group_chat",
]
@export_module("autogen")
def initiate_group_chat(
pattern: "Pattern",
messages: Union[list[dict[str, Any]], str],
max_rounds: int = 20,
) -> tuple[ChatResult, ContextVariables, "Agent"]:
"""Initialize and run a group chat using a pattern for configuration.
Args:
pattern: Pattern object that encapsulates the chat configuration.
messages: Initial message(s).
max_rounds: Maximum number of conversation rounds.
Returns:
ChatResult: Conversations chat history.
ContextVariables: Updated Context variables.
"ConversableAgent": Last speaker.
"""
# Let the pattern prepare the group chat and all its components
# Only passing the necessary parameters that aren't already in the pattern
(
_, # agents,
_, # wrapped_agents,
_, # user_agent,
context_variables,
_, # initial_agent,
_, # group_after_work,
_, # tool_execution,
_, # groupchat,
manager,
processed_messages,
last_agent,
_, # group_agent_names,
_, # temp_user_list,
) = pattern.prepare_group_chat(
max_rounds=max_rounds,
messages=messages,
)
# Start or resume the conversation
if len(processed_messages) > 1:
last_agent, last_message = manager.resume(messages=processed_messages)
clear_history = False
else:
last_message = processed_messages[0]
clear_history = True
if last_agent is None:
raise ValueError("No agent selected to start the conversation")
chat_result = last_agent.initiate_chat(
manager,
message=last_message,
clear_history=clear_history,
summary_method=pattern.summary_method,
)
cleanup_temp_user_messages(chat_result)
return chat_result, context_variables, manager.last_speaker
@export_module("autogen.agentchat")
async def a_initiate_group_chat(
pattern: "Pattern",
messages: Union[list[dict[str, Any]], str],
max_rounds: int = 20,
) -> tuple[ChatResult, ContextVariables, "Agent"]:
"""Initialize and run a group chat using a pattern for configuration, asynchronously.
Args:
pattern: Pattern object that encapsulates the chat configuration.
messages: Initial message(s).
max_rounds: Maximum number of conversation rounds.
Returns:
ChatResult: Conversations chat history.
ContextVariables: Updated Context variables.
"ConversableAgent": Last speaker.
"""
# Let the pattern prepare the group chat and all its components
# Only passing the necessary parameters that aren't already in the pattern
(
_, # agents,
_, # wrapped_agents,
_, # user_agent,
context_variables,
_, # initial_agent,
_, # group_after_work,
_, # tool_execution,
_, # groupchat,
manager,
processed_messages,
last_agent,
_, # group_agent_names,
_, # temp_user_list,
) = pattern.prepare_group_chat(
max_rounds=max_rounds,
messages=messages,
)
# Start or resume the conversation
if len(processed_messages) > 1:
last_agent, last_message = await manager.a_resume(messages=processed_messages)
clear_history = False
else:
last_message = processed_messages[0]
clear_history = True
if last_agent is None:
raise ValueError("No agent selected to start the conversation")
chat_result = await last_agent.a_initiate_chat(
manager,
message=last_message, # type: ignore[arg-type]
clear_history=clear_history,
summary_method=pattern.summary_method,
)
cleanup_temp_user_messages(chat_result)
return chat_result, context_variables, manager.last_speaker
@export_module("autogen.agentchat")
def run_group_chat(
pattern: "Pattern",
messages: Union[list[dict[str, Any]], str],
max_rounds: int = 20,
) -> RunResponseProtocol:
iostream = ThreadIOStream()
# todo: add agents
response = RunResponse(iostream, agents=[])
def _initiate_group_chat(
pattern: "Pattern" = pattern,
messages: Union[list[dict[str, Any]], str] = messages,
max_rounds: int = max_rounds,
iostream: ThreadIOStream = iostream,
response: RunResponse = response,
) -> None:
with IOStream.set_default(iostream):
try:
chat_result, context_vars, agent = initiate_group_chat(
pattern=pattern,
messages=messages,
max_rounds=max_rounds,
)
IOStream.get_default().send(
RunCompletionEvent( # type: ignore[call-arg]
history=chat_result.chat_history,
summary=chat_result.summary,
cost=chat_result.cost,
last_speaker=agent.name,
context_variables=context_vars,
)
)
except Exception as e:
response.iostream.send(ErrorEvent(error=e)) # type: ignore[call-arg]
threading.Thread(
target=_initiate_group_chat,
).start()
return response
@export_module("autogen.agentchat")
async def a_run_group_chat(
pattern: "Pattern",
messages: Union[list[dict[str, Any]], str],
max_rounds: int = 20,
) -> AsyncRunResponseProtocol:
iostream = AsyncThreadIOStream()
# todo: add agents
response = AsyncRunResponse(iostream, agents=[])
async def _initiate_group_chat(
pattern: "Pattern" = pattern,
messages: Union[list[dict[str, Any]], str] = messages,
max_rounds: int = max_rounds,
iostream: AsyncThreadIOStream = iostream,
response: AsyncRunResponse = response,
) -> None:
with IOStream.set_default(iostream):
try:
chat_result, context_vars, agent = await a_initiate_group_chat(
pattern=pattern,
messages=messages,
max_rounds=max_rounds,
)
IOStream.get_default().send(
RunCompletionEvent( # type: ignore[call-arg]
history=chat_result.chat_history,
summary=chat_result.summary,
cost=chat_result.cost,
last_speaker=agent.name,
context_variables=context_vars,
)
)
except Exception as e:
response.iostream.send(ErrorEvent(error=e)) # type: ignore[call-arg]
asyncio.create_task(_initiate_group_chat())
return response

View File

@@ -0,0 +1,58 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from pydantic import BaseModel
from ...doc_utils import export_module
from .available_condition import AvailableCondition
from .llm_condition import LLMCondition
from .targets.transition_target import TransitionTarget
__all__ = [
"OnCondition",
]
@export_module("autogen")
class OnCondition(BaseModel): # noqa: N801
"""Defines a condition for transitioning to another agent or nested chats.
This is for LLM-based condition evaluation where these conditions are translated into tools and attached to the agent.
These are evaluated after the OnCondition conditions but before the after work condition.
Args:
target (TransitionTarget): The transition (essentially an agent) to hand off to.
condition (LLMCondition): The condition for transitioning to the target agent, evaluated by the LLM.
available (AvailableCondition): Optional condition to determine if this OnCondition is included for the LLM to evaluate based on context variables using classes like StringAvailableCondition and ContextExpressionAvailableCondition.
llm_function_name (Optional[str]): The name of the LLM function to use for this condition.
"""
target: TransitionTarget
condition: LLMCondition
available: Optional[AvailableCondition] = None
llm_function_name: Optional[str] = None
def has_target_type(self, target_type: type) -> bool:
"""
Check if the target type matches the specified type.
Args:
target_type (type): The target type to check against, which should be a subclass of TransitionTarget
Returns:
bool: True if the target type matches, False otherwise
"""
return isinstance(self.target, target_type)
def target_requires_wrapping(self) -> bool:
"""
Check if the target requires wrapping in an agent.
Returns:
bool: True if the target requires wrapping, False otherwise
"""
return self.target.needs_agent_wrapper()

View File

@@ -0,0 +1,54 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from pydantic import BaseModel
from .available_condition import AvailableCondition
from .context_condition import ContextCondition
from .targets.transition_target import TransitionTarget
__all__ = [
"OnContextCondition",
]
class OnContextCondition(BaseModel): # noqa: N801
"""Defines a condition for transitioning to another agent or nested chats using context variables and the ContextExpression class.
This is for context variable-based condition evaluation (does not use the agent's LLM).
These are evaluated before the OnCondition and after work conditions.
Args:
target (TransitionTarget): The transition (essentially an agent) to hand off to.
condition (Optional[ContextCondition]): The context variable based condition for transitioning to the target agent. If None, the condition always evaluates to True.
available (AvailableCondition): Optional condition to determine if this OnCondition is included for the LLM to evaluate based on context variables using classes like StringAvailableCondition and ContextExpressionAvailableCondition.
"""
target: TransitionTarget
condition: Optional[ContextCondition] = None
available: Optional[AvailableCondition] = None
def has_target_type(self, target_type: type) -> bool:
"""
Check if the target type matches the specified type.
Args:
target_type (type): The target type to check against. Should be a subclass of TransitionTarget.
Returns:
bool: True if the target type matches, False otherwise
"""
return isinstance(self.target, target_type)
def target_requires_wrapping(self) -> bool:
"""
Check if the target requires wrapping in an agent.
Returns:
bool: True if the target requires wrapping, False otherwise
"""
return self.target.needs_agent_wrapper()

View File

@@ -0,0 +1,18 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
from .auto import AutoPattern
from .manual import ManualPattern
from .pattern import DefaultPattern
from .random import RandomPattern
from .round_robin import RoundRobinPattern
__all__ = [
"AutoPattern",
"DefaultPattern",
"ManualPattern",
"RandomPattern",
"RoundRobinPattern",
]

View File

@@ -0,0 +1,159 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
from ..context_variables import ContextVariables
from ..targets.group_manager_target import GroupManagerSelectionMessage, GroupManagerTarget
from ..targets.transition_target import TransitionTarget
from .pattern import Pattern
if TYPE_CHECKING:
from ...conversable_agent import ConversableAgent
from ...groupchat import GroupChat, GroupChatManager
from ..group_tool_executor import GroupToolExecutor
class AutoPattern(Pattern):
"""AutoPattern implements a flexible pattern where agents are selected based on their expertise.
In this pattern, a group manager automatically selects the next agent to speak based on the context
of the conversation and agent descriptions. The after_work is always set to "group_manager" as
this is the defining characteristic of this pattern.
"""
def __init__(
self,
initial_agent: "ConversableAgent",
agents: list["ConversableAgent"],
user_agent: Optional["ConversableAgent"] = None,
group_manager_args: Optional[dict[str, Any]] = None,
context_variables: Optional[ContextVariables] = None,
selection_message: Optional[GroupManagerSelectionMessage] = None,
exclude_transit_message: bool = True,
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
):
"""Initialize the AutoPattern.
The after_work is always set to group_manager selection, which is the defining
characteristic of this pattern. You can customize the selection message used
by the group manager when selecting the next agent.
Args:
initial_agent: The first agent to speak in the group chat.
agents: List of all agents participating in the chat.
user_agent: Optional user proxy agent.
group_manager_args: Optional arguments for the GroupChatManager.
context_variables: Initial context variables for the chat.
selection_message: Custom message to use when the group manager is selecting agents.
exclude_transit_message: Whether to exclude transit messages from the conversation.
summary_method: Method for summarizing the conversation.
"""
# Create the group_manager after_work with the provided selection message
group_manager_after_work = GroupManagerTarget(selection_message=selection_message)
super().__init__(
initial_agent=initial_agent,
agents=agents,
user_agent=user_agent,
group_manager_args=group_manager_args,
context_variables=context_variables,
group_after_work=group_manager_after_work,
exclude_transit_message=exclude_transit_message,
summary_method=summary_method,
)
# Store the selection message for potential use
self.selection_message = selection_message
def prepare_group_chat(
self,
max_rounds: int,
messages: Union[list[dict[str, Any]], str],
) -> Tuple[
list["ConversableAgent"],
list["ConversableAgent"],
Optional["ConversableAgent"],
ContextVariables,
"ConversableAgent",
TransitionTarget,
"GroupToolExecutor",
"GroupChat",
"GroupChatManager",
list[dict[str, Any]],
Any,
list[str],
list[Any],
]:
"""Prepare the group chat for organic agent selection.
Ensures that:
1. The group manager has a valid LLM config
2. All agents have appropriate descriptions for the group manager to use
Args:
max_rounds: Maximum number of conversation rounds.
messages: Initial message(s) to start the conversation.
Returns:
Tuple containing all necessary components for the group chat.
"""
# Validate that group_manager_args has an LLM config which is required for this pattern
if not self.group_manager_args.get("llm_config", False):
# Check if any agent has an LLM config we can use
has_llm_config = any(getattr(agent, "llm_config", False) for agent in self.agents)
if not has_llm_config:
raise ValueError(
"AutoPattern requires the group_manager_args to include an llm_config, "
"or at least one agent to have an llm_config"
)
# Check that all agents have descriptions for effective group manager selection
for agent in self.agents:
if not hasattr(agent, "description") or not agent.description:
agent.description = f"Agent {agent.name}"
# Use the parent class's implementation to prepare the agents and group chat
components = super().prepare_group_chat(
max_rounds=max_rounds,
messages=messages,
)
# Extract the group_after_work and the rest of the components
(
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
_,
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
) = components
# Ensure we're using the group_manager after_work
group_after_work = self.group_after_work
# Return all components with our group_after_work
return (
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
group_after_work,
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
)

View File

@@ -0,0 +1,176 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
from ..context_variables import ContextVariables
from ..group_tool_executor import GroupToolExecutor
from ..targets.transition_target import AskUserTarget, TransitionTarget
from .pattern import Pattern
if TYPE_CHECKING:
from ...conversable_agent import ConversableAgent
from ...groupchat import GroupChat, GroupChatManager
class ManualPattern(Pattern):
"""ManualPattern will ask the user to nominate the next agent to speak at each turn."""
def __init__(
self,
initial_agent: "ConversableAgent",
agents: list["ConversableAgent"],
user_agent: Optional["ConversableAgent"] = None,
group_manager_args: Optional[dict[str, Any]] = None,
context_variables: Optional[ContextVariables] = None,
exclude_transit_message: bool = True,
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
):
"""Initialize the ManualPattern.
The after_work is always set to ask_user, which will prompt the user for the next agent
Args:
initial_agent: The first agent to speak in the group chat.
agents: List of all agents participating in the chat.
user_agent: Optional user proxy agent.
group_manager_args: Optional arguments for the GroupChatManager.
context_variables: Initial context variables for the chat.
exclude_transit_message: Whether to exclude transit messages from the conversation.
summary_method: Method for summarizing the conversation.
"""
# The group after work will be to ask the user
group_after_work = AskUserTarget()
super().__init__(
initial_agent=initial_agent,
agents=agents,
user_agent=user_agent,
group_manager_args=group_manager_args,
context_variables=context_variables,
group_after_work=group_after_work,
exclude_transit_message=exclude_transit_message,
summary_method=summary_method,
)
def prepare_group_chat(
self,
max_rounds: int,
messages: Union[list[dict[str, Any]], str],
) -> Tuple[
list["ConversableAgent"],
list["ConversableAgent"],
Optional["ConversableAgent"],
ContextVariables,
"ConversableAgent",
TransitionTarget,
"GroupToolExecutor",
"GroupChat",
"GroupChatManager",
list[dict[str, Any]],
Any,
list[str],
list[Any],
]:
"""Prepare the group chat for organic agent selection.
Ensures that:
1. The group manager has a valid LLM config
2. All agents have appropriate descriptions for the group manager to use
Args:
max_rounds: Maximum number of conversation rounds.
messages: Initial message(s) to start the conversation.
Returns:
Tuple containing all necessary components for the group chat.
"""
# Use the parent class's implementation to prepare the agents and group chat
components = super().prepare_group_chat(
max_rounds=max_rounds,
messages=messages,
)
# Extract the group_after_work and the rest of the components
(
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
_,
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
) = components
# Ensure we're using the group_manager after_work
group_after_work = self.group_after_work
# Set up the allowed speaker transitions to exclude user_agent and GroupToolExecutor
self._setup_allowed_transitions(groupchat, user_agent, tool_executor)
# Return all components with our group_after_work
return (
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
group_after_work,
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
)
def _setup_allowed_transitions(
self, groupchat: "GroupChat", user_agent: Optional["ConversableAgent"], tool_executor: "GroupToolExecutor"
) -> None:
"""Set up the allowed speaker transitions for the group chat so that when a user selects the next agent the tool executor and user agent don't appear as options.
Creates transitions where:
1. Any agent can speak after any other agent, including themselves
2. The user_agent and GroupToolExecutor are excluded from transitions
Args:
groupchat: The GroupChat instance to configure
user_agent: The user agent to exclude from transitions
tool_executor: The GroupToolExecutor to exclude from transitions
"""
# NOTE: THIS IS NOT WORKING - THE TRANSITIONS ARE NOT BEING KEPT?!
"""
# Get all agents in the group chat
all_agents = groupchat.agents
# Filter out user_agent and group tool executor
eligible_agents = []
for agent in all_agents:
# Skip user_agent
if agent == user_agent:
continue
# Skip GroupToolExecutor
if isinstance(agent, GroupToolExecutor):
continue
eligible_agents.append(agent)
# Create a fully connected graph among eligible agents
# Each agent can be followed by any other eligible agent
allowed_transitions = {}
for agent in eligible_agents:
# For each agent, every other eligible agent can follow
allowed_transitions[agent] = eligible_agents
# Set the transitions in the group chat
groupchat.allowed_speaker_transitions_dict = allowed_transitions
"""

View File

@@ -0,0 +1,294 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
# Patterns of agent orchestrations
# Uses the group chat or the agents' handoffs to create a pattern
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
from ..context_variables import ContextVariables
from ..group_utils import (
create_group_manager,
create_group_transition,
link_agents_to_group_manager,
prepare_group_agents,
process_initial_messages,
setup_context_variables,
)
from ..targets.transition_target import TerminateTarget, TransitionTarget
if TYPE_CHECKING:
from ...agent import Agent
from ...conversable_agent import ConversableAgent
from ...groupchat import GroupChat, GroupChatManager
from ..group_tool_executor import GroupToolExecutor
class Pattern(ABC):
"""Base abstract class for all orchestration patterns.
Patterns provide a reusable way to define how agents interact within a group chat.
Each pattern encapsulates the logic for setting up agents, configuring handoffs,
and determining the flow of conversation.
This is an abstract base class and should not be instantiated directly.
Use one of the concrete pattern implementations like AutoPattern,
RoundRobinPattern, RandomPattern, or ManualPattern.
"""
def __init__(
self,
initial_agent: "ConversableAgent",
agents: list["ConversableAgent"],
user_agent: Optional["ConversableAgent"] = None,
group_manager_args: Optional[dict[str, Any]] = None,
context_variables: Optional[ContextVariables] = None,
group_after_work: Optional[TransitionTarget] = None,
exclude_transit_message: bool = True,
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
):
"""Initialize the pattern with the required components.
Args:
initial_agent: The first agent to speak in the group chat.
agents: List of all agents participating in the chat.
user_agent: Optional user proxy agent.
group_manager_args: Optional arguments for the GroupChatManager.
context_variables: Initial context variables for the chat.
group_after_work: Default after work transition behavior when no specific next agent is determined.
exclude_transit_message: Whether to exclude transit messages from the conversation.
summary_method: Method for summarizing the conversation.
"""
self.initial_agent = initial_agent
self.agents = agents
self.user_agent = user_agent
self.group_manager_args = group_manager_args or {}
self.context_variables = context_variables or ContextVariables()
self.group_after_work = group_after_work if group_after_work is not None else TerminateTarget()
self.exclude_transit_message = exclude_transit_message
self.summary_method = summary_method
@abstractmethod
def prepare_group_chat(
self,
max_rounds: int,
messages: Union[list[dict[str, Any]], str],
) -> Tuple[
list["ConversableAgent"],
list["ConversableAgent"],
Optional["ConversableAgent"],
ContextVariables,
"ConversableAgent",
TransitionTarget,
"GroupToolExecutor",
"GroupChat",
"GroupChatManager",
list[dict[str, Any]],
"ConversableAgent",
list[str],
list["Agent"],
]:
"""Prepare the group chat for orchestration.
This is the main method called by initiate_group_chat to set up the pattern.
Subclasses must implement or extend this method to define pattern-specific behavior.
Args:
max_rounds: Maximum number of conversation rounds.
messages: Initial message(s) to start the conversation.
Returns:
Tuple containing:
- List of agents involved in the group chat
- List of wrapped agents
- User agent, if applicable
- Context variables for the group chat
- Initial agent for the group chat
- Group-level after work transition for the group chat
- Tool executor for the group chat
- GroupChat instance
- GroupChatManager instance
- Processed messages
- Last agent to speak
- List of group agent names
- List of temporary user agents
"""
from ...groupchat import GroupChat
# Prepare the agents using the existing helper function
tool_executor, wrapped_agents = prepare_group_agents(
self.agents, self.context_variables, self.exclude_transit_message
)
# Process the initial messages BEFORE creating the GroupChat
# This will create a temporary user agent if needed
processed_messages, last_agent, group_agent_names, temp_user_list = process_initial_messages(
messages, self.user_agent, self.agents, wrapped_agents
)
# Create transition function (has enclosed state for initial agent)
group_transition = create_group_transition(
initial_agent=self.initial_agent,
tool_execution=tool_executor,
group_agent_names=group_agent_names,
user_agent=self.user_agent,
group_after_work=self.group_after_work,
)
# Create the group chat - now we use temp_user_list if no user_agent
groupchat = GroupChat(
agents=[tool_executor]
+ self.agents
+ wrapped_agents
+ ([self.user_agent] if self.user_agent else temp_user_list),
messages=[],
max_round=max_rounds,
speaker_selection_method=group_transition,
)
# Create the group manager
manager = create_group_manager(groupchat, self.group_manager_args, self.agents, self.group_after_work)
# Point all agent's context variables to this function's context_variables
setup_context_variables(
tool_execution=tool_executor,
agents=self.agents,
manager=manager,
user_agent=self.user_agent,
context_variables=self.context_variables,
)
# Link all agents with the GroupChatManager to allow access to the group chat
link_agents_to_group_manager(groupchat.agents, manager)
return (
self.agents,
wrapped_agents,
self.user_agent,
self.context_variables,
self.initial_agent,
self.group_after_work,
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
) # type: ignore[return-value]
@classmethod
def create_default(
cls,
initial_agent: "ConversableAgent",
agents: list["ConversableAgent"],
user_agent: Optional["ConversableAgent"] = None,
group_manager_args: Optional[dict[str, Any]] = None,
context_variables: Optional[ContextVariables] = None,
exclude_transit_message: bool = True,
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
) -> "DefaultPattern":
"""Create a default pattern with minimal configuration.
This replaces the need for a separate BasePattern class by providing
a factory method that creates a simple DefaultPattern instance.
Args:
initial_agent: The first agent to speak in the group chat.
agents: List of all agents participating in the chat.
user_agent: Optional user proxy agent.
group_manager_args: Optional arguments for the GroupChatManager.
context_variables: Initial context variables for the chat.
exclude_transit_message: Whether to exclude transit messages from the conversation.
summary_method: Method for summarizing the conversation.
Returns:
A DefaultPattern instance with basic configuration.
"""
return DefaultPattern(
initial_agent=initial_agent,
agents=agents,
user_agent=user_agent,
group_manager_args=group_manager_args,
context_variables=context_variables,
exclude_transit_message=exclude_transit_message,
summary_method=summary_method,
)
class DefaultPattern(Pattern):
"""DefaultPattern implements a minimal pattern for simple agent interactions.
This replaces the previous BasePattern and provides a concrete implementation
of the Pattern abstract base class.
"""
def prepare_group_chat(
self,
max_rounds: int,
messages: Union[list[dict[str, Any]], str],
) -> Tuple[
list["ConversableAgent"],
list["ConversableAgent"],
Optional["ConversableAgent"],
ContextVariables,
"ConversableAgent",
TransitionTarget,
"GroupToolExecutor",
"GroupChat",
"GroupChatManager",
list[dict[str, Any]],
Any,
list[str],
list[Any],
]:
"""Prepare the group chat with default configuration.
This implementation calls the parent class method but ensures that
the group_after_work in the returned tuple is the pattern's own.
Args:
max_rounds: Maximum number of conversation rounds.
messages: Initial message(s) to start the conversation.
Returns:
Tuple containing all necessary components for the group chat.
"""
# Use the parent class's implementation to prepare the agents and group chat
(
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
_, # Ignore the group_after_work from parent
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
) = super().prepare_group_chat(
max_rounds=max_rounds,
messages=messages,
)
# Return all components with our group_after_work
return (
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
self.group_after_work, # Use our own group_after_work
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
)

View File

@@ -0,0 +1,106 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
from ..context_variables import ContextVariables
from ..targets.transition_target import RandomAgentTarget, TransitionTarget
from .pattern import Pattern
if TYPE_CHECKING:
from ...conversable_agent import ConversableAgent
from ...groupchat import GroupChat, GroupChatManager
from ..group_tool_executor import GroupToolExecutor
class RandomPattern(Pattern):
"""RandomPattern implements a random agent selection process."""
def _generate_handoffs(
self,
initial_agent: "ConversableAgent",
agents: list["ConversableAgent"],
user_agent: Optional["ConversableAgent"],
) -> None:
"""Generate handoffs between agents in a random fashion."""
agent_list = agents + ([user_agent] if user_agent is not None else [])
for agent in agent_list:
# Get the list of agents except itself
other_agents = [a for a in agent_list if a != agent]
# Create a random after work
agent.handoffs.set_after_work(target=RandomAgentTarget(agents=other_agents))
def prepare_group_chat(
self,
max_rounds: int,
messages: Union[list[dict[str, Any]], str],
) -> Tuple[
list["ConversableAgent"],
list["ConversableAgent"],
Optional["ConversableAgent"],
ContextVariables,
"ConversableAgent",
TransitionTarget,
"GroupToolExecutor",
"GroupChat",
"GroupChatManager",
list[dict[str, Any]],
Any,
list[str],
list[Any],
]:
"""Prepare the group chat for organic agent selection.
Ensures that:
1. The group manager has a valid LLM config
2. All agents have appropriate descriptions for the group manager to use
Args:
max_rounds: Maximum number of conversation rounds.
messages: Initial message(s) to start the conversation.
Returns:
Tuple containing all necessary components for the group chat.
"""
# Use the parent class's implementation to prepare the agents and group chat
(
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
group_after_work,
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
) = super().prepare_group_chat(
max_rounds=max_rounds,
messages=messages,
)
# Create the random handoffs between agents
self._generate_handoffs(initial_agent=initial_agent, agents=agents, user_agent=user_agent)
# Return all components with our group_after_work
return (
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
group_after_work,
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
)

View File

@@ -0,0 +1,117 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
from ..context_variables import ContextVariables
from ..targets.transition_target import AgentTarget, TransitionTarget
from .pattern import Pattern
if TYPE_CHECKING:
from ...conversable_agent import ConversableAgent
from ...groupchat import GroupChat, GroupChatManager
from ..group_tool_executor import GroupToolExecutor
class RoundRobinPattern(Pattern):
"""RoundRobinPattern implements a round robin with handoffs between agents."""
def _generate_handoffs(
self,
initial_agent: "ConversableAgent",
agents: list["ConversableAgent"],
user_agent: Optional["ConversableAgent"],
) -> None:
"""Generate handoffs between agents in a round-robin fashion."""
# Create a list of the agents and the user_agent but put the initial_agent first
agent_list = [initial_agent]
# Add the rest of the agents, excluding the initial_agent and user_agent
for agent in agents:
if agent != initial_agent and (user_agent is None or agent != user_agent):
agent_list.append(agent)
# Add the user_agent last if it exists
if user_agent is not None:
agent_list.append(user_agent)
# Create handoffs in a round-robin fashion
for i, agent in enumerate(agent_list):
# Last agent hands off to the first agent
# Otherwise agent hands off to the next one
handoff_target = agent_list[0] if i == len(agent_list) - 1 else agent_list[i + 1]
agent.handoffs.set_after_work(target=AgentTarget(agent=handoff_target))
def prepare_group_chat(
self,
max_rounds: int,
messages: Union[list[dict[str, Any]], str],
) -> Tuple[
list["ConversableAgent"],
list["ConversableAgent"],
Optional["ConversableAgent"],
ContextVariables,
"ConversableAgent",
TransitionTarget,
"GroupToolExecutor",
"GroupChat",
"GroupChatManager",
list[dict[str, Any]],
Any,
list[str],
list[Any],
]:
"""Prepare the group chat for organic agent selection.
Ensures that:
1. The group manager has a valid LLM config
2. All agents have appropriate descriptions for the group manager to use
Args:
max_rounds: Maximum number of conversation rounds.
messages: Initial message(s) to start the conversation.
Returns:
Tuple containing all necessary components for the group chat.
"""
# Use the parent class's implementation to prepare the agents and group chat
(
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
group_after_work,
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
) = super().prepare_group_chat(
max_rounds=max_rounds,
messages=messages,
)
# Create the handoffs between agents
self._generate_handoffs(initial_agent=initial_agent, agents=agents, user_agent=user_agent)
# Return all components with our group_after_work
return (
agents,
wrapped_agents,
user_agent,
context_variables,
initial_agent,
group_after_work,
tool_executor,
groupchat,
manager,
processed_messages,
last_agent,
group_agent_names,
temp_user_list,
)

View File

@@ -0,0 +1,26 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
__all__ = ["ReplyResult"]
from typing import Optional
from pydantic import BaseModel
from .context_variables import ContextVariables
from .targets.transition_target import TransitionTarget
class ReplyResult(BaseModel):
"""Result of a tool call that is used to provide the return message and the target to transition to."""
message: str
target: Optional[TransitionTarget] = None
context_variables: Optional[ContextVariables] = None
def __str__(self) -> str:
"""The string representation for ReplyResult will be just the message."""
return self.message

View File

@@ -0,0 +1,41 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional, Union
from pydantic import BaseModel
from ..agent import Agent
if TYPE_CHECKING:
# Avoid circular import
from ..groupchat import GroupChat
class SpeakerSelectionResult(BaseModel):
"""Represents a speaker selection result that will be returned to GroupChat._prepare_and_select_agents to determine the next speaker.
This class can return an Agent, a None to end the conversation, or a string for a speaker selection method.
"""
terminate: Optional[bool] = None
agent_name: Optional[str] = None
speaker_selection_method: Optional[str] = None
def get_speaker_selection_result(self, groupchat: "GroupChat") -> Optional[Union[Agent, str]]:
"""Get the speaker selection result. If None, the conversation will end."""
if self.agent_name is not None:
# Find the agent by name in the groupchat
for agent in groupchat.agents:
if agent.name == self.agent_name:
return agent
raise ValueError(f"Agent '{self.agent_name}' not found in groupchat.")
elif self.speaker_selection_method is not None:
return self.speaker_selection_method
elif self.terminate is not None and self.terminate:
return None
else:
raise ValueError(
"Unable to establish speaker selection result. No terminate, agent, or speaker selection method provided."
)

View File

@@ -0,0 +1,4 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#

View File

@@ -0,0 +1,132 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import BaseModel
from ....doc_utils import export_module
from ...agent import Agent
from ..speaker_selection_result import SpeakerSelectionResult
from .transition_target import AgentTarget, TransitionTarget
from .transition_utils import __AGENT_WRAPPER_PREFIX__
if TYPE_CHECKING:
from ...conversable_agent import ConversableAgent
from ...groupchat import GroupChat
from ..patterns.pattern import Pattern
__all__ = ["GroupChatConfig", "GroupChatTarget"]
@export_module("autogen.agentchat.group")
class GroupChatConfig(BaseModel):
"""Configuration for a group chat transition target.
Note: If context_variables are not passed in, the outer context variables will be passed in"""
pattern: "Pattern"
messages: Union[list[dict[str, Any]], str]
max_rounds: int = 20
@export_module("autogen.agentchat.group")
class GroupChatTarget(TransitionTarget):
"""Target that represents a group chat."""
group_chat_config: GroupChatConfig
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection. For GroupChatTarget the chat must be encapsulated into an agent."""
return False
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to the nested chat configuration."""
raise NotImplementedError(
"GroupChatTarget does not support the resolve method. An agent should be used to encapsulate this nested chat and then the target changed to an AgentTarget."
)
def display_name(self) -> str:
"""Get the display name for the target."""
return "a group chat"
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling."""
return "group_chat"
def __str__(self) -> str:
"""String representation for AgentTarget, can be shown as a function call message."""
return "Transfer to group chat"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent. GroupChatTarget must be wrapped in an agent."""
return True
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the group chat."""
from autogen.agentchat import initiate_group_chat
from ...conversable_agent import ConversableAgent # to avoid circular import
# Create the wrapper agent with a name that identifies it as a wrapped group chat
group_chat_agent = ConversableAgent(
name=f"{__AGENT_WRAPPER_PREFIX__}group_{parent_agent.name}_{index + 1}",
# Copy LLM config from parent agent to ensure it can generate replies if needed
llm_config=parent_agent.llm_config,
)
# Store the config directly on the agent
group_chat_agent._group_chat_config = self.group_chat_config # type: ignore[attr-defined]
# Define the reply function that will run the group chat
def group_chat_reply(
agent: "ConversableAgent",
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
config: Optional[Any] = None,
) -> tuple[bool, Optional[dict[str, Any]]]:
"""Run the inner group chat and return its results as a reply."""
# Get the configuration stored directly on the agent
group_config = agent._group_chat_config # type: ignore[attr-defined]
# Pull through the second last message from the outer chat (the last message will be the handoff message)
# This may need work to make sure we get the right message(s) from the outer chat
message = (
messages[-2]["content"]
if messages and len(messages) >= 2 and "content" in messages[-2]
else "No message to pass through."
)
try:
# Run the group chat with direct agent references from the config
result, _, _ = initiate_group_chat(
pattern=group_config.pattern,
messages=message,
max_rounds=group_config.max_rounds,
)
# Return the summary from the chat result summary
return True, {"content": result.summary}
except Exception as e:
# Handle any errors during execution
return True, {"content": f"Error running group chat: {str(e)}"}
# Register the reply function with the wrapper agent
group_chat_agent.register_reply(
trigger=[ConversableAgent, None],
reply_func=group_chat_reply,
remove_other_reply_funcs=True, # Use only this reply function
)
# After the group chat completes, transition back to the parent agent
group_chat_agent.handoffs.set_after_work(AgentTarget(parent_agent))
return group_chat_agent

View File

@@ -0,0 +1,151 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Optional, Type, Union
from pydantic import BaseModel, field_validator
from ....doc_utils import export_module
from ..context_str import ContextStr
from ..group_tool_executor import GroupToolExecutor
from ..speaker_selection_result import SpeakerSelectionResult
from .transition_target import TransitionTarget
from .transition_utils import __AGENT_WRAPPER_PREFIX__
if TYPE_CHECKING:
# Avoid circular import
from ...conversable_agent import ConversableAgent
from ...groupchat import GroupChat
__all__ = ["GroupManagerTarget"]
def prepare_groupchat_auto_speaker(
groupchat: "GroupChat",
last_group_agent: "ConversableAgent",
group_chat_manager_selection_msg: Optional[Any],
) -> None:
"""Prepare the group chat for auto speaker selection, includes updating or restore the groupchat speaker selection message.
Tool Executor and wrapped agents will be removed from the available agents list.
Args:
groupchat (GroupChat): GroupChat instance.
last_group_agent ("ConversableAgent"): The last group agent for which the LLM config is used
group_chat_manager_selection_msg (GroupManagerSelectionMessage): Optional message to use for the agent selection (in internal group chat).
"""
from ...groupchat import SELECT_SPEAKER_PROMPT_TEMPLATE
def substitute_agentlist(template: str) -> str:
# Run through group chat's string substitution first for {agentlist}
# We need to do this so that the next substitution doesn't fail with agentlist
# and we can remove the tool executor and wrapped chats from the available agents list
agent_list = [
agent
for agent in groupchat.agents
if not isinstance(agent, GroupToolExecutor) and not agent.name.startswith(__AGENT_WRAPPER_PREFIX__)
]
groupchat.select_speaker_prompt_template = template
return groupchat.select_speaker_prompt(agent_list)
# Use the default speaker selection prompt if one is not specified, otherwise use the specified one
groupchat.select_speaker_prompt_template = substitute_agentlist(
SELECT_SPEAKER_PROMPT_TEMPLATE
if group_chat_manager_selection_msg is None
else group_chat_manager_selection_msg.get_message(last_group_agent)
)
# GroupManagerSelectionMessage protocol and implementations
@export_module("autogen.agentchat.group")
class GroupManagerSelectionMessage(BaseModel):
"""Base class for all GroupManager selection message types."""
def get_message(self, agent: "ConversableAgent") -> str:
"""Get the formatted message."""
raise NotImplementedError("Requires subclasses to implement.")
@export_module("autogen.agentchat.group")
class GroupManagerSelectionMessageString(GroupManagerSelectionMessage):
"""Selection message that uses a plain string template."""
message: str
def get_message(self, agent: "ConversableAgent") -> str:
"""Get the message string."""
return self.message
@export_module("autogen.agentchat.group")
class GroupManagerSelectionMessageContextStr(GroupManagerSelectionMessage):
"""Selection message that uses a ContextStr template."""
context_str_template: str
# We will replace {agentlist} with another term and return it later for use with the internal group chat auto speaker selection
# Otherwise our format will fail
@field_validator("context_str_template", mode="before")
def _replace_agentlist_placeholder(cls: Type["GroupManagerSelectionMessageContextStr"], v: Any) -> Union[str, Any]: # noqa: N805
"""Replace {agentlist} placeholder before validation/assignment."""
if isinstance(v, str):
if "{agentlist}" in v:
return v.replace("{agentlist}", "<<agent_list>>") # Perform the replacement
else:
return v # If no replacement is needed, return the original value
return ""
def get_message(self, agent: "ConversableAgent") -> str:
"""Get the formatted message with context variables substituted."""
context_str = ContextStr(template=self.context_str_template)
format_result = context_str.format(agent.context_variables)
if format_result is None:
return ""
return format_result.replace(
"<<agent_list>>", "{agentlist}"
) # Restore agentlist so it can be substituted by the internal group chat auto speaker selection
class GroupManagerTarget(TransitionTarget):
"""Target that represents an agent by name."""
selection_message: Optional[GroupManagerSelectionMessage] = None
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection."""
return True
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to the speaker selection for the group."""
if self.selection_message is not None:
prepare_groupchat_auto_speaker(groupchat, current_agent, self.selection_message)
return SpeakerSelectionResult(speaker_selection_method="auto")
def display_name(self) -> str:
"""Get the display name for the target."""
return "the group manager"
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
return self.display_name()
def __str__(self) -> str:
"""String representation for AgentTarget, can be shown as a function call message."""
return "Transfer to the group manager"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent."""
return False
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the target if needed."""
raise NotImplementedError("GroupManagerTarget does not require wrapping in an agent.")

View File

@@ -0,0 +1,413 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import random
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel
from ..speaker_selection_result import SpeakerSelectionResult
from .transition_utils import __AGENT_WRAPPER_PREFIX__
if TYPE_CHECKING:
# Avoid circular import
from ...conversable_agent import ConversableAgent
from ...groupchat import GroupChat
__all__ = [
"AgentNameTarget",
"AgentTarget",
"AskUserTarget",
"NestedChatTarget",
"RandomAgentTarget",
"RevertToUserTarget",
"StayTarget",
"TerminateTarget",
"TransitionTarget",
]
# Common options for transitions
# terminate: Terminate the conversation
# revert_to_user: Revert to the user agent
# stay: Stay with the current agent
# group_manager: Use the group manager (auto speaker selection)
# ask_user: Use the user manager (ask the user, aka manual)
# TransitionOption = Literal["terminate", "revert_to_user", "stay", "group_manager", "ask_user"]
class TransitionTarget(BaseModel):
"""Base class for all transition targets across OnCondition, OnContextCondition, and after work."""
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve to an option for speaker selection (Agent, 'None' to end, Str for speaker selection method). In the case of a nested chat, this will return False as it should be encapsulated in an agent."""
return False
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to a speaker selection result (Agent, None for termination, or str for speaker selection method)."""
raise NotImplementedError("Requires subclasses to implement.")
def display_name(self) -> str:
"""Get the display name for the target."""
raise NotImplementedError("Requires subclasses to implement.")
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
raise NotImplementedError("Requires subclasses to implement.")
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent."""
raise NotImplementedError("Requires subclasses to implement.")
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the target if needed."""
raise NotImplementedError("Requires subclasses to implement.")
class AgentTarget(TransitionTarget):
"""Target that represents a direct agent reference."""
agent_name: str
def __init__(self, agent: "ConversableAgent", **data: Any) -> None: # type: ignore[no-untyped-def]
# Store the name from the agent for serialization
super().__init__(agent_name=agent.name, **data)
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection."""
return True
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to the actual agent object from the groupchat."""
return SpeakerSelectionResult(agent_name=self.agent_name)
def display_name(self) -> str:
"""Get the display name for the target."""
return f"{self.agent_name}"
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
return self.display_name()
def __str__(self) -> str:
"""String representation for AgentTarget, can be shown as a function call message."""
return f"Transfer to {self.agent_name}"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent."""
return False
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the target if needed."""
raise NotImplementedError("AgentTarget does not require wrapping in an agent.")
class AgentNameTarget(TransitionTarget):
"""Target that represents an agent by name."""
agent_name: str
def __init__(self, agent_name: str, **data: Any) -> None:
"""Initialize with agent name as a positional parameter."""
super().__init__(agent_name=agent_name, **data)
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection."""
return True
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to the agent name string."""
return SpeakerSelectionResult(agent_name=self.agent_name)
def display_name(self) -> str:
"""Get the display name for the target."""
return f"{self.agent_name}"
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
return self.display_name()
def __str__(self) -> str:
"""String representation for AgentTarget, can be shown as a function call message."""
return f"Transfer to {self.agent_name}"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent."""
return False
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the target if needed."""
raise NotImplementedError("AgentNameTarget does not require wrapping in an agent.")
class NestedChatTarget(TransitionTarget):
"""Target that represents a nested chat configuration."""
nested_chat_config: dict[str, Any]
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection. For NestedChatTarget the nested chat must be encapsulated into an agent."""
return False
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to the nested chat configuration."""
raise NotImplementedError(
"NestedChatTarget does not support the resolve method. An agent should be used to encapsulate this nested chat and then the target changed to an AgentTarget."
)
def display_name(self) -> str:
"""Get the display name for the target."""
return "a nested chat"
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
return "nested_chat"
def __str__(self) -> str:
"""String representation for AgentTarget, can be shown as a function call message."""
return "Transfer to nested chat"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent. NestedChatTarget must be wrapped in an agent."""
return True
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the nested chat."""
from ...conversable_agent import ConversableAgent # to avoid circular import - NEED SOLUTION
nested_chat_agent = ConversableAgent(name=f"{__AGENT_WRAPPER_PREFIX__}nested_{parent_agent.name}_{index + 1}")
nested_chat_agent.register_nested_chats(
self.nested_chat_config["chat_queue"],
reply_func_from_nested_chats=self.nested_chat_config.get("reply_func_from_nested_chats")
or "summary_from_nested_chats",
config=self.nested_chat_config.get("config"),
trigger=lambda sender: True,
position=0,
use_async=self.nested_chat_config.get("use_async", False),
)
# After the nested chat is complete, transfer back to the parent agent
nested_chat_agent.handoffs.set_after_work(AgentTarget(parent_agent))
return nested_chat_agent
class TerminateTarget(TransitionTarget):
"""Target that represents a termination of the conversation."""
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection."""
return True
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to termination."""
return SpeakerSelectionResult(terminate=True)
def display_name(self) -> str:
"""Get the display name for the target."""
return "Terminate"
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
return "terminate"
def __str__(self) -> str:
"""String representation for AgentTarget, can be shown as a function call message."""
return "Terminate"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent."""
return False
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the target if needed."""
raise NotImplementedError("TerminateTarget does not require wrapping in an agent.")
class StayTarget(TransitionTarget):
"""Target that represents staying with the current agent."""
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection."""
return True
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to staying with the current agent."""
return SpeakerSelectionResult(agent_name=current_agent.name)
def display_name(self) -> str:
"""Get the display name for the target."""
return "Stay"
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
return "stay"
def __str__(self) -> str:
"""String representation for AgentTarget, can be shown as a function call message."""
return "Stay with agent"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent."""
return False
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the target if needed."""
raise NotImplementedError("StayTarget does not require wrapping in an agent.")
class RevertToUserTarget(TransitionTarget):
"""Target that represents reverting to the user agent."""
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection."""
return True
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to reverting to the user agent."""
if user_agent is None:
raise ValueError("User agent must be provided to the chat for the revert_to_user option.")
return SpeakerSelectionResult(agent_name=user_agent.name)
def display_name(self) -> str:
"""Get the display name for the target."""
return "Revert to User"
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
return "revert_to_user"
def __str__(self) -> str:
"""String representation for AgentTarget, can be shown as a function call message."""
return "Revert to User"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent."""
return False
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the target if needed."""
raise NotImplementedError("RevertToUserTarget does not require wrapping in an agent.")
class AskUserTarget(TransitionTarget):
"""Target that represents asking the user for input."""
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection."""
return True
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to asking the user for input."""
return SpeakerSelectionResult(speaker_selection_method="manual")
def display_name(self) -> str:
"""Get the display name for the target."""
return "Ask User"
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
return "ask_user"
def __str__(self) -> str:
"""String representation for AgentTarget, can be shown as a function call message."""
return "Ask User"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent."""
return False
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the target if needed."""
raise NotImplementedError("AskUserTarget does not require wrapping in an agent.")
class RandomAgentTarget(TransitionTarget):
"""Target that represents a random selection from a list of agents."""
agent_names: list[str]
nominated_name: str = "<Not Randomly Selected Yet>"
def __init__(self, agents: list["ConversableAgent"], **data: Any) -> None: # type: ignore[no-untyped-def]
# Store the name from the agent for serialization
super().__init__(agent_names=[agent.name for agent in agents], **data)
def can_resolve_for_speaker_selection(self) -> bool:
"""Check if the target can resolve for speaker selection."""
return True
def resolve(
self,
groupchat: "GroupChat",
current_agent: "ConversableAgent",
user_agent: Optional["ConversableAgent"],
) -> SpeakerSelectionResult:
"""Resolve to the actual agent object from the groupchat, choosing a random agent (except the current one)"""
# Randomly select the next agent
self.nominated_name = random.choice([name for name in self.agent_names if name != current_agent.name])
return SpeakerSelectionResult(agent_name=self.nominated_name)
def display_name(self) -> str:
"""Get the display name for the target."""
return self.nominated_name
def normalized_name(self) -> str:
"""Get a normalized name for the target that has no spaces, used for function calling"""
return self.display_name()
def __str__(self) -> str:
"""String representation for RandomAgentTarget, can be shown as a function call message."""
return f"Transfer to {self.nominated_name}"
def needs_agent_wrapper(self) -> bool:
"""Check if the target needs to be wrapped in an agent."""
return False
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
"""Create a wrapper agent for the target if needed."""
raise NotImplementedError("RandomAgentTarget does not require wrapping in an agent.")
# TODO: Consider adding a SequentialChatTarget class

View File

@@ -0,0 +1,6 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
# Prefix for all wrapped agent names
__AGENT_WRAPPER_PREFIX__ = "wrapped_"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .audio_adapters import TwilioAudioAdapter, WebSocketAudioAdapter
from .audio_observer import AudioObserver
from .function_observer import FunctionObserver
from .realtime_agent import RealtimeAgent
from .realtime_observer import RealtimeObserver
from .realtime_swarm import register_swarm
__all__ = [
"AudioObserver",
"FunctionObserver",
"RealtimeAgent",
"RealtimeObserver",
"TwilioAudioAdapter",
"WebSocketAudioAdapter",
"register_swarm",
]

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .twilio_audio_adapter import TwilioAudioAdapter
from .websocket_audio_adapter import WebSocketAudioAdapter
__all__ = ["TwilioAudioAdapter", "WebSocketAudioAdapter"]

View File

@@ -0,0 +1,148 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import base64
import json
from logging import Logger
from typing import TYPE_CHECKING, Optional
from .....doc_utils import export_module
from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted
from ..realtime_observer import RealtimeObserver
if TYPE_CHECKING:
from ..websockets import WebSocketProtocol as WebSocket
LOG_EVENT_TYPES = [
"error",
"response.content.done",
"rate_limits.updated",
"response.done",
"input_audio_buffer.committed",
"input_audio_buffer.speech_stopped",
"input_audio_buffer.speech_started",
"session.created",
]
SHOW_TIMING_MATH = False
@export_module("autogen.agentchat.realtime.experimental")
class TwilioAudioAdapter(RealtimeObserver):
"""Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa."""
def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None):
"""Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa.
Args:
websocket: the websocket connection to the Twilio service
logger: the logger to use for logging events
"""
super().__init__(logger=logger)
self.websocket = websocket
# Connection specific state
self.stream_sid = None
self.latest_media_timestamp = 0
self.last_assistant_item: Optional[str] = None
self.mark_queue: list[str] = []
self.response_start_timestamp_twilio: Optional[int] = None
async def on_event(self, event: RealtimeEvent) -> None:
"""Receive events from the OpenAI Realtime API, send audio back to Twilio."""
logger = self.logger
if isinstance(event, AudioDelta):
audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8")
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
await self.websocket.send_json(audio_delta)
if self.response_start_timestamp_twilio is None:
self.response_start_timestamp_twilio = self.latest_media_timestamp
if SHOW_TIMING_MATH:
logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_twilio}ms")
# Update last_assistant_item safely
if event.item_id:
self.last_assistant_item = event.item_id
await self.send_mark()
# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
if isinstance(event, SpeechStarted):
logger.info("Speech start detected.")
if self.last_assistant_item:
logger.info(f"Interrupting response with id: {self.last_assistant_item}")
await self.handle_speech_started_event()
async def handle_speech_started_event(self) -> None:
"""Handle interruption when the caller's speech starts."""
logger = self.logger
logger.info("Handling speech started event.")
if self.mark_queue and self.response_start_timestamp_twilio is not None:
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_twilio
if SHOW_TIMING_MATH:
logger.info(
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_twilio} = {elapsed_time}ms"
)
if self.last_assistant_item:
if SHOW_TIMING_MATH:
logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")
await self.realtime_client.truncate_audio(
audio_end_ms=elapsed_time,
content_index=0,
item_id=self.last_assistant_item,
)
await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})
self.mark_queue.clear()
self.last_assistant_item = None
self.response_start_timestamp_twilio = None
async def send_mark(self) -> None:
"""Send a mark of audio interruption to the Twilio websocket."""
if self.stream_sid:
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
await self.websocket.send_json(mark_event)
self.mark_queue.append("responsePart")
async def run_loop(self) -> None:
"""Run the adapter loop."""
logger = self.logger
async for message in self.websocket.iter_text():
try:
data = json.loads(message)
if data["event"] == "media":
self.latest_media_timestamp = int(data["media"]["timestamp"])
await self.realtime_client.send_audio(audio=data["media"]["payload"])
elif data["event"] == "start":
self.stream_sid = data["start"]["streamSid"]
logger.info(f"Incoming stream has started {self.stream_sid}")
self.response_start_timestamp_twilio = None
self.latest_media_timestamp = 0
self.last_assistant_item = None
elif data["event"] == "mark":
if self.mark_queue:
self.mark_queue.pop(0)
except Exception as e:
logger.warning(f"Error processing Twilio message: {e}", stack_info=True)
async def initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {
"input_audio_format": "g711_ulaw",
"output_audio_format": "g711_ulaw",
}
await self.realtime_client.session_update(session_update)
if TYPE_CHECKING:
def twilio_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
return TwilioAudioAdapter(websocket)

View File

@@ -0,0 +1,139 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import base64
import json
from logging import Logger
from typing import TYPE_CHECKING, Optional
from .....doc_utils import export_module
from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted
from ..realtime_observer import RealtimeObserver
if TYPE_CHECKING:
from ..websockets import WebSocketProtocol as WebSocket
LOG_EVENT_TYPES = [
"error",
"response.content.done",
"rate_limits.updated",
"response.done",
"input_audio_buffer.committed",
"input_audio_buffer.speech_stopped",
"input_audio_buffer.speech_started",
"session.created",
]
SHOW_TIMING_MATH = False
@export_module("autogen.agentchat.realtime.experimental")
class WebSocketAudioAdapter(RealtimeObserver):
def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None) -> None:
"""Observer for handling function calls from the OpenAI Realtime API.
Args:
websocket (WebSocket): The websocket connection.
logger (Logger): The logger for the observer.
"""
super().__init__(logger=logger)
self.websocket = websocket
# Connection specific state
self.stream_sid = None
self.latest_media_timestamp = 0
self.last_assistant_item: Optional[str] = None
self.mark_queue: list[str] = []
self.response_start_timestamp_socket: Optional[int] = None
async def on_event(self, event: RealtimeEvent) -> None:
"""Receive events from the OpenAI Realtime API, send audio back to websocket."""
logger = self.logger
if isinstance(event, AudioDelta):
audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8")
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
await self.websocket.send_json(audio_delta)
if self.response_start_timestamp_socket is None:
self.response_start_timestamp_socket = self.latest_media_timestamp
if SHOW_TIMING_MATH:
logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_socket}ms")
# Update last_assistant_item safely
if event.item_id:
self.last_assistant_item = event.item_id
await self.send_mark()
# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
if isinstance(event, SpeechStarted):
logger.info("Speech start detected.")
if self.last_assistant_item:
logger.info(f"Interrupting response with id: {self.last_assistant_item}")
await self.handle_speech_started_event()
async def handle_speech_started_event(self) -> None:
"""Handle interruption when the caller's speech starts."""
logger = self.logger
logger.info("Handling speech started event.")
if self.mark_queue and self.response_start_timestamp_socket is not None:
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_socket
if SHOW_TIMING_MATH:
logger.info(
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_socket} = {elapsed_time}ms"
)
if self.last_assistant_item:
if SHOW_TIMING_MATH:
logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")
await self.realtime_client.truncate_audio(
audio_end_ms=elapsed_time,
content_index=0,
item_id=self.last_assistant_item,
)
await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})
self.mark_queue.clear()
self.last_assistant_item = None
self.response_start_timestamp_socket = None
async def send_mark(self) -> None:
if self.stream_sid:
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
await self.websocket.send_json(mark_event)
self.mark_queue.append("responsePart")
async def initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {"input_audio_format": "pcm16", "output_audio_format": "pcm16"}
await self.realtime_client.session_update(session_update)
async def run_loop(self) -> None:
"""Reads data from websocket and sends it to the RealtimeClient."""
logger = self.logger
async for message in self.websocket.iter_text():
try:
data = json.loads(message)
if data["event"] == "media":
self.latest_media_timestamp = int(data["media"]["timestamp"])
await self.realtime_client.send_audio(audio=data["media"]["payload"])
elif data["event"] == "start":
self.stream_sid = data["start"]["streamSid"]
logger.info(f"Incoming stream has started {self.stream_sid}")
self.response_start_timestamp_socket = None
self.latest_media_timestamp = 0
self.last_assistant_item = None
elif data["event"] == "mark":
if self.mark_queue:
self.mark_queue.pop(0)
except Exception as e:
logger.warning(f"Failed to process message: {e}", stack_info=True)
if TYPE_CHECKING:
def websocket_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
return WebSocketAudioAdapter(websocket)

View File

@@ -0,0 +1,42 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional
from ....doc_utils import export_module
from .realtime_events import InputAudioBufferDelta, RealtimeEvent
from .realtime_observer import RealtimeObserver
if TYPE_CHECKING:
from logging import Logger
@export_module("autogen.agentchat.realtime.experimental")
class AudioObserver(RealtimeObserver):
"""Observer for user voice input"""
def __init__(self, *, logger: Optional["Logger"] = None) -> None:
"""Observer for user voice input"""
super().__init__(logger=logger)
async def on_event(self, event: RealtimeEvent) -> None:
"""Observe voice input events from the Realtime.
Args:
event (dict[str, Any]): The event from the OpenAI Realtime API.
"""
if isinstance(event, InputAudioBufferDelta):
self.logger.info("Received audio buffer delta")
async def initialize_session(self) -> None:
"""No need to initialize session from this observer"""
pass
async def run_loop(self) -> None:
"""Run the observer loop."""
pass
if TYPE_CHECKING:
function_observer: RealtimeObserver = AudioObserver()

View File

@@ -0,0 +1,15 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .gemini.client import GeminiRealtimeClient
from .oai.base_client import OpenAIRealtimeClient
from .realtime_client import RealtimeClientProtocol, Role, get_client
__all__ = [
"GeminiRealtimeClient",
"OpenAIRealtimeClient",
"RealtimeClientProtocol",
"Role",
"get_client",
]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .client import GeminiRealtimeClient
__all__ = ["GeminiRealtimeClient"]

View File

@@ -0,0 +1,274 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from ......doc_utils import export_module
from ......import_utils import optional_import_block, require_optional_import
from ......llm_config import LLMConfig
from ...realtime_events import AudioDelta, FunctionCall, RealtimeEvent, SessionCreated
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
with optional_import_block():
from websockets.asyncio.client import connect
if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
from ..realtime_client import RealtimeClientProtocol
__all__ = ["GeminiRealtimeClient"]
global_logger = getLogger(__name__)
HOST = "generativelanguage.googleapis.com"
API_VERSION = "v1alpha"
@register_realtime_client()
@require_optional_import("websockets", "gemini", except_for=["get_factory", "__init__"])
@export_module("autogen.agentchat.realtime.experimental.clients")
class GeminiRealtimeClient(RealtimeClientBase):
"""(Experimental) Client for Gemini Realtime API."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for Gemini Realtime API.
Args:
llm_config: The config for the client.
logger: The logger for the client.
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._connection: Optional["ClientConnection"] = None
config = llm_config["config_list"][0]
self._model: str = config["model"]
self._voice = config.get("voice", "charon")
self._temperature: float = config.get("temperature", 0.8) # type: ignore[union-attr]
self._response_modality = "AUDIO"
self._api_key = config.get("api_key", None)
# todo: add test with base_url just to make sure it works
self._base_url: str = config.get(
"base_url",
f"wss://{HOST}/ws/google.ai.generativelanguage.{API_VERSION}.GenerativeService.BidiGenerateContent?key={self._api_key}",
)
self._final_config: dict[str, Any] = {}
self._pending_session_updates: dict[str, Any] = {}
self._is_reading_events = False
@property
def logger(self) -> Logger:
"""Get the logger for the Gemini Realtime API."""
return self._logger or global_logger
@property
def connection(self) -> "ClientConnection":
"""Get the Gemini WebSocket connection."""
if self._connection is None:
raise RuntimeError("Gemini WebSocket is not initialized")
return self._connection
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the Gemini Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
msg = {
"tool_response": {"function_responses": [{"id": call_id, "response": {"result": {"string_value": result}}}]}
}
if self._is_reading_events:
await self.connection.send(json.dumps(msg))
async def send_text(self, *, role: Role, text: str, turn_complete: bool = True) -> None:
"""Send a text message to the Gemini Realtime API.
Args:
role: The role of the message.
text: The text of the message.
turn_complete: A flag indicating if the turn is complete.
"""
msg = {
"client_content": {
"turn_complete": turn_complete,
"turns": [{"role": role, "parts": [{"text": text}]}],
}
}
if self._is_reading_events:
await self.connection.send(json.dumps(msg))
async def send_audio(self, audio: str) -> None:
"""Send audio to the Gemini Realtime API.
Args:
audio (str): The audio to send.
"""
msg = {
"realtime_input": {
"media_chunks": [
{
"data": audio,
"mime_type": "audio/pcm",
}
]
}
}
await self.queue_input_audio_buffer_delta(audio)
if self._is_reading_events:
await self.connection.send(json.dumps(msg))
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
self.logger.info("This is not natively supported by Gemini Realtime API.")
pass
async def _initialize_session(self) -> None:
"""Initialize the session with the Gemini Realtime API."""
session_config = {
"setup": {
"system_instruction": {
"role": "system",
"parts": [{"text": self._pending_session_updates.get("instructions", "")}],
},
"model": f"models/{self._model}",
"tools": [
{
"function_declarations": [
{
"name": tool_schema["name"],
"description": tool_schema["description"],
"parameters": tool_schema["parameters"],
}
for tool_schema in self._pending_session_updates.get("tools", [])
]
},
],
"generation_config": {
"response_modalities": [self._response_modality],
"speech_config": {"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": self._voice}}},
"temperature": self._temperature,
},
}
}
self.logger.info(f"Sending session update: {session_config}")
await self.connection.send(json.dumps(session_config))
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Record session updates to be applied when the connection is established.
Args:
session_options (dict[str, Any]): The session options to update.
"""
if self._is_reading_events:
self.logger.warning("Is reading events. Session update will be ignored.")
else:
self._pending_session_updates.update(session_options)
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the Gemini Realtime API."""
try:
async with connect(
self._base_url, additional_headers={"Content-Type": "application/json"}
) as self._connection:
yield
finally:
self._connection = None
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read Events from the Gemini Realtime Client"""
if self._connection is None:
raise RuntimeError("Client is not connected, call connect() first.")
await self._initialize_session()
self._is_reading_events = True
async for event in self._read_events():
yield event
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the Gemini Realtime connection."""
async for raw_message in self.connection:
message = raw_message.decode("ascii") if isinstance(raw_message, bytes) else raw_message
events = self._parse_message(json.loads(message))
for event in events:
yield event
def _parse_message(self, response: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the Gemini Realtime API.
Args:
response (dict[str, Any]): The response to parse.
Returns:
list[RealtimeEvent]: The parsed events.
"""
if "serverContent" in response and "modelTurn" in response["serverContent"]:
try:
b64data = response["serverContent"]["modelTurn"]["parts"][0]["inlineData"].pop("data")
return [
AudioDelta(
delta=b64data,
item_id=None,
raw_message=response,
)
]
except KeyError:
return []
elif "toolCall" in response:
return [
FunctionCall(
raw_message=response,
call_id=call["id"],
name=call["name"],
arguments=call["args"],
)
for call in response["toolCall"]["functionCalls"]
]
elif "setupComplete" in response:
return [
SessionCreated(raw_message=response),
]
else:
return [RealtimeEvent(raw_message=response)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The LLM config for the client.
logger: The logger for the client.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type") == "google" and list(kwargs.keys()) == []:
return lambda: GeminiRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if GeminiRealtimeClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_client: RealtimeClientProtocol = GeminiRealtimeClient(llm_config={})

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .base_client import OpenAIRealtimeClient
from .rtc_client import OpenAIRealtimeWebRTCClient
__all__ = ["OpenAIRealtimeClient", "OpenAIRealtimeWebRTCClient"]

View File

@@ -0,0 +1,220 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from ......doc_utils import export_module
from ......import_utils import optional_import_block, require_optional_import
from ......llm_config import LLMConfig
from ...realtime_events import RealtimeEvent
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message
with optional_import_block():
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection
if TYPE_CHECKING:
from ..realtime_client import RealtimeClientProtocol
__all__ = ["OpenAIRealtimeClient"]
global_logger = getLogger(__name__)
@register_realtime_client()
@require_optional_import("openai>=1.66.2", "openai-realtime", except_for=["get_factory", "__init__"])
@export_module("autogen.agentchat.realtime.experimental.clients")
class OpenAIRealtimeClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Args:
llm_config: The config for the client.
logger: the logger to use for logging events
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._connection: Optional["AsyncRealtimeConnection"] = None
self.config = llm_config["config_list"][0]
# model is passed to self._client.beta.realtime.connect function later
self._model: str = self.config["model"]
self._voice: str = self.config.get("voice", "alloy")
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
self._client: Optional["AsyncOpenAI"] = None
@property
def logger(self) -> Logger:
"""Get the logger for the OpenAI Realtime API."""
return self._logger or global_logger
@property
def connection(self) -> "AsyncRealtimeConnection":
"""Get the OpenAI WebSocket connection."""
if self._connection is None:
raise RuntimeError("OpenAI WebSocket is not initialized")
return self._connection
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
await self.connection.conversation.item.create(
item={
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
)
await self.connection.response.create()
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
await self.connection.response.cancel()
await self.connection.conversation.item.create(
item={"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}
)
await self.connection.response.create()
async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
Args:
audio (str): The audio to send.
"""
await self.queue_input_audio_buffer_delta(audio)
await self.connection.input_audio_buffer.append(audio=audio)
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
await self.connection.conversation.item.truncate(
audio_end_ms=audio_end_ms, content_index=content_index, item_id=item_id
)
async def _initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._voice,
"modalities": ["audio", "text"],
"temperature": self._temperature,
}
await self.session_update(session_options=session_update)
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
Args:
session_options (dict[str, Any]): The session options to update.
"""
logger = self.logger
logger.info(f"Sending session update: {session_options}")
await self.connection.session.update(session=session_options) # type: ignore[arg-type]
logger.info("Sending session update finished")
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the OpenAI Realtime API."""
try:
if not self._client:
self._client = AsyncOpenAI(
api_key=self.config.get("api_key", None),
organization=self.config.get("organization", None),
project=self.config.get("project", None),
base_url=self.config.get("base_url", None),
websocket_base_url=self.config.get("websocket_base_url", None),
timeout=self.config.get("timeout", NOT_GIVEN),
max_retries=self.config.get("max_retries", DEFAULT_MAX_RETRIES),
default_headers=self.config.get("default_headers", None),
default_query=self.config.get("default_query", None),
)
async with self._client.beta.realtime.connect(
model=self._model,
) as self._connection:
await self._initialize_session()
yield
finally:
self._connection = None
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API."""
if self._connection is None:
raise RuntimeError("Client is not connected, call connect() first.")
try:
async for event in self._read_events():
yield event
finally:
self._connection = None
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API."""
async for message in self._connection:
for event in self._parse_message(message.model_dump()):
yield event
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
return [parse_oai_message(message)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == []:
return lambda: OpenAIRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_client: RealtimeClientProtocol = OpenAIRealtimeClient(llm_config={})

View File

@@ -0,0 +1,243 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from autogen.import_utils import optional_import_block, require_optional_import
from ......doc_utils import export_module
from ......llm_config import LLMConfig
from ...realtime_events import RealtimeEvent
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message
if TYPE_CHECKING:
from ...websockets import WebSocketProtocol as WebSocket
from ..realtime_client import RealtimeClientProtocol
with optional_import_block():
import httpx
__all__ = ["OpenAIRealtimeWebRTCClient"]
global_logger = getLogger(__name__)
@register_realtime_client()
@require_optional_import("httpx", "openai-realtime", except_for="get_factory")
@export_module("autogen.agentchat.realtime.experimental.clients.oai")
class OpenAIRealtimeWebRTCClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
websocket: "WebSocket",
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Args:
llm_config: The config for the client.
websocket: the websocket to use for the connection
logger: the logger to use for logging events
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._websocket = websocket
config = llm_config["config_list"][0]
self._model: str = config["model"]
self._voice: str = config.get("voice", "alloy")
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
self._config = config
self._base_url = config.get("base_url", "https://api.openai.com/v1/realtime/sessions")
@property
def logger(self) -> Logger:
"""Get the logger for the OpenAI Realtime API."""
return self._logger or global_logger
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
await self._websocket.send_json({
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
})
await self._websocket.send_json({"type": "response.create"})
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
# await self.connection.response.cancel() #why is this here?
await self._websocket.send_json({
"type": "response.cancel",
})
await self._websocket.send_json({
"type": "conversation.item.create",
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
})
# await self.connection.response.create()
await self._websocket.send_json({"type": "response.create"})
async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
in case of WebRTC, audio is already sent by js client, so we just queue it in order to be logged.
Args:
audio (str): The audio to send.
"""
await self.queue_input_audio_buffer_delta(audio)
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
await self._websocket.send_json({
"type": "conversation.item.truncate",
"content_index": content_index,
"item_id": item_id,
"audio_end_ms": audio_end_ms,
})
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
In the case of WebRTC we can not send it directly, but we can send it
to the javascript over the websocket, and rely on it to send session
update to OpenAI
Args:
session_options (dict[str, Any]): The session options to update.
"""
logger = self.logger
logger.info(f"Sending session update: {session_options}")
# await self.connection.session.update(session=session_options) # type: ignore[arg-type]
await self._websocket.send_json({"type": "session.update", "session": session_options})
logger.info("Sending session update finished")
def session_init_data(self) -> list[dict[str, Any]]:
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._voice,
"modalities": ["audio", "text"],
"temperature": self._temperature,
}
return [{"type": "session.update", "session": session_update}]
async def _initialize_session(self) -> None: ...
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the OpenAI Realtime API.
In the case of WebRTC, we pass connection information over the
websocket, so that javascript on the other end of websocket open
actual connection to OpenAI
"""
try:
base_url = self._base_url
api_key = self._config.get("api_key", None)
headers = {
"Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment
"Content-Type": "application/json",
}
data = {
# "model": "gpt-4o-realtime-preview-2024-12-17",
"model": self._model,
"voice": self._voice,
}
async with httpx.AsyncClient() as client:
response = await client.post(base_url, headers=headers, json=data)
response.raise_for_status()
json_data = response.json()
json_data["model"] = self._model
if self._websocket is not None:
session_init = self.session_init_data()
await self._websocket.send_json({"type": "ag2.init", "config": json_data, "init": session_init})
yield
finally:
pass
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from the OpenAI Realtime API."""
async for event in self._read_events():
yield event
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API connection.
Again, in case of WebRTC, we do not read OpenAI messages directly since we
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
"""
while True:
try:
message_json = await self._websocket.receive_text()
message = json.loads(message_json)
for event in self._parse_message(message):
yield event
except Exception as e:
self.logger.exception(f"Error reading from connection {e}")
break
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
return [parse_oai_message(message)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == ["websocket"]:
return lambda: OpenAIRealtimeWebRTCClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol:
return OpenAIRealtimeWebRTCClient(llm_config={}, websocket=websocket)

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Any
from ...realtime_events import (
AudioDelta,
FunctionCall,
InputAudioBufferDelta,
RealtimeEvent,
SessionCreated,
SessionUpdated,
SpeechStarted,
)
__all__ = ["parse_oai_message"]
def parse_oai_message(message: dict[str, Any]) -> RealtimeEvent:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
if message.get("type") == "session.created":
return SessionCreated(raw_message=message)
elif message.get("type") == "session.updated":
return SessionUpdated(raw_message=message)
elif message.get("type") == "response.audio.delta":
return AudioDelta(raw_message=message, delta=message["delta"], item_id=message["item_id"])
elif message.get("type") == "input_audio_buffer.speech_started":
return SpeechStarted(raw_message=message)
elif message.get("type") == "input_audio_buffer.delta":
return InputAudioBufferDelta(delta=message["delta"], item_id=None, raw_message=message)
elif message.get("type") == "response.function_call_arguments.done":
return FunctionCall(
raw_message=message,
call_id=message["call_id"],
name=message["name"],
arguments=json.loads(message["arguments"]),
)
else:
return RealtimeEvent(raw_message=message)

View File

@@ -0,0 +1,190 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
from collections.abc import AsyncGenerator
from logging import Logger
from typing import Any, AsyncContextManager, Callable, Literal, Optional, Protocol, TypeVar, Union, runtime_checkable
from asyncer import create_task_group
from .....doc_utils import export_module
from .....llm_config import LLMConfig
from ..realtime_events import InputAudioBufferDelta, RealtimeEvent
__all__ = ["RealtimeClientProtocol", "Role", "get_client", "register_realtime_client"]
# define role literal type for typing
Role = Literal["user", "assistant", "system"]
@runtime_checkable
@export_module("autogen.agentchat.realtime.experimental.clients")
class RealtimeClientProtocol(Protocol):
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to a Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
...
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to a Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
...
async def send_audio(self, audio: str) -> None:
"""Send audio to a Realtime API.
Args:
audio (str): The audio to send.
"""
...
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in a Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
...
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to a Realtime API.
Args:
session_options (dict[str, Any]): The session options to update.
"""
...
def connect(self) -> AsyncContextManager[None]: ...
def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime Client."""
...
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime connection."""
...
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from a Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
list[RealtimeEvent]: The parsed events.
"""
...
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
...
class RealtimeClientBase:
def __init__(self):
self._eventQueue = asyncio.Queue()
async def add_event(self, event: Optional[RealtimeEvent]):
await self._eventQueue.put(event)
async def get_event(self) -> Optional[RealtimeEvent]:
return await self._eventQueue.get()
async def _read_from_connection_task(self):
async for event in self._read_from_connection():
await self.add_event(event)
await self.add_event(None)
async def _read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime Client."""
async with create_task_group() as tg:
tg.start_soon(self._read_from_connection_task)
while True:
try:
event = await self._eventQueue.get()
if event is not None:
yield event
else:
break
except Exception:
break
async def queue_input_audio_buffer_delta(self, audio: str) -> None:
"""queue InputAudioBufferDelta.
Args:
audio (str): The audio.
"""
await self.add_event(InputAudioBufferDelta(delta=audio, item_id=None, raw_message=dict()))
_realtime_client_classes: dict[str, type[RealtimeClientProtocol]] = {}
T = TypeVar("T", bound=RealtimeClientProtocol)
def register_realtime_client() -> Callable[[type[T]], type[T]]:
"""Register a Realtime API client.
Returns:
Callable[[type[T]], type[T]]: The decorator to register the Realtime API client
"""
def decorator(client_cls: type[T]) -> type[T]:
"""Register a Realtime API client.
Args:
client_cls: The client to register.
"""
global _realtime_client_classes
fqn = f"{client_cls.__module__}.{client_cls.__name__}"
_realtime_client_classes[fqn] = client_cls
return client_cls
return decorator
@export_module("autogen.agentchat.realtime.experimental.clients")
def get_client(llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any) -> "RealtimeClientProtocol":
"""Get a registered Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client.
"""
global _realtime_client_classes
for _, client_cls in _realtime_client_classes.items():
factory = client_cls.get_factory(llm_config=llm_config, logger=logger, **kwargs)
if factory:
return factory()
raise ValueError("Realtime API client not found.")

View File

@@ -0,0 +1,85 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
from typing import TYPE_CHECKING, Any, Optional
from asyncer import asyncify
from pydantic import BaseModel
from ....doc_utils import export_module
from .realtime_events import FunctionCall, RealtimeEvent
from .realtime_observer import RealtimeObserver
if TYPE_CHECKING:
from logging import Logger
@export_module("autogen.agentchat.realtime.experimental")
class FunctionObserver(RealtimeObserver):
"""Observer for handling function calls from the OpenAI Realtime API."""
def __init__(self, *, logger: Optional["Logger"] = None) -> None:
"""Observer for handling function calls from the OpenAI Realtime API."""
super().__init__(logger=logger)
async def on_event(self, event: RealtimeEvent) -> None:
"""Handle function call events from the OpenAI Realtime API.
Args:
event (dict[str, Any]): The event from the OpenAI Realtime API.
"""
if isinstance(event, FunctionCall):
self.logger.info("Received function call event")
await self.call_function(
call_id=event.call_id,
name=event.name,
kwargs=event.arguments,
)
async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -> None:
"""Call a function registered with the agent.
Args:
call_id (str): The ID of the function call.
name (str): The name of the function to call.
kwargs (Any[str, Any]): The arguments to pass to the function.
"""
if name in self.agent.registered_realtime_tools:
func = self.agent.registered_realtime_tools[name].func
func = func if asyncio.iscoroutinefunction(func) else asyncify(func)
try:
result = await func(**kwargs)
except Exception:
result = "Function call failed"
self.logger.info(f"Function call failed: {name=}, {kwargs=}", stack_info=True)
if isinstance(result, BaseModel):
result = result.model_dump_json()
elif not isinstance(result, str):
try:
result = json.dumps(result)
except Exception:
result = str(result)
await self.realtime_client.send_function_result(call_id, result)
else:
self.logger.warning(f"Function {name} called, but is not registered with the realtime agent.")
async def initialize_session(self) -> None:
"""Add registered tools to OpenAI with a session update."""
session_update = {
"tools": [tool.realtime_tool_schema for tool in self.agent.registered_realtime_tools.values()],
"tool_choice": "auto",
}
await self.realtime_client.session_update(session_update)
async def run_loop(self) -> None:
"""Run the observer loop."""
pass
if TYPE_CHECKING:
function_observer: RealtimeObserver = FunctionObserver()

View File

@@ -0,0 +1,158 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from logging import Logger, getLogger
from typing import Any, Callable, Optional, TypeVar, Union
from anyio import lowlevel
from asyncer import create_task_group
from ....doc_utils import export_module
from ....llm_config import LLMConfig
from ....tools import Tool
from .clients.realtime_client import RealtimeClientProtocol, get_client
from .function_observer import FunctionObserver
from .realtime_observer import RealtimeObserver
F = TypeVar("F", bound=Callable[..., Any])
global_logger = getLogger(__name__)
@dataclass
class RealtimeAgentCallbacks:
"""Callbacks for the Realtime Agent."""
# async empty placeholder function
on_observers_ready: Callable[[], Any] = lambda: lowlevel.checkpoint()
@export_module("autogen.agentchat.realtime.experimental")
class RealtimeAgent:
def __init__(
self,
*,
name: str,
audio_adapter: Optional[RealtimeObserver] = None,
system_message: str = "You are a helpful AI Assistant.",
llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
logger: Optional[Logger] = None,
observers: Optional[list[RealtimeObserver]] = None,
**client_kwargs: Any,
):
"""(Experimental) Agent for interacting with the Realtime Clients.
Args:
name (str): The name of the agent.
audio_adapter (Optional[RealtimeObserver] = None): The audio adapter for the agent.
system_message (str): The system message for the agent.
llm_config (LLMConfig, dict[str, Any], bool): The config for the agent.
logger (Optional[Logger]): The logger for the agent.
observers (Optional[list[RealtimeObserver]]): The additional observers for the agent.
**client_kwargs (Any): The keyword arguments for the client.
"""
self._logger = logger
self._name = name
self._system_message = system_message
llm_config = LLMConfig.get_current_llm_config(llm_config)
self._realtime_client: RealtimeClientProtocol = get_client(
llm_config=llm_config, logger=self.logger, **client_kwargs
)
self._registered_realtime_tools: dict[str, Tool] = {}
self._observers: list[RealtimeObserver] = observers if observers else []
self._observers.append(FunctionObserver(logger=logger))
if audio_adapter:
self._observers.append(audio_adapter)
self.callbacks = RealtimeAgentCallbacks()
@property
def system_message(self) -> str:
"""Get the system message for the agent."""
return self._system_message
@property
def logger(self) -> Logger:
"""Get the logger for the agent."""
return self._logger or global_logger
@property
def realtime_client(self) -> RealtimeClientProtocol:
"""Get the OpenAI Realtime Client."""
return self._realtime_client
@property
def registered_realtime_tools(self) -> dict[str, Tool]:
"""Get the registered realtime tools."""
return self._registered_realtime_tools
def register_observer(self, observer: RealtimeObserver) -> None:
"""Register an observer with the Realtime Agent.
Args:
observer (RealtimeObserver): The observer to register.
"""
self._observers.append(observer)
async def start_observers(self) -> None:
for observer in self._observers:
self._tg.soonify(observer.run)(self)
# wait for the observers to be ready
for observer in self._observers:
await observer.wait_for_ready()
await self.callbacks.on_observers_ready()
async def run(self) -> None:
"""Run the agent."""
# everything is run in the same task group to enable easy cancellation using self._tg.cancel_scope.cancel()
async with create_task_group() as self._tg: # noqa: SIM117
# connect with the client first (establishes a connection and initializes a session)
async with self._realtime_client.connect():
# start the observers and wait for them to be ready
await self.realtime_client.session_update(session_options={"instructions": self.system_message})
await self.start_observers()
# iterate over the events
async for event in self.realtime_client.read_events():
for observer in self._observers:
await observer.on_event(event)
def register_realtime_function(
self,
*,
name: Optional[str] = None,
description: Optional[str] = None,
) -> Callable[[Union[F, Tool]], Tool]:
"""Decorator for registering a function to be used by an agent.
Args:
name (str): The name of the function.
description (str): The description of the function.
Returns:
Callable[[Union[F, Tool]], Tool]: The decorator for registering a function.
"""
def _decorator(func_or_tool: Union[F, Tool]) -> Tool:
"""Decorator for registering a function to be used by an agent.
Args:
func_or_tool (Union[F, Tool]): The function or tool to register.
Returns:
Tool: The registered tool.
"""
tool = Tool(func_or_tool=func_or_tool, name=name, description=description)
self._registered_realtime_tools[tool.name] = tool
return tool
return _decorator

View File

@@ -0,0 +1,42 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Literal
from pydantic import BaseModel
class RealtimeEvent(BaseModel):
raw_message: dict[str, Any]
class SessionCreated(RealtimeEvent):
type: Literal["session.created"] = "session.created"
class SessionUpdated(RealtimeEvent):
type: Literal["session.updated"] = "session.updated"
class AudioDelta(RealtimeEvent):
type: Literal["response.audio.delta"] = "response.audio.delta"
delta: str
item_id: Any
class InputAudioBufferDelta(RealtimeEvent):
type: Literal["input_audio_buffer.delta"] = "input_audio_buffer.delta"
delta: str
item_id: Any
class SpeechStarted(RealtimeEvent):
type: Literal["input_audio_buffer.speech_started"] = "input_audio_buffer.speech_started"
class FunctionCall(RealtimeEvent):
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
name: str
arguments: dict[str, Any]
call_id: str

View File

@@ -0,0 +1,100 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Optional
from anyio import Event
from ....doc_utils import export_module
from .clients.realtime_client import RealtimeClientProtocol
from .realtime_events import RealtimeEvent
if TYPE_CHECKING:
from .realtime_agent import RealtimeAgent
__all__ = ["RealtimeObserver"]
global_logger = getLogger(__name__)
@export_module("autogen.agentchat.realtime.experimental")
class RealtimeObserver(ABC):
"""Observer for the OpenAI Realtime API."""
def __init__(self, *, logger: Optional[Logger] = None) -> None:
"""Observer for the OpenAI Realtime API.
Args:
logger (Logger): The logger for the observer.
"""
self._ready_event = Event()
self._agent: Optional[RealtimeAgent] = None
self._logger = logger
@property
def logger(self) -> Logger:
return self._logger or global_logger
@property
def agent(self) -> "RealtimeAgent":
if self._agent is None:
raise RuntimeError("Agent has not been set.")
return self._agent
@property
def realtime_client(self) -> RealtimeClientProtocol:
if self._agent is None:
raise RuntimeError("Agent has not been set.")
if self._agent.realtime_client is None:
raise RuntimeError("Realtime client has not been set.")
return self._agent.realtime_client
async def run(self, agent: "RealtimeAgent") -> None:
"""Run the observer with the agent.
When implementing, be sure to call `self._ready_event.set()` when the observer is ready to process events.
Args:
agent (RealtimeAgent): The realtime agent attached to the observer.
"""
self._agent = agent
await self.initialize_session()
self._ready_event.set()
await self.run_loop()
@abstractmethod
async def run_loop(self) -> None:
"""Run the loop if needed.
This method is called after the observer is ready to process events.
Events will be processed by the on_event method, this is just a hook for additional processing.
Use initialize_session to set up the session.
"""
...
@abstractmethod
async def initialize_session(self) -> None:
"""Initialize the session for the observer."""
...
async def wait_for_ready(self) -> None:
"""Get the event that is set when the observer is ready."""
await self._ready_event.wait()
@abstractmethod
async def on_event(self, event: RealtimeEvent) -> None:
"""Handle an event from the OpenAI Realtime API.
Args:
event (RealtimeServerEvent): The event from the OpenAI Realtime API.
"""
...
async def on_close(self) -> None:
"""Handle close of RealtimeClient."""
...

View File

@@ -0,0 +1,483 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import logging
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
import anyio
from asyncer import asyncify, create_task_group, syncify
from ....agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat
from ....cache import AbstractCache
from ....code_utils import content_str
from ....doc_utils import export_module
from ... import Agent, ChatResult, ConversableAgent, LLMAgent
from ...utils import consolidate_chat_info, gather_usage_summary
if TYPE_CHECKING:
from .clients import Role
from .realtime_agent import RealtimeAgent
__all__ = ["register_swarm"]
SWARM_SYSTEM_MESSAGE = (
"You are a helpful voice assistant. Your task is to listen to user and to coordinate the tasks based on his/her inputs."
"You can and will communicate using audio output only."
)
QUESTION_ROLE: "Role" = "user"
QUESTION_MESSAGE = (
"I have a question/information for myself. DO NOT ANSWER YOURSELF, GET THE ANSWER FROM ME. "
"repeat the question to me **WITH AUDIO OUTPUT** and AFTER YOU GET THE ANSWER FROM ME call 'answer_task_question' with the answer in first person\n\n"
"IMPORTANT: repeat just the question, without any additional information or context\n\n"
"The question is: '{}'\n\n"
)
QUESTION_TIMEOUT_SECONDS = 20
logger = logging.getLogger(__name__)
F = TypeVar("F", bound=Callable[..., Any])
def message_to_dict(message: Union[dict[str, Any], str]) -> dict[str, Any]:
if isinstance(message, str):
return {"content": message}
elif isinstance(message, dict):
return message
else:
return dict(message)
def parse_oai_message(message: Union[dict[str, Any], str], role: str, adressee: Agent) -> dict[str, Any]:
"""
Parse a message into an OpenAI-compatible message format.
Args:
message: The message to parse.
role: The role associated with the message.
adressee: The agent that will receive the message.
Returns:
The parsed message in OpenAI-compatible format.
Raises:
ValueError: If the message lacks required fields like 'content', 'function_call', or 'tool_calls'.
"""
message = message_to_dict(message)
# Extract relevant fields while ensuring none are None
oai_message = {
key: message[key]
for key in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context")
if key in message and message[key] is not None
}
# Validate or set the content field
if "content" not in oai_message:
if "function_call" in oai_message or "tool_calls" in oai_message:
oai_message["content"] = None
else:
raise ValueError("Message must have either 'content', 'function_call', or 'tool_calls' field.")
# Determine and assign the role
if message.get("role") in ["function", "tool"]:
oai_message["role"] = message["role"]
# Ensure all tool responses have string content
for tool_response in oai_message.get("tool_responses", []):
tool_response["content"] = str(tool_response["content"])
elif "override_role" in message:
oai_message["role"] = message["override_role"]
else:
oai_message["role"] = role
# Enforce specific role requirements for assistant messages
if oai_message.get("function_call") or oai_message.get("tool_calls"):
oai_message["role"] = "assistant"
# Add a name field if missing
if "name" not in oai_message:
oai_message["name"] = adressee.name
return oai_message
class SwarmableAgent(Agent):
"""A class for an agent that can participate in a swarm chat."""
def __init__(
self,
name: str,
system_message: str = "You are a helpful AI Assistant.",
is_termination_msg: Optional[Callable[..., bool]] = None,
description: Optional[str] = None,
silent: Optional[bool] = None,
):
self._oai_messages: dict[Agent, Any] = defaultdict(list)
self._system_message = system_message
self._description = description if description is not None else system_message
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
self.silent = silent
self._name = name
# Initialize standalone client cache object.
self.client_cache = None
self.previous_cache = None
self.reply_at_receive: dict[Agent, bool] = defaultdict(bool)
@property
def system_message(self) -> str:
return self._system_message
def update_system_message(self, system_message: str) -> None:
"""Update this agent's system message.
Args:
system_message (str): system message for inference.
"""
self._system_message = system_message
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
def send(
self,
message: Union[dict[str, Any], str],
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> None:
self._oai_messages[recipient].append(parse_oai_message(message, "assistant", recipient))
recipient.receive(message, self, request_reply)
def receive(
self,
message: Union[dict[str, Any], str],
sender: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> None:
self._oai_messages[sender].append(parse_oai_message(message, "user", self))
if request_reply is False or (request_reply is None and self.reply_at_receive[sender] is False):
return
reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender)
if reply is not None:
self.send(reply, sender, silent=silent)
def generate_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, dict[str, Any], None]:
if messages is None:
if sender is None:
raise ValueError("Either messages or sender must be provided.")
messages = self._oai_messages[sender]
_, reply = self.check_termination_and_human_reply(messages=messages, sender=sender, config=None)
return reply
def check_termination_and_human_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Union[str, None]]:
raise NotImplementedError
def initiate_chat(
self,
recipient: ConversableAgent,
message: Union[dict[str, Any], str],
clear_history: bool = True,
silent: Optional[bool] = False,
cache: Optional[AbstractCache] = None,
summary_args: Optional[dict[str, Any]] = {},
**kwargs: dict[str, Any],
) -> ChatResult:
_chat_info = locals().copy()
_chat_info["sender"] = self
consolidate_chat_info(_chat_info, uniform_sender=self)
recipient._raise_exception_on_async_reply_functions()
recipient.previous_cache = recipient.client_cache # type: ignore[attr-defined]
recipient.client_cache = cache # type: ignore[attr-defined, assignment]
self._prepare_chat(recipient, clear_history)
self.send(message, recipient, silent=silent)
summary = self._last_msg_as_summary(self, recipient, summary_args)
recipient.client_cache = recipient.previous_cache # type: ignore[attr-defined]
recipient.previous_cache = None # type: ignore[attr-defined]
chat_result = ChatResult(
chat_history=self.chat_messages[recipient],
summary=summary,
cost=gather_usage_summary([self, recipient]), # type: ignore[arg-type]
human_input=[],
)
return chat_result
async def a_generate_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, dict[str, Any], None]:
return self.generate_reply(messages=messages, sender=sender, **kwargs)
async def a_receive(
self,
message: Union[dict[str, Any], str],
sender: "Agent",
request_reply: Optional[bool] = None,
) -> None:
self.receive(message, sender, request_reply)
async def a_send(
self,
message: Union[dict[str, Any], str],
recipient: "Agent",
request_reply: Optional[bool] = None,
) -> None:
self.send(message, recipient, request_reply)
@property
def chat_messages(self) -> dict[Agent, list[dict[str, Any]]]:
"""A dictionary of conversations from agent to list of messages."""
return self._oai_messages
def last_message(self, agent: Optional[Agent] = None) -> Optional[dict[str, Any]]:
if agent is None:
n_conversations = len(self._oai_messages)
if n_conversations == 0:
return None
if n_conversations == 1:
for conversation in self._oai_messages.values():
return conversation[-1] # type: ignore[no-any-return]
raise ValueError("More than one conversation is found. Please specify the sender to get the last message.")
if agent not in self._oai_messages():
raise KeyError(
f"The agent '{agent.name}' is not present in any conversation. No history available for this agent."
)
return self._oai_messages[agent][-1] # type: ignore[no-any-return]
def _prepare_chat(
self,
recipient: ConversableAgent,
clear_history: bool,
prepare_recipient: bool = True,
reply_at_receive: bool = True,
) -> None:
self.reply_at_receive[recipient] = reply_at_receive
if clear_history:
self._oai_messages[recipient].clear()
if prepare_recipient:
recipient._prepare_chat(self, clear_history, False, reply_at_receive) # type: ignore[arg-type]
def _raise_exception_on_async_reply_functions(self) -> None:
pass
def set_ui_tools(self, tools: Optional[list] = None) -> None:
"""Set UI tools for the agent."""
pass
def unset_ui_tools(self) -> None:
"""Unset UI tools for the agent."""
pass
@staticmethod
def _last_msg_as_summary(sender: Agent, recipient: Agent, summary_args: Optional[dict[str, Any]]) -> str:
"""Get a chat summary from the last message of the recipient."""
summary = ""
try:
content = recipient.last_message(sender)["content"] # type: ignore[attr-defined]
if isinstance(content, str):
summary = content.replace("TERMINATE", "")
elif isinstance(content, list):
summary = "\n".join(
x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x
)
except (IndexError, AttributeError) as e:
warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning)
return summary
# check that the SwarmableAgent class is implementing LLMAgent protocol
if TYPE_CHECKING:
def _create_swarmable_agent(
name: str,
system_message: str,
is_termination_msg: Optional[Callable[..., bool]],
description: Optional[str],
silent: Optional[bool],
) -> LLMAgent:
return SwarmableAgent(
name=name,
system_message=system_message,
is_termination_msg=is_termination_msg,
description=description,
silent=silent,
)
class SwarmableRealtimeAgent(SwarmableAgent):
def __init__(
self,
realtime_agent: "RealtimeAgent",
initial_agent: ConversableAgent,
agents: list[ConversableAgent],
question_message: Optional[str] = None,
) -> None:
self._initial_agent = initial_agent
self._agents = agents
self._realtime_agent = realtime_agent
self._answer_event: anyio.Event = anyio.Event()
self._answer: str = ""
self.question_message = question_message or QUESTION_MESSAGE
super().__init__(
name=realtime_agent._name,
is_termination_msg=None,
description=None,
silent=None,
)
def reset_answer(self) -> None:
"""Reset the answer event."""
self._answer_event = anyio.Event()
def set_answer(self, answer: str) -> str:
"""Set the answer to the question."""
self._answer = answer
self._answer_event.set()
return "Answer set successfully."
async def get_answer(self) -> str:
"""Get the answer to the question."""
await self._answer_event.wait()
return self._answer
async def ask_question(self, question: str, question_timeout: int) -> None:
"""Send a question for the user to the agent and wait for the answer.
If the answer is not received within the timeout, the question is repeated.
Args:
question: The question to ask the user.
question_timeout: The time in seconds to wait for the answer.
"""
self.reset_answer()
realtime_client = self._realtime_agent._realtime_client
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
async def _check_event_set(timeout: int = question_timeout) -> bool:
for _ in range(timeout):
if self._answer_event.is_set():
return True
await anyio.sleep(1)
return False
while not await _check_event_set():
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
def check_termination_and_human_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Optional[str]]:
"""Check if the conversation should be terminated and if the agent should reply.
Called when its agents turn in the chat conversation.
Args:
messages (list[dict[str, Any]]): The messages in the conversation.
sender (Agent): The agent that sent the message.
config (Optional[Any]): The configuration for the agent.
"""
if not messages:
return False, None
async def get_input() -> None:
async with create_task_group() as tg:
tg.soonify(self.ask_question)(
self.question_message.format(messages[-1]["content"]),
question_timeout=QUESTION_TIMEOUT_SECONDS,
)
syncify(get_input)()
return True, {"role": "user", "content": self._answer} # type: ignore[return-value]
def start_chat(self) -> None:
raise NotImplementedError
def configure_realtime_agent(self, system_message: Optional[str]) -> None:
realtime_agent = self._realtime_agent
logger = realtime_agent.logger
if not system_message:
if realtime_agent.system_message != "You are a helpful AI Assistant.":
logger.warning(
"Overriding system message set up in `__init__`, please use `system_message` parameter of the `register_swarm` function instead."
)
system_message = SWARM_SYSTEM_MESSAGE
realtime_agent._system_message = system_message
realtime_agent.register_realtime_function(
name="answer_task_question", description="Answer question from the task"
)(self.set_answer)
async def on_observers_ready() -> None:
self._realtime_agent._tg.soonify(asyncify(initiate_swarm_chat))(
initial_agent=self._initial_agent,
agents=self._agents,
user_agent=self, # type: ignore[arg-type]
messages="Find out what the user wants.",
after_work=AfterWorkOption.REVERT_TO_USER,
)
self._realtime_agent.callbacks.on_observers_ready = on_observers_ready
@export_module("autogen.agentchat.realtime.experimental")
def register_swarm(
*,
realtime_agent: "RealtimeAgent",
initial_agent: ConversableAgent,
agents: list[ConversableAgent],
system_message: Optional[str] = None,
question_message: Optional[str] = None,
) -> None:
"""Create a SwarmableRealtimeAgent.
Args:
realtime_agent (RealtimeAgent): The RealtimeAgent to create the SwarmableRealtimeAgent from.
initial_agent (ConversableAgent): The initial agent.
agents (list[ConversableAgent]): The agents in the swarm.
system_message (Optional[str]): The system message to set for the agent. If None, the default system message is used.
question_message (Optional[str]): The question message to set for the agent. If None, the default QUESTION_MESSAGE is used.
"""
swarmable_agent = SwarmableRealtimeAgent(
realtime_agent=realtime_agent, initial_agent=initial_agent, agents=agents, question_message=question_message
)
swarmable_agent.configure_realtime_agent(system_message=system_message)

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncIterator
from typing import Any, Protocol, runtime_checkable
__all__ = ["WebSocketProtocol"]
@runtime_checkable
class WebSocketProtocol(Protocol):
"""WebSocket protocol for sending and receiving JSON data modelled after FastAPI's WebSocket."""
async def send_json(self, data: Any, mode: str = "text") -> None: ...
async def receive_json(self, mode: str = "text") -> Any: ...
async def receive_text(self) -> str: ...
def iter_text(self) -> AsyncIterator[str]: ...

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from ..realtime.experimental import (
FunctionObserver,
RealtimeAgent,
RealtimeObserver,
TwilioAudioAdapter,
WebSocketAudioAdapter,
register_swarm,
)
__all__ = [
"FunctionObserver",
"RealtimeAgent",
"RealtimeObserver",
"TwilioAudioAdapter",
"WebSocketAudioAdapter",
"register_swarm",
]

View File

@@ -0,0 +1,111 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import Any, Callable, Literal, Optional, Union
from ..doc_utils import export_module
from ..llm_config import LLMConfig
from ..runtime_logging import log_new_agent, logging_enabled
from .conversable_agent import ConversableAgent
@export_module("autogen")
class UserProxyAgent(ConversableAgent):
"""(In preview) A proxy agent for the user, that can execute code and provide feedback to the other agents.
UserProxyAgent is a subclass of ConversableAgent configured with `human_input_mode` to ALWAYS
and `llm_config` to False. By default, the agent will prompt for human input every time a message is received.
Code execution is enabled by default. LLM-based auto reply is disabled by default.
To modify auto reply, register a method with [`register_reply`](../ConversableAgent#register-reply).
To modify the way to get human input, override `get_human_input` method.
To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`,
`run_code`, and `execute_function` methods respectively.
"""
# Default UserProxyAgent.description values, based on human_input_mode
DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS = {
"ALWAYS": "An attentive HUMAN user who can answer questions about the task, and can perform tasks such as running Python code or inputting command line commands at a Linux terminal and reporting back the execution results.",
"TERMINATE": "A user that can run Python code or input command line commands at a Linux terminal and report back the execution results.",
"NEVER": "A computer terminal that performs no other action than running Python scripts (provided to it quoted in ```python code blocks), or sh shell scripts (provided to it quoted in ```sh code blocks).",
}
def __init__(
self,
name: str,
is_termination_msg: Optional[Callable[[dict[str, Any]], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Literal["ALWAYS", "TERMINATE", "NEVER"] = "ALWAYS",
function_map: Optional[dict[str, Callable[..., Any]]] = None,
code_execution_config: Union[dict[str, Any], Literal[False]] = {},
default_auto_reply: Optional[Union[str, dict[str, Any]]] = "",
llm_config: Optional[Union[LLMConfig, dict[str, Any], Literal[False]]] = False,
system_message: Optional[Union[str, list[str]]] = "",
description: Optional[str] = None,
**kwargs: Any,
):
"""Args:
name (str): name of the agent.
is_termination_msg (function): a function that takes a message in the form of a dictionary
and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call".
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
The limit only plays a role when human_input_mode is not "ALWAYS".
human_input_mode (str): whether to ask for human inputs every time a message is received.
Possible values are "ALWAYS", "TERMINATE", "NEVER".
(1) When "ALWAYS", the agent prompts for human input every time a message is received.
Under this mode, the conversation stops when the human input is "exit",
or when is_termination_msg is True and there is no human input.
(2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or
the number of auto reply reaches the max_consecutive_auto_reply.
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
code_execution_config (dict or False): config for the code execution.
To disable code execution, set to False. Otherwise, set to a dictionary with the following keys:
- work_dir (Optional, str): The working directory for the code execution.
If None, a default working directory will be used.
The default working directory is the "extensions" directory under
"path_to_autogen".
- use_docker (Optional, list, str or bool): The docker image to use for code execution.
Default is True, which means the code will be executed in a docker container. A default list of images will be used.
If a list or a str of image name(s) is provided, the code will be executed in a docker container
with the first image successfully pulled.
If False, the code will be executed in the current environment.
We strongly recommend using docker for code execution.
- timeout (Optional, int): The maximum execution time in seconds.
- last_n_messages (Experimental, Optional, int): The number of messages to look back for code execution. Default to 1.
default_auto_reply (str or dict or None): the default auto reply message when no code execution or llm based reply is generated.
llm_config (LLMConfig or dict or False or None): llm inference configuration.
Please refer to [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create)
for available options.
Default to False, which disables llm-based auto reply.
When set to None, will use self.DEFAULT_CONFIG, which defaults to False.
system_message (str or List): system message for ChatCompletion inference.
Only used when llm_config is not False. Use it to reprogram the agent.
description (str): a short description of the agent. This description is used by other agents
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](https://docs.ag2.ai/latest/docs/api-reference/autogen/ConversableAgent).
"""
super().__init__(
name=name,
system_message=system_message,
is_termination_msg=is_termination_msg,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
function_map=function_map,
code_execution_config=code_execution_config,
llm_config=llm_config,
default_auto_reply=default_auto_reply,
description=(
description if description is not None else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode]
),
**kwargs,
)
if logging_enabled():
log_new_agent(self, locals())

View File

@@ -0,0 +1,206 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import re
from typing import Any, Optional, Union
from ..doc_utils import export_module
from .agent import Agent
def consolidate_chat_info(
chat_info: Union[dict[str, Any], list[dict[str, Any]]], uniform_sender: Optional[Agent] = None
) -> None:
if isinstance(chat_info, dict):
chat_info = [chat_info]
for c in chat_info:
if uniform_sender is None:
assert "sender" in c, "sender must be provided."
sender = c["sender"]
else:
sender = uniform_sender
assert "recipient" in c, "recipient must be provided."
summary_method = c.get("summary_method")
assert (
summary_method is None or callable(summary_method) or summary_method in ("last_msg", "reflection_with_llm")
), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None."
if summary_method == "reflection_with_llm":
assert sender.client is not None or c["recipient"].client is not None, (
"llm client must be set in either the recipient or sender when summary_method is reflection_with_llm."
)
@export_module("autogen")
def gather_usage_summary(agents: list[Agent]) -> dict[str, dict[str, Any]]:
r"""Gather usage summary from all agents.
Args:
agents: (list): List of agents.
Returns:
dictionary: A dictionary containing two keys:
- "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
- "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
Example:
```python
{
"usage_including_cached_inference": {
"total_cost": 0.0006090000000000001,
"gpt-35-turbo": {
"cost": 0.0006090000000000001,
"prompt_tokens": 242,
"completion_tokens": 123,
"total_tokens": 365,
},
},
"usage_excluding_cached_inference": {
"total_cost": 0.0006090000000000001,
"gpt-35-turbo": {
"cost": 0.0006090000000000001,
"prompt_tokens": 242,
"completion_tokens": 123,
"total_tokens": 365,
},
},
}
```
Note:
If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`.
"""
def aggregate_summary(usage_summary: dict[str, Any], agent_summary: dict[str, Any]) -> None:
if agent_summary is None:
return
usage_summary["total_cost"] += agent_summary.get("total_cost", 0)
for model, data in agent_summary.items():
if model != "total_cost":
if model not in usage_summary:
usage_summary[model] = data.copy()
else:
usage_summary[model]["cost"] += data.get("cost", 0)
usage_summary[model]["prompt_tokens"] += data.get("prompt_tokens", 0)
usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0)
usage_summary[model]["total_tokens"] += data.get("total_tokens", 0)
usage_including_cached_inference = {"total_cost": 0}
usage_excluding_cached_inference = {"total_cost": 0}
for agent in agents:
if getattr(agent, "client", None):
aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary) # type: ignore[attr-defined]
aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary) # type: ignore[attr-defined]
return {
"usage_including_cached_inference": usage_including_cached_inference,
"usage_excluding_cached_inference": usage_excluding_cached_inference,
}
def parse_tags_from_content(tag: str, content: Union[str, list[dict[str, Any]]]) -> list[dict[str, Any]]:
"""Parses HTML style tags from message contents.
The parsing is done by looking for patterns in the text that match the format of HTML tags. The tag to be parsed is
specified as an argument to the function. The function looks for this tag in the text and extracts its content. The
content of a tag is everything that is inside the tag, between the opening and closing angle brackets. The content
can be a single string or a set of attribute-value pairs.
Examples:
`<img http://example.com/image.png> -> [{"tag": "img", "attr": {"src": "http://example.com/image.png"}, "match": re.Match}]`
```<audio text="Hello I'm a robot" prompt="whisper"> ->
[{"tag": "audio", "attr": {"text": "Hello I'm a robot", "prompt": "whisper"}, "match": re.Match}]```
Args:
tag (str): The HTML style tag to be parsed.
content (Union[str, list[dict[str, Any]]]): The message content to parse. Can be a string or a list of content
items.
Returns:
list[dict[str, str]]: A list of dictionaries, where each dictionary represents a parsed tag. Each dictionary
contains three key-value pairs: 'type' which is the tag, 'attr' which is a dictionary of the parsed attributes,
and 'match' which is a regular expression match object.
Raises:
ValueError: If the content is not a string or a list.
"""
results = []
if isinstance(content, str):
results.extend(_parse_tags_from_text(tag, content))
# Handles case for multimodal messages.
elif isinstance(content, list):
for item in content:
if item.get("type") == "text":
results.extend(_parse_tags_from_text(tag, item["text"]))
else:
raise ValueError(f"content must be str or list, but got {type(content)}")
return results
def _parse_tags_from_text(tag: str, text: str) -> list[dict[str, Any]]:
pattern = re.compile(f"<{tag} (.*?)>")
results = []
for match in re.finditer(pattern, text):
tag_attr = match.group(1).strip()
attr = _parse_attributes_from_tags(tag_attr)
results.append({"tag": tag, "attr": attr, "match": match})
return results
def _parse_attributes_from_tags(tag_content: str) -> dict[str, str]:
pattern = r"([^ ]+)"
attrs = re.findall(pattern, tag_content)
reconstructed_attrs = _reconstruct_attributes(attrs)
def _append_src_value(content: dict[str, str], value: Any) -> None:
if "src" in content:
content["src"] += f" {value}"
else:
content["src"] = value
content: dict[str, str] = {}
for attr in reconstructed_attrs:
if "=" not in attr:
_append_src_value(content, attr)
continue
key, value = attr.split("=", 1)
if value.startswith("'") or value.startswith('"'):
content[key] = value[1:-1] # remove quotes
else:
_append_src_value(content, attr)
return content
def _reconstruct_attributes(attrs: list[str]) -> list[str]:
"""Reconstructs attributes from a list of strings where some attributes may be split across multiple elements."""
def is_attr(attr: str) -> bool:
if "=" in attr:
_, value = attr.split("=", 1)
if value.startswith("'") or value.startswith('"'):
return True
return False
reconstructed = []
found_attr = False
for attr in attrs:
if is_attr(attr):
reconstructed.append(attr)
found_attr = True
else:
if found_attr:
reconstructed[-1] += f" {attr}"
found_attr = True
elif reconstructed:
reconstructed[-1] += f" {attr}"
else:
reconstructed.append(attr)
return reconstructed

View File

@@ -0,0 +1,596 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import logging
import os
import pathlib
import re
import string
import subprocess
import sys
import time
import venv
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from hashlib import md5
from types import SimpleNamespace
from typing import Callable, Optional, Union
import docker
from .types import UserMessageImageContentPart, UserMessageTextContentPart
SENTINEL = object()
DEFAULT_MODEL = "gpt-4"
FAST_MODEL = "gpt-3.5-turbo"
# Regular expression for finding a code block
# ```[ \t]*(\w+)?[ \t]*\r?\n(.*?)[ \t]*\r?\n``` Matches multi-line code blocks.
# The [ \t]* matches the potential spaces before language name.
# The (\w+)? matches the language, where the ? indicates it is optional.
# The [ \t]* matches the potential spaces (not newlines) after language name.
# The \r?\n makes sure there is a linebreak after ```.
# The (.*?) matches the code itself (non-greedy).
# The \r?\n makes sure there is a linebreak before ```.
# The [ \t]* matches the potential spaces before closing ``` (the spec allows indentation).
CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```"
WORKING_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extensions")
UNKNOWN = "unknown"
TIMEOUT_MSG = "Timeout"
DEFAULT_TIMEOUT = 600
WIN32 = sys.platform == "win32"
PATH_SEPARATOR = (WIN32 and "\\") or "/"
PYTHON_VARIANTS = ["python", "Python", "py"]
logger = logging.getLogger(__name__)
def content_str(content: Union[str, list[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None]) -> str:
"""Converts the `content` field of an OpenAI message into a string format.
This function processes content that may be a string, a list of mixed text and image URLs, or None,
and converts it into a string. Text is directly appended to the result string, while image URLs are
represented by a placeholder image token. If the content is None, an empty string is returned.
Args:
content: The content to be processed. Can be a string, a list of dictionaries representing text and image URLs, or None.
Returns:
str: A string representation of the input content. Image URLs are replaced with an image token.
Note:
- The function expects each dictionary in the list to have a "type" key that is either "text" or "image_url".
For "text" type, the "text" key's value is appended to the result. For "image_url", an image token is appended.
- This function is useful for handling content that may include both text and image references, especially
in contexts where images need to be represented as placeholders.
"""
if content is None:
return ""
if isinstance(content, str):
return content
if not isinstance(content, list):
raise TypeError(f"content must be None, str, or list, but got {type(content)}")
rst = ""
for item in content:
if not isinstance(item, dict):
raise TypeError("Wrong content format: every element should be dict if the content is a list.")
assert "type" in item, "Wrong content format. Missing 'type' key in content's dict."
if item["type"] == "text":
rst += item["text"]
elif item["type"] in ["input_image", "image_url"]:
rst += "<image>"
else:
raise ValueError(f"Wrong content format: unknown type {item['type']} within the content")
return rst
def infer_lang(code: str) -> str:
"""Infer the language for the code.
TODO: make it robust.
"""
if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
return "sh"
# check if code is a valid python code
try:
compile(code, "test", "exec")
return "python"
except SyntaxError:
# not a valid python code
return UNKNOWN
# TODO: In the future move, to better support https://spec.commonmark.org/0.30/#fenced-code-blocks
# perhaps by using a full Markdown parser.
def extract_code(
text: Union[str, list], pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False
) -> list[tuple[str, str]]:
"""Extract code from a text.
Args:
text (str or List): The content to extract code from. The content can be
a string or a list, as returned by standard GPT or multimodal GPT.
pattern (str, optional): The regular expression pattern for finding the
code block. Defaults to CODE_BLOCK_PATTERN.
detect_single_line_code (bool, optional): Enable the new feature for
extracting single line code. Defaults to False.
Returns:
list: A list of tuples, each containing the language and the code.
If there is no code block in the input text, the language would be "unknown".
If there is code block but the language is not specified, the language would be "".
"""
text = content_str(text)
if not detect_single_line_code:
match = re.findall(pattern, text, flags=re.DOTALL)
return match if match else [(UNKNOWN, text)]
# Extract both multi-line and single-line code block, separated by the | operator
# `([^`]+)`: Matches inline code.
code_pattern = re.compile(CODE_BLOCK_PATTERN + r"|`([^`]+)`")
code_blocks = code_pattern.findall(text)
# Extract the individual code blocks and languages from the matched groups
extracted = []
for lang, group1, group2 in code_blocks:
if group1:
extracted.append((lang.strip(), group1.strip()))
elif group2:
extracted.append(("", group2.strip()))
return extracted
def timeout_handler(signum, frame):
raise TimeoutError("Timed out!")
def get_powershell_command():
try:
result = subprocess.run(["powershell", "$PSVersionTable.PSVersion.Major"], capture_output=True, text=True)
if result.returncode == 0:
return "powershell"
except (FileNotFoundError, NotADirectoryError):
# This means that 'powershell' command is not found so now we try looking for 'pwsh'
try:
result = subprocess.run(
["pwsh", "-Command", "$PSVersionTable.PSVersion.Major"], capture_output=True, text=True
)
if result.returncode == 0:
return "pwsh"
except FileExistsError as e:
raise FileNotFoundError(
"Neither powershell.exe nor pwsh.exe is present in the system. "
"Please install PowerShell and try again. "
) from e
except NotADirectoryError as e:
raise NotADirectoryError(
"PowerShell is either not installed or its path is not given "
"properly in the environment variable PATH. Please check the "
"path and try again. "
) from e
except PermissionError as e:
raise PermissionError("No permission to run powershell.") from e
def _cmd(lang: str) -> str:
if lang in PYTHON_VARIANTS:
return "python"
if lang.startswith("python") or lang in ["bash", "sh"]:
return lang
if lang in ["shell"]:
return "sh"
if lang == "javascript":
return "node"
if lang in ["ps1", "pwsh", "powershell"]:
powershell_command = get_powershell_command()
return powershell_command
raise NotImplementedError(f"{lang} not recognized in code execution")
def is_docker_running() -> bool:
"""Check if docker is running.
Returns:
bool: True if docker is running; False otherwise.
"""
try:
client = docker.from_env()
client.ping()
return True
except docker.errors.DockerException:
return False
def in_docker_container() -> bool:
"""Check if the code is running in a docker container.
Returns:
bool: True if the code is running in a docker container; False otherwise.
"""
return os.path.exists("/.dockerenv")
def decide_use_docker(use_docker: Optional[bool]) -> Optional[bool]:
if use_docker is None:
env_var_use_docker = os.environ.get("AUTOGEN_USE_DOCKER", "True")
truthy_values = {"1", "true", "yes", "t"}
falsy_values = {"0", "false", "no", "f"}
# Convert the value to lowercase for case-insensitive comparison
env_var_use_docker_lower = env_var_use_docker.lower()
# Determine the boolean value based on the environment variable
if env_var_use_docker_lower in truthy_values:
use_docker = True
elif env_var_use_docker_lower in falsy_values:
use_docker = False
elif env_var_use_docker_lower == "none": # Special case for 'None' as a string
use_docker = None
else:
# Raise an error for any unrecognized value
raise ValueError(
f'Invalid value for AUTOGEN_USE_DOCKER: {env_var_use_docker}. Please set AUTOGEN_USE_DOCKER to "1/True/yes", "0/False/no", or "None".'
)
return use_docker
def check_can_use_docker_or_throw(use_docker) -> None:
if use_docker is not None:
inside_docker = in_docker_container()
docker_installed_and_running = is_docker_running()
if use_docker and not inside_docker and not docker_installed_and_running:
raise RuntimeError(
"Code execution is set to be run in docker (default behaviour) but docker is not running.\n"
"The options available are:\n"
"- Make sure docker is running (advised approach for code execution)\n"
'- Set "use_docker": False in code_execution_config\n'
'- Set AUTOGEN_USE_DOCKER to "0/False/no" in your environment variables'
)
def _sanitize_filename_for_docker_tag(filename: str) -> str:
"""Convert a filename to a valid docker tag.
See https://docs.docker.com/engine/reference/commandline/tag/ for valid tag
format.
Args:
filename (str): The filename to be converted.
Returns:
str: The sanitized Docker tag.
"""
# Replace any character not allowed with an underscore
allowed_chars = set(string.ascii_letters + string.digits + "_.-")
sanitized = "".join(char if char in allowed_chars else "_" for char in filename)
# Ensure it does not start with a period or a dash
if sanitized.startswith(".") or sanitized.startswith("-"):
sanitized = "_" + sanitized[1:]
# Truncate if longer than 128 characters
return sanitized[:128]
def execute_code(
code: Optional[str] = None,
timeout: Optional[int] = None,
filename: Optional[str] = None,
work_dir: Optional[str] = None,
use_docker: Union[list[str], str, bool] = SENTINEL,
lang: Optional[str] = "python",
) -> tuple[int, str, Optional[str]]:
"""Execute code in a docker container.
This function is not tested on MacOS.
Args:
code (Optional, str): The code to execute.
If None, the code from the file specified by filename will be executed.
Either code or filename must be provided.
timeout (Optional, int): The maximum execution time in seconds.
If None, a default timeout will be used. The default timeout is 600 seconds. On Windows, the timeout is not enforced when use_docker=False.
filename (Optional, str): The file name to save the code or where the code is stored when `code` is None.
If None, a file with a randomly generated name will be created.
The randomly generated file will be deleted after execution.
The file name must be a relative path. Relative paths are relative to the working directory.
work_dir (Optional, str): The working directory for the code execution.
If None, a default working directory will be used.
The default working directory is the "extensions" directory under
"path_to_autogen".
use_docker (list, str or bool): The docker image to use for code execution.
Default is True, which means the code will be executed in a docker container. A default list of images will be used.
If a list or a str of image name(s) is provided, the code will be executed in a docker container
with the first image successfully pulled.
If False, the code will be executed in the current environment.
Expected behaviour:
- If `use_docker` is not set (i.e. left default to True) or is explicitly set to True and the docker package is available, the code will run in a Docker container.
- If `use_docker` is not set (i.e. left default to True) or is explicitly set to True but the Docker package is missing or docker isn't running, an error will be raised.
- If `use_docker` is explicitly set to False, the code will run natively.
If the code is executed in the current environment,
the code must be trusted.
lang (Optional, str): The language of the code. Default is "python".
Returns:
int: 0 if the code executes successfully.
str: The error message if the code fails to execute; the stdout otherwise.
image: The docker image name after container run when docker is used.
"""
if all((code is None, filename is None)):
error_msg = f"Either {code=} or {filename=} must be provided."
logger.error(error_msg)
raise AssertionError(error_msg)
running_inside_docker = in_docker_container()
docker_running = is_docker_running()
# SENTINEL is used to indicate that the user did not explicitly set the argument
if use_docker is SENTINEL:
use_docker = decide_use_docker(use_docker=None)
check_can_use_docker_or_throw(use_docker)
timeout = timeout or DEFAULT_TIMEOUT
original_filename = filename
if WIN32 and lang in ["sh", "shell"] and (not use_docker):
lang = "ps1"
if filename is None:
code_hash = md5(code.encode()).hexdigest()
# create a file with a automatically generated name
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
if work_dir is None:
work_dir = WORKING_DIR
filepath = os.path.join(work_dir, filename)
file_dir = os.path.dirname(filepath)
os.makedirs(file_dir, exist_ok=True)
if code is not None:
with open(filepath, "w", encoding="utf-8") as fout:
fout.write(code)
if not use_docker or running_inside_docker:
# already running in a docker container
cmd = [
sys.executable if lang.startswith("python") else _cmd(lang),
f".\\{filename}" if WIN32 else filename,
]
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(
subprocess.run,
cmd,
cwd=work_dir,
capture_output=True,
text=True,
)
try:
result = future.result(timeout=timeout)
except TimeoutError:
if original_filename is None:
os.remove(filepath)
return 1, TIMEOUT_MSG, None
if original_filename is None:
os.remove(filepath)
if result.returncode:
logs = result.stderr
if original_filename is None:
abs_path = str(pathlib.Path(filepath).absolute())
logs = logs.replace(str(abs_path), "").replace(filename, "")
else:
abs_path = str(pathlib.Path(work_dir).absolute()) + PATH_SEPARATOR
logs = logs.replace(str(abs_path), "")
else:
logs = result.stdout
return result.returncode, logs, None
# create a docker client
if use_docker and not docker_running:
raise RuntimeError(
"Docker package is missing or docker is not running. Please make sure docker is running or set use_docker=False."
)
client = docker.from_env()
image_list = (
["python:3-slim", "python:3", "python:3-windowsservercore"]
if use_docker is True
else [use_docker]
if isinstance(use_docker, str)
else use_docker
)
for image in image_list:
# check if the image exists
try:
client.images.get(image)
break
except docker.errors.ImageNotFound:
# pull the image
print("Pulling image", image)
try:
client.images.pull(image)
break
except docker.errors.DockerException:
print("Failed to pull image", image)
# get a randomized str based on current time to wrap the exit code
exit_code_str = f"exitcode{time.time()}"
abs_path = pathlib.Path(work_dir).absolute()
cmd = [
"sh",
"-c",
f'{_cmd(lang)} "{filename}"; exit_code=$?; echo -n {exit_code_str}; echo -n $exit_code; echo {exit_code_str}',
]
# create a docker container
container = client.containers.run(
image,
command=cmd,
working_dir="/workspace",
detach=True,
# get absolute path to the working directory
volumes={abs_path: {"bind": "/workspace", "mode": "rw"}},
)
start_time = time.time()
while container.status != "exited" and time.time() - start_time < timeout:
# Reload the container object
container.reload()
if container.status != "exited":
container.stop()
container.remove()
if original_filename is None:
os.remove(filepath)
return 1, TIMEOUT_MSG, image
# get the container logs
logs = container.logs().decode("utf-8").rstrip()
# commit the image
tag = _sanitize_filename_for_docker_tag(filename)
container.commit(repository="python", tag=tag)
# remove the container
container.remove()
# check if the code executed successfully
exit_code = container.attrs["State"]["ExitCode"]
if exit_code == 0:
# extract the exit code from the logs
pattern = re.compile(f"{exit_code_str}(\\d+){exit_code_str}")
match = pattern.search(logs)
exit_code = 1 if match is None else int(match.group(1))
# remove the exit code from the logs
logs = logs if match is None else pattern.sub("", logs)
if original_filename is None:
os.remove(filepath)
if exit_code:
logs = logs.replace(f"/workspace/{filename if original_filename is None else ''}", "")
# return the exit code, logs and image
return exit_code, logs, f"python:{tag}"
_GENERATE_ASSERTIONS_CONFIG = {
"prompt": """Given the signature and docstring, write the exactly same number of assertion(s) for the provided example(s) in the docstring, without assertion messages.
func signature:
{definition}
assertions:""",
"model": FAST_MODEL,
"max_tokens": 256,
"stop": "\n\n",
}
def _remove_check(response):
"""Remove the check function from the response."""
# find the position of the check function
pos = response.find("def check(")
if pos == -1:
return response
return response[:pos]
def eval_function_completions(
responses: list[str],
definition: str,
test: Optional[str] = None,
entry_point: Optional[str] = None,
assertions: Optional[Union[str, Callable[[str], tuple[str, float]]]] = None,
timeout: Optional[float] = 3,
use_docker: Optional[bool] = True,
) -> dict:
"""`(openai<1)` Select a response from a list of responses for the function completion task (using generated assertions), and/or evaluate if the task is successful using a gold test.
Args:
responses: The list of responses.
definition: The input definition.
test: The test code.
entry_point: The name of the function.
assertions: The assertion code which serves as a filter of the responses, or an assertion generator.
When provided, only the responses that pass the assertions will be considered for the actual test (if provided).
timeout: The timeout for executing the code.
use_docker: Whether to use docker for code execution.
Returns:
dict: The success metrics.
"""
n = len(responses)
if assertions is None:
# no assertion filter
success_list = []
for i in range(n):
response = _remove_check(responses[i])
code = (
f"{response}\n{test}\ncheck({entry_point})"
if response.startswith("def")
else f"{definition}{response}\n{test}\ncheck({entry_point})"
)
success = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0
success_list.append(success)
return {
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
"success": any(s for s in success_list),
}
if callable(assertions) and n > 1:
# assertion generator
assertions, gen_cost = assertions(definition)
else:
assertions, gen_cost = None, 0
if n > 1 or test is None:
for i in range(n):
response = responses[i] = _remove_check(responses[i])
code = (
f"{response}\n{assertions}" if response.startswith("def") else f"{definition}{response}\n{assertions}"
)
succeed_assertions = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0
if succeed_assertions:
break
else:
# just test, no need to check assertions
succeed_assertions = False
i, response = 0, responses[0]
if test is None:
# no test code
return {
"index_selected": i,
"succeed_assertions": succeed_assertions,
"gen_cost": gen_cost,
"assertions": assertions,
}
code_test = (
f"{response}\n{test}\ncheck({entry_point})"
if response.startswith("def")
else f"{definition}{response}\n{test}\ncheck({entry_point})"
)
success = execute_code(code_test, timeout=timeout, use_docker=use_docker)[0] == 0
return {
"index_selected": i,
"succeed_assertions": succeed_assertions,
"success": success,
"gen_cost": gen_cost,
"assertions": assertions,
}
_FUNC_COMPLETION_PROMPT = "# Python 3{definition}"
_FUNC_COMPLETION_STOP = ["\nclass", "\ndef", "\nif", "\nprint"]
_IMPLEMENT_CONFIGS = [
{"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "cache_seed": 0},
{"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 7, "cache_seed": 0},
{"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "cache_seed": 1},
{"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 2, "cache_seed": 2},
{"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 1, "cache_seed": 2},
]
def create_virtual_env(dir_path: str, **env_args) -> SimpleNamespace:
"""Creates a python virtual environment and returns the context.
Args:
dir_path (str): Directory path where the env will be created.
**env_args: Any extra args to pass to the `EnvBuilder`
Returns:
SimpleNamespace: the virtual env context object.
"""
if not env_args:
env_args = {"with_pip": True}
env_builder = venv.EnvBuilder(**env_args)
env_builder.create(dir_path)
return env_builder.ensure_directories(dir_path)

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Original portions of this file are derived from https://github.com/microsoft/autogen under the MIT License.
# SPDX-License-Identifier: MIT
from .base import CodeBlock, CodeExecutor, CodeExtractor, CodeResult
from .docker_commandline_code_executor import DockerCommandLineCodeExecutor
from .factory import CodeExecutorFactory
from .local_commandline_code_executor import LocalCommandLineCodeExecutor
from .markdown_code_extractor import MarkdownCodeExtractor
__all__ = (
"CodeBlock",
"CodeExecutor",
"CodeExecutorFactory",
"CodeExtractor",
"CodeResult",
"DockerCommandLineCodeExecutor",
"LocalCommandLineCodeExecutor",
"MarkdownCodeExtractor",
)

View File

@@ -0,0 +1,119 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Literal, Optional, Protocol, TypedDict, Union, runtime_checkable
from pydantic import BaseModel, Field
from ..doc_utils import export_module
from ..types import UserMessageImageContentPart, UserMessageTextContentPart
__all__ = ("CodeBlock", "CodeExecutionConfig", "CodeExecutor", "CodeExtractor", "CodeResult")
@export_module("autogen.coding")
class CodeBlock(BaseModel):
"""(Experimental) A class that represents a code block."""
code: str = Field(description="The code to execute.")
language: str = Field(description="The language of the code.")
@export_module("autogen.coding")
class CodeResult(BaseModel):
"""(Experimental) A class that represents the result of a code execution."""
exit_code: int = Field(description="The exit code of the code execution.")
output: str = Field(description="The output of the code execution.")
@export_module("autogen.coding")
class CodeExtractor(Protocol):
"""(Experimental) A code extractor class that extracts code blocks from a message."""
def extract_code_blocks(
self, message: Optional[Union[str, list[Union[UserMessageTextContentPart, UserMessageImageContentPart]]]]
) -> list[CodeBlock]:
"""(Experimental) Extract code blocks from a message.
Args:
message (str): The message to extract code blocks from.
Returns:
List[CodeBlock]: The extracted code blocks.
"""
... # pragma: no cover
@runtime_checkable
@export_module("autogen.coding")
class CodeExecutor(Protocol):
"""(Experimental) A code executor class that executes code blocks and returns the result."""
@property
def code_extractor(self) -> CodeExtractor:
"""(Experimental) The code extractor used by this code executor."""
... # pragma: no cover
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CodeResult:
"""(Experimental) Execute code blocks and return the result.
This method should be implemented by the code executor.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
Returns:
CodeResult: The result of the code execution.
"""
... # pragma: no cover
def restart(self) -> None:
"""(Experimental) Restart the code executor.
This method should be implemented by the code executor.
This method is called when the agent is reset.
"""
... # pragma: no cover
class IPythonCodeResult(CodeResult):
"""(Experimental) A code result class for IPython code executor."""
output_files: list[str] = Field(
default_factory=list,
description="The list of files that the executed code blocks generated.",
)
CodeExecutionConfig = TypedDict(
"CodeExecutionConfig",
{
"executor": Union[Literal["ipython-embedded", "commandline-local"], CodeExecutor],
"last_n_messages": Union[int, Literal["auto"]],
"timeout": int,
"use_docker": Union[bool, str, list[str]],
"work_dir": str,
"ipython-embedded": Mapping[str, Any],
"commandline-local": Mapping[str, Any],
},
total=False,
)
class CommandLineCodeResult(CodeResult):
"""(Experimental) A code result class for command line code executor."""
code_file: Optional[str] = Field(
default=None,
description="The file that the executed code block was saved to.",
)

View File

@@ -0,0 +1,268 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from __future__ import annotations
import atexit
import logging
import sys
import uuid
from hashlib import md5
from pathlib import Path
from time import sleep
from types import TracebackType
from typing import Any, ClassVar, Optional, Union
import docker
from docker.errors import ImageNotFound
from ..code_utils import TIMEOUT_MSG, _cmd
from ..doc_utils import export_module
from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult
from .markdown_code_extractor import MarkdownCodeExtractor
from .utils import _get_file_name_from_content, silence_pip
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
def _wait_for_ready(container: Any, timeout: int = 60, stop_time: float = 0.1) -> None:
elapsed_time = 0.0
while container.status != "running" and elapsed_time < timeout:
sleep(stop_time)
elapsed_time += stop_time
container.reload()
continue
if container.status != "running":
raise ValueError("Container failed to start")
__all__ = ("DockerCommandLineCodeExecutor",)
@export_module("autogen.coding")
class DockerCommandLineCodeExecutor(CodeExecutor):
DEFAULT_EXECUTION_POLICY: ClassVar[dict[str, bool]] = {
"bash": True,
"shell": True,
"sh": True,
"pwsh": True,
"powershell": True,
"ps1": True,
"python": True,
"javascript": False,
"html": False,
"css": False,
}
LANGUAGE_ALIASES: ClassVar[dict[str, str]] = {"py": "python", "js": "javascript"}
def __init__(
self,
image: str = "python:3-slim",
container_name: Optional[str] = None,
timeout: int = 60,
work_dir: Optional[Union[Path, str]] = None,
bind_dir: Optional[Union[Path, str]] = None,
auto_remove: bool = True,
stop_container: bool = True,
execution_policies: Optional[dict[str, bool]] = None,
):
"""(Experimental) A code executor class that executes code through
a command line environment in a Docker container.
The executor first saves each code block in a file in the working
directory, and then executes the code file in the container.
The executor executes the code blocks in the order they are received.
Currently, the executor only supports Python and shell scripts.
For Python code, use the language "python" for the code block.
For shell scripts, use the language "bash", "shell", or "sh" for the code
block.
Args:
image: Docker image to use for code execution. Defaults to "python:3-slim".
container_name: Name of the Docker container which is created. If None, will autogenerate a name. Defaults to None.
timeout: The timeout for code execution. Defaults to 60.
work_dir: The working directory for the code execution. Defaults to Path(".").
bind_dir: The directory that will be bound to the code executor container. Useful for cases where you want to spawn
the container from within a container. Defaults to work_dir.
auto_remove: If true, will automatically remove the Docker container when it is stopped. Defaults to True.
stop_container: If true, will automatically stop the
container when stop is called, when the context manager exits or when
the Python process exits with atext. Defaults to True.
execution_policies: A dictionary mapping language names to boolean values that determine
whether code in that language should be executed. True means code in that language
will be executed, False means it will only be saved to a file. This overrides the
default execution policies. Defaults to None.
Raises:
ValueError: On argument error, or if the container fails to start.
"""
work_dir = work_dir if work_dir is not None else Path()
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
if isinstance(work_dir, str):
work_dir = Path(work_dir)
work_dir.mkdir(exist_ok=True)
if bind_dir is None:
bind_dir = work_dir
elif isinstance(bind_dir, str):
bind_dir = Path(bind_dir)
client = docker.from_env()
# Check if the image exists
try:
client.images.get(image)
except ImageNotFound:
logging.info(f"Pulling image {image}...")
# Let the docker exception escape if this fails.
client.images.pull(image)
if container_name is None:
container_name = f"autogen-code-exec-{uuid.uuid4()}"
# Start a container from the image, read to exec commands later
self._container = client.containers.create(
image,
name=container_name,
entrypoint="/bin/sh",
tty=True,
auto_remove=auto_remove,
volumes={str(bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"}},
working_dir="/workspace",
)
self._container.start()
_wait_for_ready(self._container)
def cleanup() -> None:
try:
container = client.containers.get(container_name)
container.stop()
except docker.errors.NotFound:
pass
atexit.unregister(cleanup)
if stop_container:
atexit.register(cleanup)
self._cleanup = cleanup
# Check if the container is running
if self._container.status != "running":
raise ValueError(f"Failed to start container from image {image}. Logs: {self._container.logs()}")
self._timeout = timeout
self._work_dir: Path = work_dir
self._bind_dir: Path = bind_dir
self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy()
if execution_policies is not None:
self.execution_policies.update(execution_policies)
@property
def timeout(self) -> int:
"""(Experimental) The timeout for code execution."""
return self._timeout
@property
def work_dir(self) -> Path:
"""(Experimental) The working directory for the code execution."""
return self._work_dir
@property
def bind_dir(self) -> Path:
"""(Experimental) The binding directory for the code execution container."""
return self._bind_dir
@property
def code_extractor(self) -> CodeExtractor:
"""(Experimental) Export a code extractor that can be used by an agent."""
return MarkdownCodeExtractor()
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult:
"""(Experimental) Execute the code blocks and return the result.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
Returns:
CommandlineCodeResult: The result of the code execution.
"""
if len(code_blocks) == 0:
raise ValueError("No code blocks to execute.")
outputs = []
files = []
last_exit_code = 0
for code_block in code_blocks:
lang = self.LANGUAGE_ALIASES.get(code_block.language.lower(), code_block.language.lower())
if lang not in self.DEFAULT_EXECUTION_POLICY:
outputs.append(f"Unsupported language {lang}\n")
last_exit_code = 1
break
execute_code = self.execution_policies.get(lang, False)
code = silence_pip(code_block.code, lang)
# Check if there is a filename comment
try:
filename = _get_file_name_from_content(code, self._work_dir)
except ValueError:
outputs.append("Filename is not in the workspace")
last_exit_code = 1
break
if not filename:
filename = f"tmp_code_{md5(code.encode()).hexdigest()}.{lang}"
code_path = self._work_dir / filename
with code_path.open("w", encoding="utf-8") as fout:
fout.write(code)
files.append(code_path)
if not execute_code:
outputs.append(f"Code saved to {code_path!s}\n")
continue
command = ["timeout", str(self._timeout), _cmd(lang), filename]
result = self._container.exec_run(command)
exit_code = result.exit_code
output = result.output.decode("utf-8")
if exit_code == 124:
output += "\n" + TIMEOUT_MSG
outputs.append(output)
last_exit_code = exit_code
if exit_code != 0:
break
code_file = str(files[0]) if files else None
return CommandLineCodeResult(exit_code=last_exit_code, output="".join(outputs), code_file=code_file)
def restart(self) -> None:
"""(Experimental) Restart the code executor."""
self._container.restart()
if self._container.status != "running":
raise ValueError(f"Failed to restart container. Logs: {self._container.logs()}")
def stop(self) -> None:
"""(Experimental) Stop the code executor."""
self._cleanup()
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.stop()

View File

@@ -0,0 +1,47 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from ..doc_utils import export_module
from .base import CodeExecutionConfig, CodeExecutor
__all__ = ("CodeExecutorFactory",)
@export_module("autogen.coding")
class CodeExecutorFactory:
"""(Experimental) A factory class for creating code executors."""
@staticmethod
def create(code_execution_config: CodeExecutionConfig) -> CodeExecutor:
"""(Experimental) Get a code executor based on the code execution config.
Args:
code_execution_config (Dict): The code execution config,
which is a dictionary that must contain the key "executor".
The value of the key "executor" can be either a string
or an instance of CodeExecutor, in which case the code
executor is returned directly.
Returns:
CodeExecutor: The code executor.
Raises:
ValueError: If the code executor is unknown or not specified.
"""
executor = code_execution_config.get("executor")
if isinstance(executor, CodeExecutor):
# If the executor is already an instance of CodeExecutor, return it.
return executor
if executor == "ipython-embedded":
from .jupyter.embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
return EmbeddedIPythonCodeExecutor(**code_execution_config.get("ipython-embedded", {}))
elif executor == "commandline-local":
from .local_commandline_code_executor import LocalCommandLineCodeExecutor
return LocalCommandLineCodeExecutor(**code_execution_config.get("commandline-local", {}))
else:
raise ValueError(f"Unknown code executor {executor}")

View File

@@ -0,0 +1,202 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import functools
import importlib
import inspect
from dataclasses import dataclass, field
from importlib.abc import SourceLoader
from textwrap import dedent, indent
from typing import Any, Callable, Generic, TypeVar, Union
from typing_extensions import ParamSpec
T = TypeVar("T")
P = ParamSpec("P")
def _to_code(func: Union["FunctionWithRequirements[T, P]", Callable[P, T], "FunctionWithRequirementsStr"]) -> str:
if isinstance(func, FunctionWithRequirementsStr):
return func.func
code = inspect.getsource(func)
# Strip the decorator
if code.startswith("@"):
code = code[code.index("\n") + 1 :]
return code
@dataclass
class Alias:
name: str
alias: str
@dataclass
class ImportFromModule:
module: str
imports: list[Union[str, Alias]]
Import = Union[str, ImportFromModule, Alias]
def _import_to_str(im: Import) -> str:
if isinstance(im, str):
return f"import {im}"
elif isinstance(im, Alias):
return f"import {im.name} as {im.alias}"
else:
def to_str(i: Union[str, Alias]) -> str:
if isinstance(i, str):
return i
else:
return f"{i.name} as {i.alias}"
imports = ", ".join(map(to_str, im.imports))
return f"from {im.module} import {imports}"
class _StringLoader(SourceLoader):
def __init__(self, data: str):
self.data = data
def get_source(self, fullname: str) -> str:
return self.data
def get_data(self, path: str) -> bytes:
return self.data.encode("utf-8")
def get_filename(self, fullname: str) -> str:
return "<not a real path>/" + fullname + ".py"
@dataclass
class FunctionWithRequirementsStr:
func: str
_compiled_func: Callable[..., Any]
_func_name: str
python_packages: list[str] = field(default_factory=list)
global_imports: list[Import] = field(default_factory=list)
def __init__(self, func: str, python_packages: list[str] = [], global_imports: list[Import] = []):
self.func = func
self.python_packages = python_packages
self.global_imports = global_imports
module_name = "func_module"
loader = _StringLoader(func)
spec = importlib.util.spec_from_loader(module_name, loader)
if spec is None:
raise ValueError("Could not create spec")
module = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ValueError("Could not create loader")
try:
spec.loader.exec_module(module)
except Exception as e:
raise ValueError(f"Could not compile function: {e}") from e
functions = inspect.getmembers(module, inspect.isfunction)
if len(functions) != 1:
raise ValueError("The string must contain exactly one function")
self._func_name, self._compiled_func = functions[0]
def __call__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("String based function with requirement objects are not directly callable")
@dataclass
class FunctionWithRequirements(Generic[T, P]):
func: Callable[P, T]
python_packages: list[str] = field(default_factory=list)
global_imports: list[Import] = field(default_factory=list)
@classmethod
def from_callable(
cls, func: Callable[P, T], python_packages: list[str] = [], global_imports: list[Import] = []
) -> "FunctionWithRequirements[T, P]":
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
@staticmethod
def from_str(
func: str, python_packages: list[str] = [], global_imports: list[Import] = []
) -> FunctionWithRequirementsStr:
return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)
# Type this based on F
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.func(*args, **kwargs)
def with_requirements(
python_packages: list[str] = [], global_imports: list[Import] = []
) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]:
"""Decorate a function with package and import requirements
Args:
python_packages (List[str], optional): Packages required to function. Can include version info.. Defaults to [].
global_imports (List[Import], optional): Required imports. Defaults to [].
Returns:
Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: The decorated function
"""
def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
func_with_reqs = FunctionWithRequirements(
python_packages=python_packages, global_imports=global_imports, func=func
)
functools.update_wrapper(func_with_reqs, func)
return func_with_reqs
return wrapper
def _build_python_functions_file(
funcs: list[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
) -> str:
# First collect all global imports
global_imports: set[str] = set()
for func in funcs:
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
global_imports.update(map(_import_to_str, func.global_imports))
content = "\n".join(global_imports) + "\n\n"
for func in funcs:
content += _to_code(func) + "\n\n"
return content
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
"""Generate a stub for a function as a string
Args:
func (Callable[..., Any]): The function to generate a stub for
Returns:
str: The stub for the function
"""
if isinstance(func, FunctionWithRequirementsStr):
return to_stub(func._compiled_func)
content = f"def {func.__name__}{inspect.signature(func)}:\n"
docstring = func.__doc__
if docstring:
docstring = dedent(docstring)
docstring = '"""' + docstring + '"""'
docstring = indent(docstring, " ")
content += docstring + "\n"
content += " ..."
return content

View File

@@ -0,0 +1,23 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Original portions of this file are derived from https://github.com/microsoft/autogen under the MIT License.
# SPDX-License-Identifier: MIT
from .base import JupyterConnectable, JupyterConnectionInfo
from .docker_jupyter_server import DockerJupyterServer
from .embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
from .jupyter_client import JupyterClient
from .jupyter_code_executor import JupyterCodeExecutor
from .local_jupyter_server import LocalJupyterServer
__all__ = [
"DockerJupyterServer",
"EmbeddedIPythonCodeExecutor",
"JupyterClient",
"JupyterCodeExecutor",
"JupyterConnectable",
"JupyterConnectionInfo",
"LocalJupyterServer",
]

View File

@@ -0,0 +1,36 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from dataclasses import dataclass
from typing import Optional, Protocol, runtime_checkable
from ...doc_utils import export_module
@dataclass
@export_module("autogen.coding.jupyter")
class JupyterConnectionInfo:
"""(Experimental)"""
host: str
"""`str` - Host of the Jupyter gateway server"""
use_https: bool
"""`bool` - Whether to use HTTPS"""
port: Optional[int] = None
"""`Optional[int]` - Port of the Jupyter gateway server. If None, the default port is used"""
token: Optional[str] = None
"""`Optional[str]` - Token for authentication. If None, no token is used"""
@runtime_checkable
@export_module("autogen.coding.jupyter")
class JupyterConnectable(Protocol):
"""(Experimental)"""
@property
def connection_info(self) -> JupyterConnectionInfo:
"""Return the connection information for this connectable."""
pass

View File

@@ -0,0 +1,167 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from __future__ import annotations
import atexit
import io
import logging
import secrets
import sys
import uuid
from pathlib import Path
from types import TracebackType
from typing import Optional
import docker
from ...doc_utils import export_module
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
from ..docker_commandline_code_executor import _wait_for_ready
from .base import JupyterConnectable, JupyterConnectionInfo
from .import_utils import require_jupyter_kernel_gateway_installed
from .jupyter_client import JupyterClient
@require_jupyter_kernel_gateway_installed()
@export_module("autogen.coding.jupyter")
class DockerJupyterServer(JupyterConnectable):
DEFAULT_DOCKERFILE = """FROM quay.io/jupyter/docker-stacks-foundation
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
USER ${NB_UID}
RUN mamba install --yes jupyter_kernel_gateway ipykernel && \
mamba clean --all -f -y && \
fix-permissions "${CONDA_DIR}" && \
fix-permissions "/home/${NB_USER}"
ENV TOKEN="UNSET"
CMD python -m jupyter kernelgateway --KernelGatewayApp.ip=0.0.0.0 \
--KernelGatewayApp.port=8888 \
--KernelGatewayApp.auth_token="${TOKEN}" \
--JupyterApp.answer_yes=true \
--JupyterWebsocketPersonality.list_kernels=true
EXPOSE 8888
WORKDIR "${HOME}"
"""
class GenerateToken:
pass
def __init__(
self,
*,
custom_image_name: Optional[str] = None,
container_name: Optional[str] = None,
auto_remove: bool = True,
stop_container: bool = True,
docker_env: dict[str, str] = {},
token: str | GenerateToken = GenerateToken(),
):
"""Start a Jupyter kernel gateway server in a Docker container.
Args:
custom_image_name (Optional[str], optional): Custom image to use. If this is None,
then the bundled image will be built and used. The default image is based on
quay.io/jupyter/docker-stacks-foundation and extended to include jupyter_kernel_gateway
container_name (Optional[str], optional): Name of the container to start.
A name will be generated if None.
auto_remove (bool, optional): If true the Docker container will be deleted
when it is stopped.
stop_container (bool, optional): If true the container will be stopped,
either by program exit or using the context manager
docker_env (Dict[str, str], optional): Extra environment variables to pass
to the running Docker container.
token (Union[str, GenerateToken], optional): Token to use for authentication.
If GenerateToken is used, a random token will be generated. Empty string
will be unauthenticated.
"""
if container_name is None:
container_name = f"autogen-jupyterkernelgateway-{uuid.uuid4()}"
client = docker.from_env()
if custom_image_name is None:
image_name = "autogen-jupyterkernelgateway"
# Make sure the image exists
try:
client.images.get(image_name)
except docker.errors.ImageNotFound:
# Build the image
# Get this script directory
here = Path(__file__).parent
dockerfile = io.BytesIO(self.DEFAULT_DOCKERFILE.encode("utf-8"))
logging.info(f"Image {image_name} not found. Building it now.")
client.images.build(path=here, fileobj=dockerfile, tag=image_name)
logging.info(f"Image {image_name} built successfully.")
else:
image_name = custom_image_name
# Check if the image exists
try:
client.images.get(image_name)
except docker.errors.ImageNotFound:
raise ValueError(f"Custom image {image_name} does not exist")
if isinstance(token, DockerJupyterServer.GenerateToken):
self._token = secrets.token_hex(32)
else:
self._token = token
# Run the container
env = {"TOKEN": self._token}
env.update(docker_env)
container = client.containers.run(
image_name,
detach=True,
auto_remove=auto_remove,
environment=env,
publish_all_ports=True,
name=container_name,
)
_wait_for_ready(container)
container_ports = container.ports
self._port = int(container_ports["8888/tcp"][0]["HostPort"])
self._container_id = container.id
def cleanup() -> None:
try:
inner_container = client.containers.get(container.id)
inner_container.stop()
except docker.errors.NotFound:
pass
atexit.unregister(cleanup)
if stop_container:
atexit.register(cleanup)
self._cleanup_func = cleanup
self._stop_container = stop_container
@property
def connection_info(self) -> JupyterConnectionInfo:
return JupyterConnectionInfo(host="127.0.0.1", use_https=False, port=self._port, token=self._token)
def stop(self) -> None:
self._cleanup_func()
def get_client(self) -> JupyterClient:
return JupyterClient(self.connection_info)
def __enter__(self) -> Self:
return self
def __exit__(
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
self.stop()

View File

@@ -0,0 +1,182 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import base64
import json
import os
import re
import uuid
from pathlib import Path
from queue import Empty
from typing import Any
from pydantic import BaseModel, Field, field_validator
from ...doc_utils import export_module
from ...import_utils import optional_import_block, require_optional_import
from ..base import CodeBlock, CodeExtractor, IPythonCodeResult
from ..markdown_code_extractor import MarkdownCodeExtractor
from .import_utils import require_jupyter_kernel_gateway_installed
with optional_import_block():
from jupyter_client import KernelManager # type: ignore[attr-defined]
from jupyter_client.kernelspec import KernelSpecManager
__all__ = ["EmbeddedIPythonCodeExecutor"]
@require_optional_import("jupyter_client", "jupyter-executor")
@require_jupyter_kernel_gateway_installed()
@export_module("autogen.coding.jupyter")
class EmbeddedIPythonCodeExecutor(BaseModel):
"""(Experimental) A code executor class that executes code statefully using an embedded
IPython kernel managed by this class.
**This will execute LLM generated code on the local machine.**
Each execution is stateful and can access variables created from previous
executions in the same session. The kernel must be installed before using
this class. The kernel can be installed using the following command:
`python -m ipykernel install --user --name {kernel_name}`
where `kernel_name` is the name of the kernel to install.
"""
timeout: int = Field(default=60, ge=1, description="The timeout for code execution.")
kernel_name: str = Field(default="python3", description="The kernel name to use. Make sure it is installed.")
output_dir: str = Field(default=".", description="The directory to save output files.")
@field_validator("output_dir")
@classmethod
def _output_dir_must_exist(cls, value: str) -> str:
if not os.path.exists(value):
raise ValueError(f"Output directory {value} does not exist.")
return value
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
# Check if the kernel is installed.
if self.kernel_name not in KernelSpecManager().find_kernel_specs():
raise ValueError(
f"Kernel {self.kernel_name} is not installed. "
"Please first install it with "
f"`python -m ipykernel install --user --name {self.kernel_name}`."
)
self._kernel_manager = KernelManager(kernel_name=self.kernel_name)
self._kernel_manager.start_kernel()
self._kernel_client = self._kernel_manager.client()
self._kernel_client.start_channels()
self._timeout = self.timeout
self._kernel_name = self.kernel_name
self._output_dir = Path(self.output_dir)
@property
def code_extractor(self) -> CodeExtractor:
"""(Experimental) Export a code extractor that can be used by an agent."""
return MarkdownCodeExtractor()
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> IPythonCodeResult:
"""(Experimental) Execute a list of code blocks and return the result.
This method executes a list of code blocks as cells in an IPython kernel
managed by this class.
See: https://jupyter-client.readthedocs.io/en/stable/messaging.html
for the message protocol.
Args:
code_blocks (List[CodeBlock]): A list of code blocks to execute.
Returns:
IPythonCodeResult: The result of the code execution.
"""
self._kernel_client.wait_for_ready()
outputs = []
output_files = []
for code_block in code_blocks:
code = self._process_code(code_block.code)
self._kernel_client.execute(code, store_history=True)
while True:
try:
msg = self._kernel_client.get_iopub_msg(timeout=self._timeout)
msg_type = msg["msg_type"]
content = msg["content"]
if msg_type in ["execute_result", "display_data"]:
for data_type, data in content["data"].items():
if data_type == "text/plain":
# Output is a text.
outputs.append(data)
elif data_type.startswith("image/"):
# Output is an image.
path = self._save_image(data)
outputs.append(f"Image data saved to {path}")
output_files.append(path)
elif data_type == "text/html":
# Output is an html.
path = self._save_html(data)
outputs.append(f"HTML data saved to {path}")
output_files.append(path)
else:
# Output raw data.
outputs.append(json.dumps(data))
elif msg_type == "stream":
# Output is a text.
outputs.append(content["text"])
elif msg_type == "error":
# Output is an error.
return IPythonCodeResult(
exit_code=1,
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}",
)
if msg_type == "status" and content["execution_state"] == "idle":
break
# handle time outs.
except Empty:
return IPythonCodeResult(
exit_code=1,
output=f"ERROR: Timeout waiting for output from code block: {code_block.code}",
)
# We return the full output.
return IPythonCodeResult(
exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files
)
def restart(self) -> None:
"""(Experimental) Restart a new session."""
self._kernel_client.stop_channels()
self._kernel_manager.shutdown_kernel()
self._kernel_manager = KernelManager(kernel_name=self.kernel_name)
self._kernel_manager.start_kernel()
self._kernel_client = self._kernel_manager.client()
self._kernel_client.start_channels()
def _save_image(self, image_data_base64: str) -> str:
"""Save image data to a file."""
image_data = base64.b64decode(image_data_base64)
# Randomly generate a filename.
filename = f"{uuid.uuid4().hex}.png"
path = os.path.join(self.output_dir, filename)
with open(path, "wb") as f:
f.write(image_data)
return os.path.abspath(path)
def _save_html(self, html_data: str) -> str:
"""Save html data to a file."""
# Randomly generate a filename.
filename = f"{uuid.uuid4().hex}.html"
path = os.path.join(self.output_dir, filename)
with open(path, "w") as f:
f.write(html_data)
return os.path.abspath(path)
def _process_code(self, code: str) -> str:
"""Process code before execution."""
# Find lines that start with `! pip install` and make sure "-qqq" flag is added.
lines = code.split("\n")
for i, line in enumerate(lines):
# use regex to find lines that start with `! pip install` or `!pip install`.
match = re.search(r"^! ?pip install", line)
if match is not None and "-qqq" not in line:
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
return "\n".join(lines)

View File

@@ -0,0 +1,82 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import subprocess
from functools import lru_cache
from logging import getLogger
from typing import Callable, TypeVar
from ...import_utils import patch_object
logger = getLogger(__name__)
__all__ = ["require_jupyter_kernel_gateway_installed", "skip_on_missing_jupyter_kernel_gateway"]
@lru_cache
def is_jupyter_kernel_gateway_installed() -> bool:
"""Check if jupyter-kernel-gateway is installed."""
try:
subprocess.run(
["jupyter", "kernelgateway", "--version"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
logger.warning(
"jupyter-kernel-gateway is required for JupyterCodeExecutor, please install it with `pip install ag2[jupyter-executor]`"
)
return False
T = TypeVar("T")
def require_jupyter_kernel_gateway_installed() -> Callable[[T], T]:
"""Decorator that checks if jupyter-kernel-gateway is installed before function execution.
Returns:
Callable[[T], T]: A decorator function that either:
- Returns the original function unchanged if jupyter-kernel-gateway is installed
- Returns a patched version of the function that will raise a helpful error indicating the missing dependency when called
"""
if is_jupyter_kernel_gateway_installed():
def decorator(o: T) -> T:
return o
else:
def decorator(o: T) -> T:
return patch_object(o, missing_modules={}, dep_target="jupyter-executor")
return decorator
def skip_on_missing_jupyter_kernel_gateway() -> Callable[[T], T]:
"""Decorator to skip a test if an optional module is missing"""
# Add pytest.mark.jupyter_executor decorator
mark_name = "jupyter_executor"
if is_jupyter_kernel_gateway_installed():
def decorator(o: T) -> T:
import pytest
pytest_mark_o = getattr(pytest.mark, mark_name)(o)
return pytest_mark_o # type: ignore[no-any-return]
else:
def decorator(o: T) -> T:
import pytest
pytest_mark_o = getattr(pytest.mark, mark_name)(o)
return pytest.mark.skip( # type: ignore[return-value,no-any-return]
reason="jupyter-kernel-gateway is required for JupyterCodeExecutor, please install it with `pip install ag2[jupyter-executor]`"
)(pytest_mark_o)
return decorator

View File

@@ -0,0 +1,231 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from __future__ import annotations
import datetime
import json
import sys
import uuid
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Optional, cast
from ...doc_utils import export_module
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
import requests
from requests.adapters import HTTPAdapter, Retry
from ...import_utils import optional_import_block, require_optional_import
from .base import JupyterConnectionInfo
with optional_import_block():
import websocket
from websocket import WebSocket
@export_module("autogen.coding.jupyter")
class JupyterClient:
def __init__(self, connection_info: JupyterConnectionInfo):
"""(Experimental) A client for communicating with a Jupyter gateway server.
Args:
connection_info (JupyterConnectionInfo): Connection information
"""
self._connection_info = connection_info
self._session = requests.Session()
retries = Retry(total=5, backoff_factor=0.1)
self._session.mount("http://", HTTPAdapter(max_retries=retries))
def _get_headers(self) -> dict[str, str]:
if self._connection_info.token is None:
return {}
return {"Authorization": f"token {self._connection_info.token}"}
def _get_api_base_url(self) -> str:
protocol = "https" if self._connection_info.use_https else "http"
port = f":{self._connection_info.port}" if self._connection_info.port else ""
return f"{protocol}://{self._connection_info.host}{port}"
def _get_ws_base_url(self) -> str:
port = f":{self._connection_info.port}" if self._connection_info.port else ""
return f"ws://{self._connection_info.host}{port}"
def list_kernel_specs(self) -> dict[str, dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
return cast(dict[str, dict[str, str]], response.json())
def list_kernels(self) -> list[dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers())
return cast(list[dict[str, str]], response.json())
def start_kernel(self, kernel_spec_name: str) -> str:
"""Start a new kernel.
Args:
kernel_spec_name (str): Name of the kernel spec to start
Returns:
str: ID of the started kernel
"""
response = self._session.post(
f"{self._get_api_base_url()}/api/kernels",
headers=self._get_headers(),
json={"name": kernel_spec_name},
)
return cast(str, response.json()["id"])
def delete_kernel(self, kernel_id: str) -> None:
response = self._session.delete(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers()
)
response.raise_for_status()
def restart_kernel(self, kernel_id: str) -> None:
response = self._session.post(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers()
)
response.raise_for_status()
@require_optional_import("websocket", "jupyter-executor")
def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient:
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels"
ws = websocket.create_connection(ws_url, header=self._get_headers())
return JupyterKernelClient(ws)
@require_optional_import("websocket", "jupyter-executor")
class JupyterKernelClient:
"""(Experimental) A client for communicating with a Jupyter kernel."""
@dataclass
class ExecutionResult:
@dataclass
class DataItem:
mime_type: str
data: str
is_ok: bool
output: str
data_items: list[DataItem]
def __init__(self, websocket: WebSocket): # type: ignore[no-any-unimported]
self._session_id: str = uuid.uuid4().hex
self._websocket: WebSocket = websocket # type: ignore[no-any-unimported]
def __enter__(self) -> Self:
return self
def __exit__(
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
self.stop()
def stop(self) -> None:
self._websocket.close()
def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str:
timestamp = datetime.datetime.now().isoformat()
message_id = uuid.uuid4().hex
message = {
"header": {
"username": "autogen",
"version": "5.0",
"session": self._session_id,
"msg_id": message_id,
"msg_type": message_type,
"date": timestamp,
},
"parent_header": {},
"channel": channel,
"content": content,
"metadata": {},
"buffers": {},
}
self._websocket.send_text(json.dumps(message))
return message_id
def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[dict[str, Any]]:
self._websocket.settimeout(timeout_seconds)
try:
data = self._websocket.recv()
if isinstance(data, bytes):
data = data.decode("utf-8")
return cast(dict[str, Any], json.loads(data))
except websocket.WebSocketTimeoutException:
return None
def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool:
message_id = self._send_message(content={}, channel="shell", message_type="kernel_info_request")
while True:
message = self._receive_message(timeout_seconds)
# This means we timed out with no new messages.
if message is None:
return False
if (
message.get("parent_header", {}).get("msg_id") == message_id
and message["msg_type"] == "kernel_info_reply"
):
return True
def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult:
message_id = self._send_message(
content={
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
channel="shell",
message_type="execute_request",
)
text_output = []
data_output = []
while True:
message = self._receive_message(timeout_seconds)
if message is None:
return JupyterKernelClient.ExecutionResult(
is_ok=False, output="ERROR: Timeout waiting for output from code block.", data_items=[]
)
# Ignore messages that are not for this execution.
if message.get("parent_header", {}).get("msg_id") != message_id:
continue
msg_type = message["msg_type"]
content = message["content"]
if msg_type in ["execute_result", "display_data"]:
for data_type, data in content["data"].items():
if data_type == "text/plain":
text_output.append(data)
elif data_type.startswith("image/") or data_type == "text/html":
data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data))
else:
text_output.append(json.dumps(data))
elif msg_type == "stream":
text_output.append(content["text"])
elif msg_type == "error":
# Output is an error.
return JupyterKernelClient.ExecutionResult(
is_ok=False,
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}",
data_items=[],
)
if msg_type == "status" and content["execution_state"] == "idle":
break
return JupyterKernelClient.ExecutionResult(
is_ok=True, output="\n".join([str(output) for output in text_output]), data_items=data_output
)

View File

@@ -0,0 +1,160 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import base64
import json
import os
import sys
import uuid
from pathlib import Path
from types import TracebackType
from typing import Optional, Union
from ...doc_utils import export_module
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
from ..base import CodeBlock, CodeExecutor, CodeExtractor, IPythonCodeResult
from ..markdown_code_extractor import MarkdownCodeExtractor
from ..utils import silence_pip
from .base import JupyterConnectable, JupyterConnectionInfo
from .jupyter_client import JupyterClient
@export_module("autogen.coding.jupyter")
class JupyterCodeExecutor(CodeExecutor):
def __init__(
self,
jupyter_server: Union[JupyterConnectable, JupyterConnectionInfo],
kernel_name: str = "python3",
timeout: int = 60,
output_dir: Union[Path, str] = Path(),
):
"""(Experimental) A code executor class that executes code statefully using
a Jupyter server supplied to this class.
Each execution is stateful and can access variables created from previous
executions in the same session.
Args:
jupyter_server (Union[JupyterConnectable, JupyterConnectionInfo]): The Jupyter server to use.
timeout (int): The timeout for code execution, by default 60.
kernel_name (str): The kernel name to use. Make sure it is installed.
By default, it is "python3".
output_dir (str): The directory to save output files, by default ".".
"""
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
if isinstance(output_dir, str):
output_dir = Path(output_dir)
if not output_dir.exists():
raise ValueError(f"Output directory {output_dir} does not exist.")
if isinstance(jupyter_server, JupyterConnectable):
self._connection_info = jupyter_server.connection_info
elif isinstance(jupyter_server, JupyterConnectionInfo):
self._connection_info = jupyter_server
else:
raise ValueError("jupyter_server must be a JupyterConnectable or JupyterConnectionInfo.")
self._jupyter_client = JupyterClient(self._connection_info)
available_kernels = self._jupyter_client.list_kernel_specs()
if kernel_name not in available_kernels["kernelspecs"]:
raise ValueError(f"Kernel {kernel_name} is not installed.")
self._kernel_id = self._jupyter_client.start_kernel(kernel_name)
self._kernel_name = kernel_name
self._jupyter_kernel_client = self._jupyter_client.get_kernel_client(self._kernel_id)
self._timeout = timeout
self._output_dir = output_dir
@property
def code_extractor(self) -> CodeExtractor:
"""(Experimental) Export a code extractor that can be used by an agent."""
return MarkdownCodeExtractor()
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> IPythonCodeResult:
"""(Experimental) Execute a list of code blocks and return the result.
This method executes a list of code blocks as cells in the Jupyter kernel.
See: https://jupyter-client.readthedocs.io/en/stable/messaging.html
for the message protocol.
Args:
code_blocks (List[CodeBlock]): A list of code blocks to execute.
Returns:
IPythonCodeResult: The result of the code execution.
"""
self._jupyter_kernel_client.wait_for_ready()
outputs = []
output_files = []
for code_block in code_blocks:
code = silence_pip(code_block.code, code_block.language)
result = self._jupyter_kernel_client.execute(code, timeout_seconds=self._timeout)
if result.is_ok:
outputs.append(result.output)
for data in result.data_items:
if data.mime_type == "image/png":
path = self._save_image(data.data)
outputs.append(f"Image data saved to {path}")
output_files.append(path)
elif data.mime_type == "text/html":
path = self._save_html(data.data)
outputs.append(f"HTML data saved to {path}")
output_files.append(path)
else:
outputs.append(json.dumps(data.data))
else:
return IPythonCodeResult(
exit_code=1,
output=f"ERROR: {result.output}",
)
return IPythonCodeResult(
exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files
)
def restart(self) -> None:
"""(Experimental) Restart a new session."""
self._jupyter_client.restart_kernel(self._kernel_id)
self._jupyter_kernel_client = self._jupyter_client.get_kernel_client(self._kernel_id)
def _save_image(self, image_data_base64: str) -> str:
"""Save image data to a file."""
image_data = base64.b64decode(image_data_base64)
# Randomly generate a filename.
filename = f"{uuid.uuid4().hex}.png"
path = os.path.join(self._output_dir, filename)
with open(path, "wb") as f:
f.write(image_data)
return os.path.abspath(path)
def _save_html(self, html_data: str) -> str:
"""Save html data to a file."""
# Randomly generate a filename.
filename = f"{uuid.uuid4().hex}.html"
path = os.path.join(self._output_dir, filename)
with open(path, "w") as f:
f.write(html_data)
return os.path.abspath(path)
def stop(self) -> None:
"""Stop the kernel."""
self._jupyter_client.delete_kernel(self._kernel_id)
def __enter__(self) -> Self:
return self
def __exit__(
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
self.stop()

View File

@@ -0,0 +1,172 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from __future__ import annotations
import atexit
import json
import secrets
import signal
import subprocess
import sys
from types import TracebackType
from typing import Optional
from ...doc_utils import export_module
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
from .base import JupyterConnectable, JupyterConnectionInfo
from .import_utils import require_jupyter_kernel_gateway_installed
from .jupyter_client import JupyterClient
@require_jupyter_kernel_gateway_installed()
@export_module("autogen.coding.jupyter")
class LocalJupyterServer(JupyterConnectable):
class GenerateToken:
pass
def __init__(
self,
ip: str = "127.0.0.1",
port: Optional[int] = None,
token: str | GenerateToken = GenerateToken(),
log_file: str = "jupyter_gateway.log",
log_level: str = "INFO",
log_max_bytes: int = 1048576,
log_backup_count: int = 3,
):
"""Runs a Jupyter Kernel Gateway server locally.
Args:
ip (str, optional): IP address to bind to. Defaults to "127.0.0.1".
port (Optional[int], optional): Port to use, if None it automatically selects a port. Defaults to None.
token (Union[str, GenerateToken], optional): Token to use for Jupyter server. By default will generate a token. Using None will use no token for authentication. Defaults to GenerateToken().
log_file (str, optional): File for Jupyter Kernel Gateway logs. Defaults to "jupyter_gateway.log".
log_level (str, optional): Level for Jupyter Kernel Gateway logs. Defaults to "INFO".
log_max_bytes (int, optional): Max logfile size. Defaults to 1048576.
log_backup_count (int, optional): Number of backups for rotating log. Defaults to 3.
"""
# Remove as soon as https://github.com/jupyter-server/kernel_gateway/issues/398 is fixed
if sys.platform == "win32":
raise ValueError("LocalJupyterServer is not supported on Windows due to kernelgateway bug.")
# Check Jupyter gateway server is installed
try:
subprocess.run(
[sys.executable, "-m", "jupyter", "kernelgateway", "--version"],
check=True,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError:
raise ValueError(
"Jupyter gateway server is not installed. Please install it with `pip install jupyter_kernel_gateway`."
)
self.ip: str = ip
if isinstance(token, LocalJupyterServer.GenerateToken):
token = secrets.token_hex(32)
self.token: str = token
self._subprocess: subprocess.Popen[str]
logging_config = {
"handlers": {
"file": {
"class": "logging.handlers.RotatingFileHandler",
"level": log_level,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"filename": log_file,
}
},
"loggers": {"KernelGatewayApp": {"level": log_level, "handlers": ["file", "console"]}},
}
# Run Jupyter gateway server with detached subprocess
args = [
sys.executable,
"-m",
"jupyter",
"kernelgateway",
"--KernelGatewayApp.ip",
ip,
"--KernelGatewayApp.auth_token",
token,
"--JupyterApp.answer_yes",
"true",
"--JupyterApp.logging_config",
json.dumps(logging_config),
"--JupyterWebsocketPersonality.list_kernels",
"true",
]
if port is not None:
args.extend(["--KernelGatewayApp.port", str(port)])
args.extend(["--KernelGatewayApp.port_retries", "0"])
self._subprocess = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Satisfy mypy, we know this is not None because we passed PIPE
assert self._subprocess.stderr is not None
# Read stderr until we see "is available at" or the process has exited with an error
stderr = ""
while True:
result = self._subprocess.poll()
if result is not None:
stderr += self._subprocess.stderr.read()
raise ValueError(f"Jupyter gateway server failed to start with exit code: {result}. stderr:\n{stderr}")
line = self._subprocess.stderr.readline()
stderr += line
if "ERROR:" in line:
error_info = line.split("ERROR:")[1]
raise ValueError(f"Jupyter gateway server failed to start. {error_info}")
if "is available at" in line:
# We need to extract what port it settled on
# Example output:
# Jupyter Kernel Gateway 3.0.0 is available at http://127.0.0.1:8890
if port is None:
port = int(line.split(":")[-1])
self.port: int = port
break
# Poll the subprocess to check if it is still running
result = self._subprocess.poll()
if result is not None:
raise ValueError(
f"Jupyter gateway server failed to start. Please check the logs ({log_file}) for more information."
)
atexit.register(self.stop)
def stop(self) -> None:
if self._subprocess.poll() is None:
if sys.platform == "win32":
self._subprocess.send_signal(signal.CTRL_C_EVENT)
else:
self._subprocess.send_signal(signal.SIGINT)
self._subprocess.wait()
@property
def connection_info(self) -> JupyterConnectionInfo:
return JupyterConnectionInfo(host=self.ip, use_https=False, port=self.port, token=self.token)
def get_client(self) -> JupyterClient:
return JupyterClient(self.connection_info)
def __enter__(self) -> Self:
return self
def __exit__(
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
self.stop()

View File

@@ -0,0 +1,405 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import logging
import os
import re
import subprocess
import sys
import warnings
from hashlib import md5
from pathlib import Path
from string import Template
from types import SimpleNamespace
from typing import Any, Callable, ClassVar, Optional, Union
from typing_extensions import ParamSpec
from ..code_utils import PYTHON_VARIANTS, TIMEOUT_MSG, WIN32, _cmd
from ..doc_utils import export_module
from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult
from .func_with_reqs import (
FunctionWithRequirements,
FunctionWithRequirementsStr,
_build_python_functions_file,
to_stub,
)
from .markdown_code_extractor import MarkdownCodeExtractor
from .utils import _get_file_name_from_content, silence_pip
__all__ = ("LocalCommandLineCodeExecutor",)
A = ParamSpec("A")
@export_module("autogen.coding")
class LocalCommandLineCodeExecutor(CodeExecutor):
SUPPORTED_LANGUAGES: ClassVar[list[str]] = [
"bash",
"shell",
"sh",
"pwsh",
"powershell",
"ps1",
"python",
"javascript",
"html",
"css",
]
DEFAULT_EXECUTION_POLICY: ClassVar[dict[str, bool]] = {
"bash": True,
"shell": True,
"sh": True,
"pwsh": True,
"powershell": True,
"ps1": True,
"python": True,
"javascript": False,
"html": False,
"css": False,
}
FUNCTION_PROMPT_TEMPLATE: ClassVar[
str
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
$functions"""
def __init__(
self,
timeout: int = 60,
virtual_env_context: Optional[SimpleNamespace] = None,
work_dir: Union[Path, str] = Path(),
functions: list[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [],
functions_module: str = "functions",
execution_policies: Optional[dict[str, bool]] = None,
):
"""(Experimental) A code executor class that executes or saves LLM generated code a local command line
environment.
**This will execute or save LLM generated code on the local machine.**
Each code block is saved as a file in the working directory. Depending on the execution policy,
the code may be executed in a separate process.
The code blocks are executed or save in the order they are received.
Command line code is sanitized against a list of dangerous commands to prevent self-destructive commands from being executed,
which could potentially affect the user's environment. Supported languages include Python, shell scripts (bash, shell, sh),
PowerShell (pwsh, powershell, ps1), HTML, CSS, and JavaScript.
Execution policies determine whether each language's code blocks are executed or saved only.
## Execution with a Python virtual environment
A python virtual env can be used to execute code and install dependencies. This has the added benefit of not polluting the
base environment with unwanted modules.
```python
from autogen.code_utils import create_virtual_env
from autogen.coding import LocalCommandLineCodeExecutor
venv_dir = ".venv"
venv_context = create_virtual_env(venv_dir)
executor = LocalCommandLineCodeExecutor(virtual_env_context=venv_context)
```
Args:
timeout (int): The timeout for code execution, default is 60 seconds.
virtual_env_context (Optional[SimpleNamespace]): The virtual environment context to use.
work_dir (Union[Path, str]): The working directory for code execution, defaults to the current directory.
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]): A list of callable functions available to the executor.
functions_module (str): The module name under which functions are accessible.
execution_policies (Optional[Dict[str, bool]]): A dictionary mapping languages to execution policies (True for execution, False for saving only). Defaults to class-wide DEFAULT_EXECUTION_POLICY.
"""
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
if isinstance(work_dir, str):
work_dir = Path(work_dir)
if not functions_module.isidentifier():
raise ValueError("Module name must be a valid Python identifier")
self._functions_module = functions_module
work_dir.mkdir(exist_ok=True)
self._timeout = timeout
self._work_dir: Path = work_dir
self._virtual_env_context: Optional[SimpleNamespace] = virtual_env_context
self._functions = functions
# Setup could take some time so we intentionally wait for the first code block to do it.
if len(functions) > 0:
self._setup_functions_complete = False
else:
self._setup_functions_complete = True
self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy()
if execution_policies is not None:
self.execution_policies.update(execution_policies)
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
"""(Experimental) Format the functions for a prompt.
The template includes two variables:
- `$module_name`: The module name.
- `$functions`: The functions formatted as stubs with two newlines between each function.
Args:
prompt_template (str): The prompt template. Default is the class default.
Returns:
str: The formatted prompt.
"""
template = Template(prompt_template)
return template.substitute(
module_name=self._functions_module,
functions="\n\n".join([to_stub(func) for func in self._functions]),
)
@property
def functions_module(self) -> str:
"""(Experimental) The module name for the functions."""
return self._functions_module
@property
def functions(
self,
) -> list[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]:
"""(Experimental) The functions that are available to the code executor."""
return self._functions
@property
def timeout(self) -> int:
"""(Experimental) The timeout for code execution."""
return self._timeout
@property
def work_dir(self) -> Path:
"""(Experimental) The working directory for the code execution."""
return self._work_dir
@property
def code_extractor(self) -> CodeExtractor:
"""(Experimental) Export a code extractor that can be used by an agent."""
return MarkdownCodeExtractor()
@staticmethod
def sanitize_command(lang: str, code: str) -> None:
"""Sanitize the code block to prevent dangerous commands.
This approach acknowledges that while Docker or similar
containerization/sandboxing technologies provide a robust layer of security,
not all users may have Docker installed or may choose not to use it.
Therefore, having a baseline level of protection helps mitigate risks for users who,
either out of choice or necessity, run code outside of a sandboxed environment.
"""
dangerous_patterns = [
(r"\brm\s+-rf\b", "Use of 'rm -rf' command is not allowed."),
(r"\bmv\b.*?\s+/dev/null", "Moving files to /dev/null is not allowed."),
(r"\bdd\b", "Use of 'dd' command is not allowed."),
(r">\s*/dev/sd[a-z][1-9]?", "Overwriting disk blocks directly is not allowed."),
(r":\(\)\{\s*:\|\:&\s*\};:", "Fork bombs are not allowed."),
]
if lang in ["bash", "shell", "sh"]:
for pattern, message in dangerous_patterns:
if re.search(pattern, code):
raise ValueError(f"Potentially dangerous command detected: {message}")
def _setup_functions(self) -> None:
func_file_content = _build_python_functions_file(self._functions)
func_file = self._work_dir / f"{self._functions_module}.py"
func_file.write_text(func_file_content)
# Collect requirements
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
required_packages = list(set(flattened_packages))
if len(required_packages) > 0:
logging.info("Ensuring packages are installed in executor.")
py_executable = self._virtual_env_context.env_exe if self._virtual_env_context else sys.executable
cmd = [py_executable, "-m", "pip", "install"] + required_packages
try:
result = subprocess.run(
cmd,
cwd=self._work_dir,
capture_output=True,
text=True,
timeout=float(self._timeout),
encoding="utf-8",
)
except subprocess.TimeoutExpired as e:
raise ValueError("Pip install timed out") from e
if result.returncode != 0:
raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}")
# Attempt to load the function file to check for syntax errors, imports etc.
exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")])
if exec_result.exit_code != 0:
raise ValueError(f"Functions failed to load: {exec_result.output}")
self._setup_functions_complete = True
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult:
"""(Experimental) Execute the code blocks and return the result.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
Returns:
CommandLineCodeResult: The result of the code execution.
"""
if not self._setup_functions_complete:
self._setup_functions()
return self._execute_code_dont_check_setup(code_blocks)
def _execute_code_dont_check_setup(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult:
logs_all = ""
file_names = []
for code_block in code_blocks:
lang, code = code_block.language, code_block.code
lang = lang.lower()
LocalCommandLineCodeExecutor.sanitize_command(lang, code)
code = silence_pip(code, lang)
if lang in PYTHON_VARIANTS:
lang = "python"
if WIN32 and lang in ["sh", "shell"]:
lang = "ps1"
if lang not in self.SUPPORTED_LANGUAGES:
# In case the language is not supported, we return an error message.
exitcode = 1
logs_all += "\n" + f"unknown language {lang}"
break
execute_code = self.execution_policies.get(lang, False)
try:
# Check if there is a filename comment
filename = _get_file_name_from_content(code, self._work_dir)
except ValueError:
return CommandLineCodeResult(exit_code=1, output="Filename is not in the workspace")
if filename is None:
# create a file with an automatically generated name
code_hash = md5(code.encode()).hexdigest()
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
written_file = (self._work_dir / filename).resolve()
with written_file.open("w", encoding="utf-8") as f:
f.write(code)
file_names.append(written_file)
if not execute_code:
# Just return a message that the file is saved.
logs_all += f"Code saved to {written_file!s}\n"
exitcode = 0
continue
program = _cmd(lang)
cmd = [program, str(written_file.absolute())]
env = os.environ.copy()
if self._virtual_env_context:
virtual_env_abs_path = os.path.abspath(self._virtual_env_context.bin_path)
path_with_virtualenv = rf"{virtual_env_abs_path}{os.pathsep}{env['PATH']}"
env["PATH"] = path_with_virtualenv
if WIN32:
activation_script = os.path.join(virtual_env_abs_path, "activate.bat")
cmd = [activation_script, "&&", *cmd]
try:
result = subprocess.run(
cmd,
cwd=self._work_dir,
capture_output=True,
text=True,
timeout=float(self._timeout),
env=env,
encoding="utf-8",
)
except subprocess.TimeoutExpired:
logs_all += "\n" + TIMEOUT_MSG
# Same exit code as the timeout command on linux.
exitcode = 124
break
logs_all += result.stderr
logs_all += result.stdout
exitcode = result.returncode
if exitcode != 0:
break
code_file = str(file_names[0]) if len(file_names) > 0 else None
return CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file)
def restart(self) -> None:
"""(Experimental) Restart the code executor."""
warnings.warn("Restarting local command line code executor is not supported. No action is taken.")
# From stack overflow: https://stackoverflow.com/a/52087847/2214524
class _DeprecatedClassMeta(type):
def __new__(cls, name, bases, classdict, *args, **kwargs): # type: ignore[no-untyped-def]
alias = classdict.get("_DeprecatedClassMeta__alias")
if alias is not None:
def new(cls, *args, **kwargs): # type: ignore[no-untyped-def]
alias = cls._DeprecatedClassMeta__alias
if alias is not None:
warnings.warn(
f"{cls.__name__} has been renamed to {alias.__name__}, the alias will be removed in the future",
DeprecationWarning,
stacklevel=2,
)
return alias(*args, **kwargs)
classdict["__new__"] = new
classdict["_DeprecatedClassMeta__alias"] = alias
fixed_bases = []
for b in bases:
alias = getattr(b, "_DeprecatedClassMeta__alias", None)
if alias is not None:
warnings.warn(
f"{b.__name__} has been renamed to {alias.__name__}, the alias will be removed in the future",
DeprecationWarning,
stacklevel=2,
)
# Avoid duplicate base classes.
b = alias or b
if b not in fixed_bases:
fixed_bases.append(b)
fixed_bases = tuple(fixed_bases) # type: ignore[assignment]
return super().__new__(cls, name, fixed_bases, classdict, *args, **kwargs) # type: ignore[call-overload]
def __instancecheck__(cls, instance): # type: ignore[no-untyped-def]
return any(cls.__subclasscheck__(c) for c in {type(instance), instance.__class__}) # type: ignore[no-untyped-call]
def __subclasscheck__(cls, subclass): # type: ignore[no-untyped-def]
if subclass is cls:
return True
else:
return issubclass(subclass, cls._DeprecatedClassMeta__alias) # type: ignore[attr-defined]
class LocalCommandlineCodeExecutor(metaclass=_DeprecatedClassMeta):
"""LocalCommandlineCodeExecutor renamed to LocalCommandLineCodeExecutor"""
_DeprecatedClassMeta__alias = LocalCommandLineCodeExecutor
class CommandlineCodeResult(metaclass=_DeprecatedClassMeta):
"""CommandlineCodeResult renamed to CommandLineCodeResult"""
_DeprecatedClassMeta__alias = CommandLineCodeResult

View File

@@ -0,0 +1,45 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import re
from typing import Union
from ..code_utils import CODE_BLOCK_PATTERN, UNKNOWN, content_str, infer_lang
from ..doc_utils import export_module
from ..types import UserMessageImageContentPart, UserMessageTextContentPart
from .base import CodeBlock, CodeExtractor
__all__ = ("MarkdownCodeExtractor",)
@export_module("autogen.coding")
class MarkdownCodeExtractor(CodeExtractor):
"""(Experimental) A class that extracts code blocks from a message using Markdown syntax."""
def extract_code_blocks(
self, message: Union[str, list[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None]
) -> list[CodeBlock]:
"""(Experimental) Extract code blocks from a message. If no code blocks are found,
return an empty list.
Args:
message (str): The message to extract code blocks from.
Returns:
List[CodeBlock]: The extracted code blocks or an empty list.
"""
text = content_str(message)
match = re.findall(CODE_BLOCK_PATTERN, text, flags=re.DOTALL)
if not match:
return []
code_blocks = []
for lang, code in match:
if lang == "":
lang = infer_lang(code)
if lang == UNKNOWN:
lang = ""
code_blocks.append(CodeBlock(code=code, language=lang))
return code_blocks

View File

@@ -0,0 +1,56 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/ag2ai/ag2 are under the MIT License.
# SPDX-License-Identifier: MIT
# Will return the filename relative to the workspace path
import re
from pathlib import Path
from typing import Optional
filename_patterns = [
re.compile(r"^<!-- (filename:)?(.+?) -->", re.DOTALL),
re.compile(r"^/\* (filename:)?(.+?) \*/", re.DOTALL),
re.compile(r"^// (filename:)?(.+?)$", re.DOTALL),
re.compile(r"^# (filename:)?(.+?)$", re.DOTALL),
]
# Raises ValueError if the file is not in the workspace
def _get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]:
first_line = code.split("\n")[0].strip()
# TODO - support other languages
for pattern in filename_patterns:
matches = pattern.match(first_line)
if matches is not None:
filename = matches.group(2).strip()
# Handle relative paths in the filename
path = Path(filename)
if not path.is_absolute():
path = workspace_path / path
path = path.resolve()
# Throws an error if the file is not in the workspace
relative = path.relative_to(workspace_path.resolve())
return str(relative)
return None
def silence_pip(code: str, lang: str) -> str:
"""Apply -qqq flag to pip install commands."""
if lang == "python":
regex = r"^! ?pip install"
elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]:
regex = r"^pip install"
else:
return code
# Find lines that start with pip install and make sure "-qqq" flag is added.
lines = code.split("\n")
for i, line in enumerate(lines):
# use regex to find lines that start with pip install.
match = re.search(regex, line)
if match is not None and "-qqq" not in line:
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
return "\n".join(lines)

View File

@@ -0,0 +1,34 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
__all__ = ["export_module"]
from typing import Callable, Optional, TypeVar
T = TypeVar("T")
# Global dictionary to store export module mappings
# Key: original symbol name (qualified by module)
# Value: target module where it should be documented
_PDOC_MODULE_EXPORT_MAPPINGS: dict[str, str] = {}
def export_module(module: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
original_module = getattr(cls, "__module__", None)
if original_module:
fqn = f"{original_module}.{cls.__name__}"
_PDOC_MODULE_EXPORT_MAPPINGS[fqn] = module
return cls
return decorator
def get_target_module(obj: object) -> Optional[str]:
"""Get the target module where an object should be documented."""
if not hasattr(obj, "__module__"):
return None
fqn = f"{obj.__module__}.{obj.__name__}"
return _PDOC_MODULE_EXPORT_MAPPINGS.get(fqn)

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .base_event import BaseEvent, get_annotated_type_for_event_classes, wrap_event
from .helpers import deprecated_by
__all__ = ["BaseEvent", "deprecated_by", "get_annotated_type_for_event_classes", "wrap_event"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,99 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from abc import ABC
from typing import Annotated, Any, Callable, Literal, Optional, Union
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, create_model
from ..doc_utils import export_module
__all__ = ["BaseEvent", "get_annotated_type_for_event_classes", "get_event_classes", "wrap_event"]
@export_module("autogen.events")
class BaseEvent(BaseModel, ABC):
uuid: UUID
def __init__(self, uuid: Optional[UUID] = None, **kwargs: Any) -> None:
uuid = uuid or uuid4()
super().__init__(uuid=uuid, **kwargs)
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
"""Print event
Args:
f (Optional[Callable[..., Any]], optional): Print function. If none, python's default print will be used.
"""
...
def camel2snake(name: str) -> str:
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
_event_classes: dict[str, type[BaseModel]] = {}
@export_module("autogen.events")
def wrap_event(event_cls: type[BaseEvent]) -> type[BaseModel]:
"""Wrap a event class with a type field to be used in a union type
This is needed for proper serialization and deserialization of events in a union type.
Args:
event_cls (type[BaseEvent]): Event class to wrap
"""
global _event_classes
if not event_cls.__name__.endswith("Event"):
raise ValueError("Event class name must end with 'Event'")
type_name = camel2snake(event_cls.__name__)
type_name = type_name[: -len("_event")]
class WrapperBase(BaseModel):
# these types are generated dynamically so we need to disable the type checker
type: Literal[type_name] = type_name # type: ignore[valid-type]
content: event_cls # type: ignore[valid-type]
def __init__(self, *args: Any, **data: Any):
if set(data.keys()) == {"type", "content"} and "content" in data:
super().__init__(*args, **data)
else:
if "content" in data:
content = data.pop("content")
super().__init__(*args, content=event_cls(*args, **data, content=content), **data)
else:
super().__init__(content=event_cls(*args, **data), **data)
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
self.content.print(f) # type: ignore[attr-defined]
wrapper_cls = create_model(event_cls.__name__, __base__=WrapperBase)
# Preserve the original class's docstring and other attributes
wrapper_cls.__doc__ = event_cls.__doc__
wrapper_cls.__module__ = event_cls.__module__
# Copy any other relevant attributes/metadata from the original class
if hasattr(event_cls, "__annotations__"):
wrapper_cls.__annotations__ = event_cls.__annotations__
_event_classes[type_name] = wrapper_cls
return wrapper_cls
@export_module("autogen.events")
def get_annotated_type_for_event_classes() -> type[Any]:
# this is a dynamic type so we need to disable the type checker
union_type = Union[tuple(_event_classes.values())] # type: ignore[valid-type]
return Annotated[union_type, Field(discriminator="type")] # type: ignore[return-value]
def get_event_classes() -> dict[str, type[BaseModel]]:
return _event_classes

View File

@@ -0,0 +1,167 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Literal, Optional, Union
from uuid import UUID
from pydantic import BaseModel
from .base_event import BaseEvent, wrap_event
__all__ = ["UsageSummaryEvent"]
class ModelUsageSummary(BaseModel):
"""Model usage summary."""
model: str
"""Model name."""
completion_tokens: int
"""Number of tokens used for completion."""
cost: float
"""Cost of the completion."""
prompt_tokens: int
"""Number of tokens used for prompt."""
total_tokens: int
"""Total number of tokens used."""
class ActualUsageSummary(BaseModel):
"""Actual usage summary."""
usages: Optional[list[ModelUsageSummary]] = None
"""List of model usage summaries."""
total_cost: Optional[float] = None
"""Total cost."""
class TotalUsageSummary(BaseModel):
"""Total usage summary."""
usages: Optional[list[ModelUsageSummary]] = None
"""List of model usage summaries."""
total_cost: Optional[float] = None
"""Total cost."""
Mode = Literal["both", "total", "actual"]
def _change_usage_summary_format(
actual_usage_summary: Optional[dict[str, Any]] = None, total_usage_summary: Optional[dict[str, Any]] = None
) -> dict[str, dict[str, Any]]:
summary: dict[str, Any] = {}
for usage_type, usage_summary in {"actual": actual_usage_summary, "total": total_usage_summary}.items():
if usage_summary is None:
summary[usage_type] = {"usages": None, "total_cost": None}
continue
usage_summary_altered_format: dict[str, list[dict[str, Any]]] = {"usages": []}
for k, v in usage_summary.items():
if isinstance(k, str) and isinstance(v, dict):
current_usage = {key: value for key, value in v.items()}
current_usage["model"] = k
usage_summary_altered_format["usages"].append(current_usage)
else:
usage_summary_altered_format[k] = v
summary[usage_type] = usage_summary_altered_format
return summary
@wrap_event
class UsageSummaryEvent(BaseEvent):
"""Usage summary message."""
actual: ActualUsageSummary
"""Actual usage summary."""
total: TotalUsageSummary
"""Total usage summary."""
mode: Mode
"""Mode to display the usage summary."""
def __init__(
self,
*,
uuid: Optional[UUID] = None,
actual_usage_summary: Optional[dict[str, Any]] = None,
total_usage_summary: Optional[dict[str, Any]] = None,
mode: Mode = "both",
):
# print(f"{actual_usage_summary=}")
# print(f"{total_usage_summary=}")
summary_dict = _change_usage_summary_format(actual_usage_summary, total_usage_summary)
super().__init__(uuid=uuid, **summary_dict, mode=mode)
def _print_usage(
self,
usage_summary: Union[ActualUsageSummary, TotalUsageSummary],
usage_type: str = "total",
f: Optional[Callable[..., Any]] = None,
) -> None:
f = f or print
word_from_type = "including" if usage_type == "total" else "excluding"
if usage_summary.usages is None or len(usage_summary.usages) == 0:
f("No actual cost incurred (all completions are using cache).", flush=True)
return
f(f"Usage summary {word_from_type} cached usage: ", flush=True)
f(f"Total cost: {round(usage_summary.total_cost, 5)}", flush=True) # type: ignore [arg-type]
for usage in usage_summary.usages:
f(
f"* Model '{usage.model}': cost: {round(usage.cost, 5)}, prompt_tokens: {usage.prompt_tokens}, completion_tokens: {usage.completion_tokens}, total_tokens: {usage.total_tokens}",
flush=True,
)
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print
if self.total.usages is None:
f('No usage summary. Please call "create" first.', flush=True)
return
f("-" * 100, flush=True)
if self.mode == "both":
self._print_usage(self.actual, "actual", f)
f()
if self.total.model_dump_json() != self.actual.model_dump_json():
self._print_usage(self.total, "total", f)
else:
f(
"All completions are non-cached: the total cost with cached completions is the same as actual cost.",
flush=True,
)
elif self.mode == "total":
self._print_usage(self.total, "total", f)
elif self.mode == "actual":
self._print_usage(self.actual, "actual", f)
else:
raise ValueError(f'Invalid mode: {self.mode}, choose from "actual", "total", ["actual", "total"]')
f("-" * 100, flush=True)
@wrap_event
class StreamEvent(BaseEvent):
"""Stream event."""
content: str
"""Content of the event."""
def __init__(self, *, uuid: Optional[UUID] = None, content: str) -> None:
super().__init__(uuid=uuid, content=content)
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print
# Set the terminal text color to green
f("\033[32m", end="")
f(self.content, end="", flush=True)
# Reset the terminal text color
f("\033[0m\n")

View File

@@ -0,0 +1,36 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import logging
from functools import wraps
from typing import Callable
from pydantic import BaseModel
logger = logging.getLogger(__name__)
def deprecated_by(
new_class: type[BaseModel],
param_mapping: dict[str, str] = None,
) -> Callable[[type[BaseModel]], Callable[..., BaseModel]]:
param_mapping = param_mapping or {}
def decorator(
old_class: type[BaseModel],
param_mapping: dict[str, str] = param_mapping,
) -> Callable[..., BaseModel]:
@wraps(old_class)
def wrapper(*args, **kwargs) -> BaseModel:
logger.warning(
f"{old_class.__name__} is deprecated by {new_class.__name__}. Please import it from {new_class.__module__} and use it instead."
)
# Translate old parameters to new parameters
new_kwargs = {param_mapping.get(k, k): v for k, v in kwargs.items()}
# Pass the translated parameters to the new class
return new_class(*args, **new_kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,46 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Any, Callable, Optional
from uuid import UUID
from .base_event import BaseEvent, wrap_event
@wrap_event
class PrintEvent(BaseEvent):
"""Print message"""
objects: list[str]
"""List of objects to print"""
sep: str
"""Separator between objects"""
end: str
"""End of the print"""
def __init__(
self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False, uuid: Optional[UUID] = None
):
objects_as_string = [self._to_json(x) for x in objects]
super().__init__(uuid=uuid, objects=objects_as_string, sep=sep, end=end)
def _to_json(self, obj: Any) -> str:
if isinstance(obj, str):
return obj
if hasattr(obj, "model_dump_json"):
return obj.model_dump_json() # type: ignore [no-any-return]
try:
return json.dumps(obj)
except Exception:
return str(obj)
# return repr(obj)
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print
f(*self.objects, sep=self.sep, end=self.end, flush=True)

View File

@@ -0,0 +1,73 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import Any
from .doc_utils import export_module
__all__ = [
"AgentNameConflictError",
"InvalidCarryOverTypeError",
"ModelToolNotSupportedError",
"NoEligibleSpeakerError",
"SenderRequiredError",
"UndefinedNextAgentError",
]
@export_module("autogen")
class AgentNameConflictError(Exception): # noqa: N818
def __init__(self, msg: str = "Found multiple agents with the same name.", *args: Any, **kwargs: Any):
super().__init__(msg, *args, **kwargs)
@export_module("autogen")
class NoEligibleSpeakerError(Exception): # noqa: N818
"""Exception raised for early termination of a GroupChat."""
def __init__(self, message: str = "No eligible speakers."):
self.message = message
super().__init__(self.message)
@export_module("autogen")
class SenderRequiredError(Exception): # noqa: N818
"""Exception raised when the sender is required but not provided."""
def __init__(self, message: str = "Sender is required but not provided."):
self.message = message
super().__init__(self.message)
@export_module("autogen")
class InvalidCarryOverTypeError(Exception): # noqa: N818
"""Exception raised when the carryover type is invalid."""
def __init__(
self, message: str = "Carryover should be a string or a list of strings. Not adding carryover to the message."
):
self.message = message
super().__init__(self.message)
@export_module("autogen")
class UndefinedNextAgentError(Exception): # noqa: N818
"""Exception raised when the provided next agents list does not overlap with agents in the group."""
def __init__(self, message: str = "The provided agents list does not overlap with agents in the group."):
self.message = message
super().__init__(self.message)
class ModelToolNotSupportedError(Exception):
"""Exception raised when attempting to use tools with models that do not support them."""
def __init__(
self,
model: str,
):
self.message = f"Tools are not supported with {model} models. Refer to the documentation at https://platform.openai.com/docs/guides/reasoning#limitations"
super().__init__(self.message)

View File

@@ -0,0 +1,16 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from .dependencies import Provider, dependency_provider
from .use import Depends, inject
__all__ = (
"Depends",
"Provider",
"dependency_provider",
"inject",
)

View File

@@ -0,0 +1,80 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
import sys
from importlib.metadata import version as get_version
from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, create_model
from pydantic.version import VERSION as PYDANTIC_VERSION
__all__ = (
"PYDANTIC_V2",
"BaseModel",
"ConfigDict",
"ExceptionGroup",
"create_model",
"evaluate_forwardref",
"get_config_base",
)
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
default_pydantic_config = {"arbitrary_types_allowed": True}
evaluate_forwardref: Any
# isort: off
if PYDANTIC_V2:
from pydantic import ConfigDict
from pydantic._internal._typing_extra import ( # type: ignore[no-redef]
eval_type_lenient as evaluate_forwardref,
)
def model_schema(model: Type[BaseModel]) -> Dict[str, Any]:
return model.model_json_schema()
def get_config_base(config_data: Optional[ConfigDict] = None) -> ConfigDict:
return config_data or ConfigDict(**default_pydantic_config) # type: ignore[typeddict-item]
def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]:
return tuple(f.alias or name for name, f in model.model_fields.items())
class CreateBaseModel(BaseModel):
"""Just to support FastStream < 0.3.7."""
model_config = ConfigDict(arbitrary_types_allowed=True)
else:
from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef]
from pydantic.config import get_config, ConfigDict, BaseConfig
def get_config_base(config_data: Optional[ConfigDict] = None) -> Type[BaseConfig]: # type: ignore[misc,no-any-unimported]
return get_config(config_data or ConfigDict(**default_pydantic_config)) # type: ignore[typeddict-item,no-any-unimported,no-any-return]
def model_schema(model: Type[BaseModel]) -> Dict[str, Any]:
return model.schema()
def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]:
return tuple(f.alias or name for name, f in model.__fields__.items()) # type: ignore[attr-defined]
class CreateBaseModel(BaseModel): # type: ignore[no-redef]
"""Just to support FastStream < 0.3.7."""
class Config:
arbitrary_types_allowed = True
ANYIO_V3 = get_version("anyio").startswith("3.")
if ANYIO_V3:
from anyio import ExceptionGroup as ExceptionGroup # type: ignore[attr-defined]
else:
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup as ExceptionGroup
else:
ExceptionGroup = ExceptionGroup

Some files were not shown because too many files have changed in this diff Show More