CoACT initialize (#292)
This commit is contained in:
64
mm_agents/coact/autogen/agentchat/group/__init__.py
Normal file
64
mm_agents/coact/autogen/agentchat/group/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
__all__: list[str] = []
|
||||
|
||||
from .available_condition import ExpressionAvailableCondition, StringAvailableCondition
|
||||
from .context_condition import ExpressionContextCondition, StringContextCondition
|
||||
from .context_expression import ContextExpression
|
||||
from .context_str import ContextStr
|
||||
from .context_variables import ContextVariables
|
||||
from .handoffs import Handoffs
|
||||
from .llm_condition import ContextStrLLMCondition, StringLLMCondition
|
||||
from .on_condition import OnCondition
|
||||
from .on_context_condition import OnContextCondition
|
||||
from .reply_result import ReplyResult
|
||||
from .speaker_selection_result import SpeakerSelectionResult
|
||||
from .targets.group_chat_target import GroupChatConfig, GroupChatTarget
|
||||
|
||||
"""
|
||||
from .targets.group_manager_target import (
|
||||
GroupManagerSelectionMessageContextStr,
|
||||
GroupManagerSelectionMessageString,
|
||||
GroupManagerTarget,
|
||||
)
|
||||
"""
|
||||
from .targets.transition_target import (
|
||||
AgentNameTarget,
|
||||
AgentTarget,
|
||||
AskUserTarget,
|
||||
NestedChatTarget,
|
||||
RevertToUserTarget,
|
||||
StayTarget,
|
||||
TerminateTarget,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentNameTarget",
|
||||
"AgentTarget",
|
||||
"AskUserTarget",
|
||||
"ContextExpression",
|
||||
"ContextStr",
|
||||
"ContextStrLLMCondition",
|
||||
"ContextVariables",
|
||||
"ExpressionAvailableCondition",
|
||||
"ExpressionContextCondition",
|
||||
"GroupChatConfig",
|
||||
"GroupChatTarget",
|
||||
# "GroupManagerSelectionMessageContextStr",
|
||||
# "GroupManagerSelectionMessageString",
|
||||
# "GroupManagerTarget",
|
||||
"Handoffs",
|
||||
"NestedChatTarget",
|
||||
"OnCondition",
|
||||
"OnContextCondition",
|
||||
"ReplyResult",
|
||||
"RevertToUserTarget",
|
||||
"SpeakerSelectionResult",
|
||||
"StayTarget",
|
||||
"StringAvailableCondition",
|
||||
"StringContextCondition",
|
||||
"StringLLMCondition",
|
||||
"TerminateTarget",
|
||||
]
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .context_expression import ContextExpression
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Avoid circular import
|
||||
from ..conversable_agent import ConversableAgent
|
||||
|
||||
__all__ = ["AvailableCondition", "ExpressionAvailableCondition", "StringAvailableCondition"]
|
||||
|
||||
|
||||
class AvailableCondition(BaseModel):
|
||||
"""Protocol for determining if a condition is available to be evaluated."""
|
||||
|
||||
def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool:
|
||||
"""Determine if the condition should be considered for evaluation.
|
||||
|
||||
Args:
|
||||
agent: The agent evaluating the condition
|
||||
messages: The conversation history
|
||||
|
||||
Returns:
|
||||
True if the condition should be evaluated, False otherwise
|
||||
"""
|
||||
raise NotImplementedError("Requires subclasses to implement.")
|
||||
|
||||
|
||||
class StringAvailableCondition(AvailableCondition):
|
||||
"""String-based available condition.
|
||||
|
||||
This condition checks if a named context variable exists and is truthy.
|
||||
"""
|
||||
|
||||
context_variable: str
|
||||
|
||||
def __init__(self, context_variable: str, **data: Any) -> None:
|
||||
"""Initialize with a context variable name as a positional parameter.
|
||||
|
||||
Args:
|
||||
context_variable: The name of the context variable to check
|
||||
data: Additional data for the parent class
|
||||
"""
|
||||
super().__init__(context_variable=context_variable, **data)
|
||||
|
||||
def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if the named context variable is truthy.
|
||||
|
||||
Args:
|
||||
agent: The agent with context variables
|
||||
messages: The conversation history (not used)
|
||||
|
||||
Returns:
|
||||
True if the variable exists and is truthy, False otherwise
|
||||
"""
|
||||
return bool(agent.context_variables.get(self.context_variable, False))
|
||||
|
||||
|
||||
class ExpressionAvailableCondition(AvailableCondition):
|
||||
"""Expression-based available condition.
|
||||
|
||||
This condition evaluates a ContextExpression against the context variables.
|
||||
"""
|
||||
|
||||
expression: ContextExpression
|
||||
|
||||
def __init__(self, expression: ContextExpression, **data: Any) -> None:
|
||||
"""Initialize with an expression as a positional parameter.
|
||||
|
||||
Args:
|
||||
expression: The context expression to evaluate
|
||||
data: Additional data for the parent class
|
||||
"""
|
||||
super().__init__(expression=expression, **data)
|
||||
|
||||
def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool:
|
||||
"""Evaluate the expression against the context variables.
|
||||
|
||||
Args:
|
||||
agent: The agent with context variables
|
||||
messages: The conversation history (not used)
|
||||
|
||||
Returns:
|
||||
Boolean result of the expression evaluation
|
||||
"""
|
||||
return self.expression.evaluate(agent.context_variables)
|
||||
77
mm_agents/coact/autogen/agentchat/group/context_condition.py
Normal file
77
mm_agents/coact/autogen/agentchat/group/context_condition.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .context_expression import ContextExpression
|
||||
from .context_variables import ContextVariables
|
||||
|
||||
__all__ = ["ContextCondition", "ExpressionContextCondition", "StringContextCondition"]
|
||||
|
||||
|
||||
class ContextCondition(BaseModel):
|
||||
"""Protocol for conditions evaluated directly using context variables."""
|
||||
|
||||
def evaluate(self, context_variables: ContextVariables) -> bool:
|
||||
"""Evaluate the condition to a boolean result.
|
||||
|
||||
Args:
|
||||
context_variables: The context variables to evaluate against
|
||||
|
||||
Returns:
|
||||
Boolean result of the condition evaluation
|
||||
"""
|
||||
raise NotImplementedError("Requires subclasses to implement.")
|
||||
|
||||
|
||||
class StringContextCondition(ContextCondition):
|
||||
"""Simple string-based context condition.
|
||||
|
||||
This condition checks if a named context variable exists and is truthy.
|
||||
"""
|
||||
|
||||
variable_name: str
|
||||
|
||||
def evaluate(self, context_variables: ContextVariables) -> bool:
|
||||
"""Check if the named context variable is truthy.
|
||||
|
||||
Args:
|
||||
context_variables: The context variables to check against
|
||||
|
||||
Returns:
|
||||
True if the variable exists and is truthy, False otherwise
|
||||
"""
|
||||
return bool(context_variables.get(self.variable_name, False))
|
||||
|
||||
|
||||
class ExpressionContextCondition(ContextCondition):
|
||||
"""Complex expression-based context condition.
|
||||
|
||||
This condition evaluates a ContextExpression against the context variables.
|
||||
"""
|
||||
|
||||
expression: ContextExpression
|
||||
|
||||
def __init__(self, expression: ContextExpression, **data: Any) -> None:
|
||||
"""Initialize with an expression as a positional parameter.
|
||||
|
||||
Args:
|
||||
expression: The context expression to evaluate
|
||||
data: Additional data for the parent class
|
||||
"""
|
||||
super().__init__(expression=expression, **data)
|
||||
|
||||
def evaluate(self, context_variables: ContextVariables) -> bool:
|
||||
"""Evaluate the expression against the context variables.
|
||||
|
||||
Args:
|
||||
context_variables: The context variables to evaluate against
|
||||
|
||||
Returns:
|
||||
Boolean result of the expression evaluation
|
||||
"""
|
||||
return self.expression.evaluate(context_variables)
|
||||
238
mm_agents/coact/autogen/agentchat/group/context_expression.py
Normal file
238
mm_agents/coact/autogen/agentchat/group/context_expression.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ast
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from .context_variables import ContextVariables
|
||||
|
||||
|
||||
@dataclass
|
||||
@export_module("autogen")
|
||||
class ContextExpression:
|
||||
"""A class to evaluate logical expressions using context variables.
|
||||
|
||||
Args:
|
||||
expression (str): A string containing a logical expression with context variable references.
|
||||
- Variable references use ${var_name} syntax: ${logged_in}, ${attempts}
|
||||
- String literals can use normal quotes: 'hello', "world"
|
||||
- Supported operators:
|
||||
- Logical: not/!, and/&, or/|
|
||||
- Comparison: >, <, >=, <=, ==, !=
|
||||
- Supported functions:
|
||||
- len(${var_name}): Gets the length of a list, string, or other collection
|
||||
- Parentheses can be used for grouping
|
||||
- Examples:
|
||||
- "not ${logged_in} and ${is_admin} or ${guest_checkout}"
|
||||
- "!${logged_in} & ${is_admin} | ${guest_checkout}"
|
||||
- "len(${orders}) > 0 & ${user_active}"
|
||||
- "len(${cart_items}) == 0 | ${checkout_started}"
|
||||
|
||||
Raises:
|
||||
SyntaxError: If the expression cannot be parsed
|
||||
ValueError: If the expression contains disallowed operations
|
||||
"""
|
||||
|
||||
expression: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Validate the expression immediately upon creation
|
||||
try:
|
||||
# Extract variable references and replace with placeholders
|
||||
self._variable_names = self._extract_variable_names(self.expression)
|
||||
|
||||
# Convert symbolic operators to Python keywords
|
||||
python_expr = self._convert_to_python_syntax(self.expression)
|
||||
|
||||
# Sanitize for AST parsing
|
||||
sanitized_expr = self._prepare_for_ast(python_expr)
|
||||
|
||||
# Use ast to parse and validate the expression
|
||||
self._ast = ast.parse(sanitized_expr, mode="eval")
|
||||
|
||||
# Verify it only contains allowed operations
|
||||
self._validate_operations(self._ast.body)
|
||||
|
||||
# Store the Python-syntax version for evaluation
|
||||
self._python_expr = python_expr
|
||||
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(f"Invalid expression syntax in '{self.expression}': {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error validating expression '{self.expression}': {str(e)}")
|
||||
|
||||
def _extract_variable_names(self, expr: str) -> list[str]:
|
||||
"""Extract all variable references ${var_name} from the expression."""
|
||||
# Find all patterns like ${var_name}
|
||||
matches = re.findall(r"\${([^}]*)}", expr)
|
||||
return matches
|
||||
|
||||
def _convert_to_python_syntax(self, expr: str) -> str:
|
||||
"""Convert symbolic operators to Python keywords."""
|
||||
# We need to be careful about operators inside string literals
|
||||
# First, temporarily replace string literals with placeholders
|
||||
string_literals = []
|
||||
|
||||
def replace_string_literal(match: re.Match[str]) -> str:
|
||||
string_literals.append(match.group(0))
|
||||
return f"__STRING_LITERAL_{len(string_literals) - 1}__"
|
||||
|
||||
# Replace both single and double quoted strings
|
||||
expr_without_strings = re.sub(r"'[^']*'|\"[^\"]*\"", replace_string_literal, expr)
|
||||
|
||||
# Handle the NOT operator (!) - no parentheses handling needed
|
||||
# Replace standalone ! before variables or expressions
|
||||
expr_without_strings = re.sub(r"!\s*(\${|\()", "not \\1", expr_without_strings)
|
||||
|
||||
# Handle AND and OR operators - simpler approach without parentheses handling
|
||||
expr_without_strings = re.sub(r"\s+&\s+", " and ", expr_without_strings)
|
||||
expr_without_strings = re.sub(r"\s+\|\s+", " or ", expr_without_strings)
|
||||
|
||||
# Now put string literals back
|
||||
for i, literal in enumerate(string_literals):
|
||||
expr_without_strings = expr_without_strings.replace(f"__STRING_LITERAL_{i}__", literal)
|
||||
|
||||
return expr_without_strings
|
||||
|
||||
def _prepare_for_ast(self, expr: str) -> str:
|
||||
"""Convert the expression to valid Python for AST parsing by replacing variables with placeholders."""
|
||||
# Replace ${var_name} with var_name for AST parsing
|
||||
processed_expr = expr
|
||||
for var_name in self._variable_names:
|
||||
processed_expr = processed_expr.replace(f"${{{var_name}}}", var_name)
|
||||
|
||||
return processed_expr
|
||||
|
||||
def _validate_operations(self, node: ast.AST) -> None:
|
||||
"""Recursively validate that only allowed operations exist in the AST."""
|
||||
allowed_node_types = (
|
||||
# Boolean operations
|
||||
ast.BoolOp,
|
||||
ast.UnaryOp,
|
||||
ast.And,
|
||||
ast.Or,
|
||||
ast.Not,
|
||||
# Comparison operations
|
||||
ast.Compare,
|
||||
ast.Eq,
|
||||
ast.NotEq,
|
||||
ast.Lt,
|
||||
ast.LtE,
|
||||
ast.Gt,
|
||||
ast.GtE,
|
||||
# Basic nodes
|
||||
ast.Name,
|
||||
ast.Load,
|
||||
ast.Constant,
|
||||
ast.Expression,
|
||||
# Support for basic numeric operations in comparisons
|
||||
ast.Num,
|
||||
ast.NameConstant,
|
||||
# Support for negative numbers
|
||||
ast.USub,
|
||||
ast.UnaryOp,
|
||||
# Support for string literals
|
||||
ast.Str,
|
||||
ast.Constant,
|
||||
# Support for function calls (specifically len())
|
||||
ast.Call,
|
||||
)
|
||||
|
||||
if not isinstance(node, allowed_node_types):
|
||||
raise ValueError(f"Operation type {type(node).__name__} is not allowed in logical expressions")
|
||||
|
||||
# Special validation for function calls - only allow len()
|
||||
if isinstance(node, ast.Call):
|
||||
if not (isinstance(node.func, ast.Name) and node.func.id == "len"):
|
||||
raise ValueError(f"Only the len() function is allowed, got: {getattr(node.func, 'id', 'unknown')}")
|
||||
if len(node.args) != 1:
|
||||
raise ValueError(f"len() function must have exactly one argument, got {len(node.args)}")
|
||||
|
||||
# Special validation for Compare nodes
|
||||
if isinstance(node, ast.Compare):
|
||||
for op in node.ops:
|
||||
if not isinstance(op, (ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE)):
|
||||
raise ValueError(f"Comparison operator {type(op).__name__} is not allowed")
|
||||
|
||||
# Recursively check child nodes
|
||||
for child in ast.iter_child_nodes(node):
|
||||
self._validate_operations(child)
|
||||
|
||||
def evaluate(self, context_variables: ContextVariables) -> bool:
|
||||
"""Evaluate the expression using the provided context variables.
|
||||
|
||||
Args:
|
||||
context_variables: Dictionary of context variables to use for evaluation
|
||||
|
||||
Returns:
|
||||
bool: The result of evaluating the expression
|
||||
|
||||
Raises:
|
||||
KeyError: If a variable referenced in the expression is not found in the context
|
||||
"""
|
||||
# Create a modified expression that we can safely evaluate
|
||||
eval_expr = self._python_expr # Use the Python-syntax version
|
||||
|
||||
# First, handle len() functions with variable references inside
|
||||
len_pattern = r"len\(\${([^}]*)}\)"
|
||||
len_matches = list(re.finditer(len_pattern, eval_expr))
|
||||
|
||||
# Process all len() operations first
|
||||
for match in len_matches:
|
||||
var_name = match.group(1)
|
||||
# Check if variable exists in context, raise KeyError if not
|
||||
if not context_variables.contains(var_name):
|
||||
raise KeyError(f"Missing context variable: '{var_name}'")
|
||||
|
||||
var_value = context_variables.get(var_name)
|
||||
|
||||
# Calculate the length - works for lists, strings, dictionaries, etc.
|
||||
try:
|
||||
length_value = len(var_value) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
# If the value doesn't support len(), treat as 0
|
||||
length_value = 0
|
||||
|
||||
# Replace the len() expression with the actual length
|
||||
full_match = match.group(0)
|
||||
eval_expr = eval_expr.replace(full_match, str(length_value))
|
||||
|
||||
# Then replace remaining variable references with their values
|
||||
for var_name in self._variable_names:
|
||||
# Skip variables that were already processed in len() expressions
|
||||
if any(m.group(1) == var_name for m in len_matches):
|
||||
continue
|
||||
|
||||
# Check if variable exists in context, raise KeyError if not
|
||||
if not context_variables.contains(var_name):
|
||||
raise KeyError(f"Missing context variable: '{var_name}'")
|
||||
|
||||
# Get the value from context
|
||||
var_value = context_variables.get(var_name)
|
||||
|
||||
# Format the value appropriately based on its type
|
||||
if isinstance(var_value, (bool, int, float)):
|
||||
formatted_value = str(var_value)
|
||||
elif isinstance(var_value, str):
|
||||
formatted_value = f"'{var_value}'" # Quote strings
|
||||
elif isinstance(var_value, (list, dict, tuple)):
|
||||
# For collections, convert to their boolean evaluation
|
||||
formatted_value = str(bool(var_value))
|
||||
else:
|
||||
formatted_value = str(var_value)
|
||||
|
||||
# Replace the variable reference with the formatted value
|
||||
eval_expr = eval_expr.replace(f"${{{var_name}}}", formatted_value)
|
||||
|
||||
try:
|
||||
return eval(eval_expr) # type: ignore[no-any-return]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error evaluating expression '{self.expression}' (are you sure you're using ${{my_context_variable_key}}): {str(e)}"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextExpression('{self.expression}')"
|
||||
41
mm_agents/coact/autogen/agentchat/group/context_str.py
Normal file
41
mm_agents/coact/autogen/agentchat/group/context_str.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .context_variables import ContextVariables
|
||||
|
||||
__all__ = ["ContextStr"]
|
||||
|
||||
|
||||
class ContextStr(BaseModel):
|
||||
"""A string that requires context variable substitution.
|
||||
|
||||
Use the format method to substitute context variables into the string.
|
||||
"""
|
||||
|
||||
"""The string to be substituted with context variables. It is expected that the string will contain `{var}` placeholders and that string format will be able to replace all values."""
|
||||
template: str
|
||||
|
||||
def format(self, context_variables: ContextVariables) -> Optional[str]:
|
||||
"""Substitute context variables into the string.
|
||||
|
||||
Args:
|
||||
context_variables (ContextVariables): The context variables to substitute into the string.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The formatted string with context variables substituted.
|
||||
"""
|
||||
|
||||
context = context_variables.to_dict()
|
||||
|
||||
if not context:
|
||||
return self.template
|
||||
|
||||
return self.template.format(**context)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextStr, unformatted: {self.template}"
|
||||
192
mm_agents/coact/autogen/agentchat/group/context_variables.py
Normal file
192
mm_agents/coact/autogen/agentchat/group/context_variables.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Generator, Iterable, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = ["ContextVariables"]
|
||||
|
||||
# Parameter name for context variables
|
||||
# Use the value in functions and they will be substituted with the context variables:
|
||||
# e.g. def my_function(context_variables: ContextVariables, my_other_parameters: Any) -> Any:
|
||||
__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables"
|
||||
|
||||
|
||||
class ContextVariables(BaseModel):
|
||||
"""
|
||||
Stores and manages context variables for agentic workflows.
|
||||
|
||||
Utilises a dictionary-like interface for setting, getting, and removing variables.
|
||||
"""
|
||||
|
||||
# Internal storage for context variables
|
||||
data: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __init__(self, data: Optional[dict[str, Any]] = None, **kwargs: Any) -> None:
|
||||
"""Initialize with data dictionary as an optional positional parameter.
|
||||
|
||||
Args:
|
||||
data: Initial dictionary of context variables (optional)
|
||||
kwargs: Additional keyword arguments for the parent class
|
||||
"""
|
||||
init_data = data or {}
|
||||
super().__init__(data=init_data, **kwargs)
|
||||
|
||||
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
|
||||
"""
|
||||
Get a value from the context by key.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
default: The default value to return if key is not found
|
||||
|
||||
Returns:
|
||||
The value associated with the key or default if not found
|
||||
"""
|
||||
return self.data.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
Set a value in the context by key.
|
||||
|
||||
Args:
|
||||
key: The key to set
|
||||
value: The value to store
|
||||
"""
|
||||
self.data[key] = value
|
||||
|
||||
def remove(self, key: str) -> bool:
|
||||
"""
|
||||
Remove a key from the context.
|
||||
|
||||
Args:
|
||||
key: The key to remove
|
||||
|
||||
Returns:
|
||||
True if the key was removed, False if it didn't exist
|
||||
"""
|
||||
if key in self.data:
|
||||
del self.data[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def keys(self) -> Iterable[str]:
|
||||
"""
|
||||
Get all keys in the context.
|
||||
|
||||
Returns:
|
||||
An iterable of all keys
|
||||
"""
|
||||
return self.data.keys()
|
||||
|
||||
def values(self) -> Iterable[Any]:
|
||||
"""
|
||||
Get all values in the context.
|
||||
|
||||
Returns:
|
||||
An iterable of all values
|
||||
"""
|
||||
return self.data.values()
|
||||
|
||||
def items(self) -> Iterable[tuple[str, Any]]:
|
||||
"""
|
||||
Get all key-value pairs in the context.
|
||||
|
||||
Returns:
|
||||
An iterable of all key-value pairs
|
||||
"""
|
||||
return self.data.items()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all keys and values from the context."""
|
||||
self.data.clear()
|
||||
|
||||
def contains(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists in the context.
|
||||
|
||||
Args:
|
||||
key: The key to check
|
||||
|
||||
Returns:
|
||||
True if the key exists, False otherwise
|
||||
"""
|
||||
return key in self.data
|
||||
|
||||
def update(self, other: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update context with key-value pairs from another dictionary.
|
||||
|
||||
Args:
|
||||
other: Dictionary containing key-value pairs to add
|
||||
"""
|
||||
self.data.update(other)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert context variables to a dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of all context variables
|
||||
"""
|
||||
return self.data.copy()
|
||||
|
||||
# Dictionary-compatible interface
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Get a value using dictionary syntax: context[key]"""
|
||||
try:
|
||||
return self.data[key]
|
||||
except KeyError:
|
||||
raise KeyError(f"Context variable '{key}' not found")
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""Set a value using dictionary syntax: context[key] = value"""
|
||||
self.data[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""Delete a key using dictionary syntax: del context[key]"""
|
||||
try:
|
||||
del self.data[key]
|
||||
except KeyError:
|
||||
raise KeyError(f"Cannot delete non-existent context variable '{key}'")
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if key exists using 'in' operator: key in context"""
|
||||
return key in self.data
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the number of items: len(context)"""
|
||||
return len(self.data)
|
||||
|
||||
def __iter__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
"""Iterate over keys: for key in context"""
|
||||
for key in self.data:
|
||||
yield (key, self.data[key])
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of context variables."""
|
||||
return f"ContextVariables({self.data})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Detailed representation of context variables."""
|
||||
return f"ContextVariables(data={self.data!r})"
|
||||
|
||||
# Utility methods
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ContextVariables":
|
||||
"""
|
||||
Create a new ContextVariables instance from a dictionary.
|
||||
|
||||
E.g.:
|
||||
my_context = {"user_id": "12345", "settings": {"theme": "dark"}}
|
||||
context = ContextVariables.from_dict(my_context)
|
||||
|
||||
Args:
|
||||
data: Dictionary of key-value pairs
|
||||
|
||||
Returns:
|
||||
New ContextVariables instance
|
||||
"""
|
||||
return cls(data=data)
|
||||
202
mm_agents/coact/autogen/agentchat/group/group_tool_executor.py
Normal file
202
mm_agents/coact/autogen/agentchat/group/group_tool_executor.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from typing import Annotated, Any, Callable, Optional
|
||||
|
||||
from ...oai import OpenAIWrapper
|
||||
from ...tools import Depends, Tool
|
||||
from ...tools.dependency_injection import inject_params, on
|
||||
from ..agent import Agent
|
||||
from ..conversable_agent import ConversableAgent
|
||||
from .context_variables import __CONTEXT_VARIABLES_PARAM_NAME__, ContextVariables
|
||||
from .reply_result import ReplyResult
|
||||
from .targets.transition_target import TransitionTarget
|
||||
|
||||
__TOOL_EXECUTOR_NAME__ = "_Group_Tool_Executor"
|
||||
|
||||
|
||||
class GroupToolExecutor(ConversableAgent):
|
||||
"""Tool executor for the group chat initiated with initiate_group_chat"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
name=__TOOL_EXECUTOR_NAME__,
|
||||
system_message="Tool Execution, do not use this agent directly.",
|
||||
human_input_mode="NEVER",
|
||||
code_execution_config=False,
|
||||
)
|
||||
|
||||
# Store the next target from a tool call
|
||||
self._group_next_target: Optional[TransitionTarget] = None
|
||||
|
||||
# Primary tool reply function for handling the tool reply and the ReplyResult and TransitionTarget returns
|
||||
self.register_reply([Agent, None], self._generate_group_tool_reply, remove_other_reply_funcs=True)
|
||||
|
||||
def set_next_target(self, next_target: TransitionTarget) -> None:
|
||||
"""Sets the next target to transition to, used in the determine_next_agent function."""
|
||||
self._group_next_target = next_target
|
||||
|
||||
def get_next_target(self) -> TransitionTarget:
|
||||
"""Gets the next target to transition to."""
|
||||
"""Returns the next target to transition to, if it exists."""
|
||||
if self._group_next_target is None:
|
||||
raise ValueError(
|
||||
"No next target set. Please set a next target before calling this method. Use has_next_target() to check if a next target exists."
|
||||
)
|
||||
return self._group_next_target
|
||||
|
||||
def has_next_target(self) -> bool:
|
||||
"""Checks if there is a next target to transition to."""
|
||||
return self._group_next_target is not None
|
||||
|
||||
def clear_next_target(self) -> None:
|
||||
"""Clears the next target to transition to."""
|
||||
self._group_next_target = None
|
||||
|
||||
def _modify_context_variables_param(
|
||||
self, f: Callable[..., Any], context_variables: ContextVariables
|
||||
) -> Callable[..., Any]:
|
||||
"""Modifies the context_variables parameter to use dependency injection and link it to the group context variables.
|
||||
|
||||
This essentially changes:
|
||||
def some_function(some_variable: int, context_variables: ContextVariables) -> str:
|
||||
|
||||
to:
|
||||
|
||||
def some_function(some_variable: int, context_variables: Annotated[ContextVariables, Depends(on(self.context_variables))]) -> str:
|
||||
"""
|
||||
sig = inspect.signature(f)
|
||||
|
||||
# Check if context_variables parameter exists and update it if so
|
||||
if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters:
|
||||
new_params = []
|
||||
for name, param in sig.parameters.items():
|
||||
if name == __CONTEXT_VARIABLES_PARAM_NAME__:
|
||||
# Replace with new annotation using Depends
|
||||
new_param = param.replace(annotation=Annotated[ContextVariables, Depends(on(context_variables))])
|
||||
new_params.append(new_param)
|
||||
else:
|
||||
new_params.append(param)
|
||||
|
||||
# Update signature
|
||||
new_sig = sig.replace(parameters=new_params)
|
||||
f.__signature__ = new_sig # type: ignore[attr-defined]
|
||||
|
||||
return f
|
||||
|
||||
def _change_tool_context_variables_to_depends(
|
||||
self, agent: ConversableAgent, current_tool: Tool, context_variables: ContextVariables
|
||||
) -> None:
|
||||
"""Checks for the context_variables parameter in the tool and updates it to use dependency injection."""
|
||||
|
||||
# If the tool has a context_variables parameter, remove the tool and reregister it without the parameter
|
||||
if __CONTEXT_VARIABLES_PARAM_NAME__ in current_tool.tool_schema["function"]["parameters"]["properties"]:
|
||||
# We'll replace the tool, so start with getting the underlying function
|
||||
tool_func = current_tool._func
|
||||
|
||||
# Remove the Tool from the agent
|
||||
name = current_tool._name
|
||||
description = current_tool._description
|
||||
agent.remove_tool_for_llm(current_tool)
|
||||
|
||||
# Recreate the tool without the context_variables parameter
|
||||
tool_func = self._modify_context_variables_param(current_tool._func, context_variables)
|
||||
tool_func = inject_params(tool_func)
|
||||
new_tool = ConversableAgent._create_tool_if_needed(
|
||||
func_or_tool=tool_func, name=name, description=description
|
||||
)
|
||||
|
||||
# Re-register with the agent
|
||||
agent.register_for_llm()(new_tool)
|
||||
|
||||
def register_agents_functions(self, agents: list[ConversableAgent], context_variables: ContextVariables) -> None:
|
||||
"""Adds the functions of the agents to the group tool executor."""
|
||||
for agent in agents:
|
||||
# As we're moving towards tools and away from function maps, this may not be used
|
||||
self._function_map.update(agent._function_map)
|
||||
|
||||
# Update any agent tools that have context_variables parameters to use Dependency Injection
|
||||
for tool in agent.tools:
|
||||
self._change_tool_context_variables_to_depends(agent, tool, context_variables)
|
||||
|
||||
# Add all tools to the Tool Executor agent
|
||||
for tool in agent.tools:
|
||||
self.register_for_execution(serialize=False, silent_override=True)(tool)
|
||||
|
||||
def _generate_group_tool_reply(
|
||||
self,
|
||||
agent: ConversableAgent,
|
||||
messages: Optional[list[dict[str, Any]]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[OpenAIWrapper] = None,
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""Pre-processes and generates tool call replies.
|
||||
|
||||
This function:
|
||||
1. Adds context_variables back to the tool call for the function, if necessary.
|
||||
2. Generates the tool calls reply.
|
||||
3. Updates context_variables and next_agent based on the tool call response."""
|
||||
|
||||
if config is None:
|
||||
config = agent # type: ignore[assignment]
|
||||
if messages is None:
|
||||
messages = agent._oai_messages[sender]
|
||||
|
||||
message = messages[-1]
|
||||
if "tool_calls" in message:
|
||||
tool_call_count = len(message["tool_calls"])
|
||||
|
||||
# Loop through tool calls individually (so context can be updated after each function call)
|
||||
next_target: Optional[TransitionTarget] = None
|
||||
tool_responses_inner = []
|
||||
contents = []
|
||||
for index in range(tool_call_count):
|
||||
message_copy = deepcopy(message)
|
||||
|
||||
# 1. add context_variables to the tool call arguments
|
||||
tool_call = message_copy["tool_calls"][index]
|
||||
|
||||
# Ensure we are only executing the one tool at a time
|
||||
message_copy["tool_calls"] = [tool_call]
|
||||
|
||||
# 2. generate tool calls reply
|
||||
_, tool_message = agent.generate_tool_calls_reply([message_copy])
|
||||
|
||||
if tool_message is None:
|
||||
raise ValueError("Tool call did not return a message")
|
||||
|
||||
# 3. update context_variables and next_agent, convert content to string
|
||||
for tool_response in tool_message["tool_responses"]:
|
||||
content = tool_response.get("content")
|
||||
|
||||
# Tool Call returns that are a target are either a ReplyResult or a TransitionTarget are the next agent
|
||||
if isinstance(content, ReplyResult):
|
||||
if content.context_variables and content.context_variables.to_dict() != {}:
|
||||
agent.context_variables.update(content.context_variables.to_dict())
|
||||
if content.target is not None:
|
||||
next_target = content.target
|
||||
elif isinstance(content, TransitionTarget):
|
||||
next_target = content
|
||||
|
||||
# Serialize the content to a string
|
||||
if content is not None:
|
||||
tool_response["content"] = str(content)
|
||||
|
||||
tool_responses_inner.append(tool_response)
|
||||
contents.append(str(tool_response["content"]))
|
||||
|
||||
self._group_next_target = next_target # type: ignore[attr-defined]
|
||||
|
||||
# Put the tool responses and content strings back into the response message
|
||||
# Caters for multiple tool calls
|
||||
if tool_message is None:
|
||||
raise ValueError("Tool call did not return a message")
|
||||
|
||||
tool_message["tool_responses"] = tool_responses_inner
|
||||
tool_message["content"] = "\n".join(contents)
|
||||
|
||||
return True, tool_message
|
||||
return False, None
|
||||
636
mm_agents/coact/autogen/agentchat/group/group_utils.py
Normal file
636
mm_agents/coact/autogen/agentchat/group/group_utils.py
Normal file
@@ -0,0 +1,636 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from ..agent import Agent
|
||||
from ..groupchat import GroupChat, GroupChatManager
|
||||
from .context_variables import ContextVariables
|
||||
from .group_tool_executor import GroupToolExecutor
|
||||
from .targets.group_manager_target import GroupManagerTarget
|
||||
from .targets.transition_target import (
|
||||
AgentNameTarget,
|
||||
AgentTarget,
|
||||
TransitionTarget,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..conversable_agent import ConversableAgent
|
||||
|
||||
# Utility functions for group chat preparation and management
|
||||
# These are extracted from multi_agent_chat.py to avoid circular imports
|
||||
|
||||
|
||||
def update_conditional_functions(agent: "ConversableAgent", messages: list[dict[str, Any]]) -> None:
|
||||
"""Updates the agent's functions based on the OnCondition's available condition.
|
||||
|
||||
All functions are removed and then added back if they are available
|
||||
"""
|
||||
for on_condition in agent.handoffs.llm_conditions:
|
||||
is_available = on_condition.available.is_available(agent, messages) if on_condition.available else True
|
||||
|
||||
# Remove it from their tools
|
||||
for tool in agent.tools:
|
||||
if tool.name == on_condition.llm_function_name:
|
||||
agent.remove_tool_for_llm(tool)
|
||||
break
|
||||
|
||||
# then add the function if it is available, so that the function signature is updated
|
||||
if is_available:
|
||||
agent._add_single_function(
|
||||
_create_on_condition_handoff_function(on_condition.target),
|
||||
on_condition.llm_function_name,
|
||||
on_condition.condition.get_prompt(agent, messages),
|
||||
)
|
||||
|
||||
|
||||
def establish_group_agent(agent: "ConversableAgent") -> None:
|
||||
"""Establish the group agent with the group-related attributes and hooks. Not for the tool executor.
|
||||
|
||||
Args:
|
||||
agent ("ConversableAgent"): The agent to establish as a group agent.
|
||||
"""
|
||||
|
||||
def _group_agent_str(self: "ConversableAgent") -> str:
|
||||
"""Customise the __str__ method to show the agent name for transition messages."""
|
||||
return f"Group agent --> {self.name}"
|
||||
|
||||
# Register the hook to update agent state (except tool executor)
|
||||
agent.register_hook("update_agent_state", update_conditional_functions)
|
||||
|
||||
# Register a reply function to run Python function-based OnContextConditions before any other reply function
|
||||
agent.register_reply(trigger=([Agent, None]), reply_func=_run_oncontextconditions, position=0)
|
||||
|
||||
agent._get_display_name = MethodType(_group_agent_str, agent) # type: ignore[method-assign]
|
||||
|
||||
# Mark this agent as established as a group agent
|
||||
agent._group_is_established = True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def link_agents_to_group_manager(agents: list[Agent], group_chat_manager: Agent) -> None:
|
||||
"""Link all agents to the GroupChatManager so they can access the underlying GroupChat and other agents.
|
||||
|
||||
This is primarily used so that agents can get to the tool executor to help set the next agent.
|
||||
|
||||
Does not link the Tool Executor agent.
|
||||
"""
|
||||
for agent in agents:
|
||||
agent._group_manager = group_chat_manager # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _evaluate_after_works_conditions(
|
||||
agent: "ConversableAgent",
|
||||
groupchat: GroupChat,
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> Optional[Union[Agent, str]]:
|
||||
"""Evaluate after_works context conditions for an agent.
|
||||
|
||||
Args:
|
||||
agent: The agent to evaluate after_works conditions for
|
||||
groupchat: The current group chat
|
||||
user_agent: Optional user proxy agent
|
||||
|
||||
Returns:
|
||||
The resolved speaker selection result if a condition matches, None otherwise
|
||||
"""
|
||||
if not hasattr(agent, "handoffs") or not agent.handoffs.after_works: # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
for after_work_condition in agent.handoffs.after_works: # type: ignore[attr-defined]
|
||||
# Check if condition is available
|
||||
is_available = (
|
||||
after_work_condition.available.is_available(agent, groupchat.messages)
|
||||
if after_work_condition.available
|
||||
else True
|
||||
)
|
||||
|
||||
# Evaluate the condition (None condition means always true)
|
||||
if is_available and (
|
||||
after_work_condition.condition is None or after_work_condition.condition.evaluate(agent.context_variables)
|
||||
):
|
||||
# Condition matched, resolve and return
|
||||
return after_work_condition.target.resolve(
|
||||
groupchat,
|
||||
agent,
|
||||
user_agent,
|
||||
).get_speaker_selection_result(groupchat)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _run_oncontextconditions(
|
||||
agent: "ConversableAgent",
|
||||
messages: Optional[list[dict[str, Any]]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
|
||||
"""Run OnContextConditions for an agent before any other reply function."""
|
||||
for on_condition in agent.handoffs.context_conditions: # type: ignore[attr-defined]
|
||||
is_available = (
|
||||
on_condition.available.is_available(agent, messages if messages else []) if on_condition.available else True
|
||||
)
|
||||
|
||||
if is_available and (
|
||||
on_condition.condition is None or on_condition.condition.evaluate(agent.context_variables)
|
||||
):
|
||||
# Condition has been met, we'll set the Tool Executor's next target
|
||||
# attribute and that will be picked up on the next iteration when
|
||||
# _determine_next_agent is called
|
||||
for agent in agent._group_manager.groupchat.agents: # type: ignore[attr-defined]
|
||||
if isinstance(agent, GroupToolExecutor):
|
||||
agent.set_next_target(on_condition.target)
|
||||
break
|
||||
|
||||
transfer_name = on_condition.target.display_name()
|
||||
|
||||
return True, "[Handing off to " + transfer_name + "]"
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def _create_on_condition_handoff_function(target: TransitionTarget) -> Callable[[], TransitionTarget]:
|
||||
"""Creates a function that will be used by the tool call reply function when the condition is met.
|
||||
|
||||
Args:
|
||||
target (TransitionTarget): The target to transfer to.
|
||||
|
||||
Returns:
|
||||
Callable: The transfer function.
|
||||
"""
|
||||
|
||||
def transfer_to_target() -> TransitionTarget:
|
||||
return target
|
||||
|
||||
return transfer_to_target
|
||||
|
||||
|
||||
def create_on_condition_handoff_functions(agent: "ConversableAgent") -> None:
|
||||
"""Creates the functions for the OnConditions so that the current tool handling works.
|
||||
|
||||
Args:
|
||||
agent ("ConversableAgent"): The agent to create the functions for.
|
||||
"""
|
||||
# Populate the function names for the handoffs
|
||||
agent.handoffs.set_llm_function_names()
|
||||
|
||||
# Create a function for each OnCondition
|
||||
for on_condition in agent.handoffs.llm_conditions:
|
||||
# Create a function that will be called when the condition is met
|
||||
agent._add_single_function(
|
||||
_create_on_condition_handoff_function(on_condition.target),
|
||||
on_condition.llm_function_name,
|
||||
on_condition.condition.get_prompt(agent, []),
|
||||
)
|
||||
|
||||
|
||||
def ensure_handoff_agents_in_group(agents: list["ConversableAgent"]) -> None:
|
||||
"""Ensure the agents in handoffs are in the group chat."""
|
||||
agent_names = [agent.name for agent in agents]
|
||||
for agent in agents:
|
||||
for llm_conditions in agent.handoffs.llm_conditions:
|
||||
if (
|
||||
isinstance(llm_conditions.target, (AgentTarget, AgentNameTarget))
|
||||
and llm_conditions.target.agent_name not in agent_names
|
||||
):
|
||||
raise ValueError("Agent in OnCondition Hand-offs must be in the agents list")
|
||||
for context_conditions in agent.handoffs.context_conditions:
|
||||
if (
|
||||
isinstance(context_conditions.target, (AgentTarget, AgentNameTarget))
|
||||
and context_conditions.target.agent_name not in agent_names
|
||||
):
|
||||
raise ValueError("Agent in OnContextCondition Hand-offs must be in the agents list")
|
||||
# Check after_works targets
|
||||
for after_work_condition in agent.handoffs.after_works:
|
||||
if (
|
||||
isinstance(after_work_condition.target, (AgentTarget, AgentNameTarget))
|
||||
and after_work_condition.target.agent_name not in agent_names
|
||||
):
|
||||
raise ValueError("Agent in after work target Hand-offs must be in the agents list")
|
||||
|
||||
|
||||
def prepare_exclude_transit_messages(agents: list["ConversableAgent"]) -> None:
|
||||
"""Preparation for excluding transit messages by getting all tool names and registering a hook on agents to remove those messages."""
|
||||
# get all transit functions names
|
||||
to_be_removed: list[str] = []
|
||||
for agent in agents:
|
||||
for on_condition in agent.handoffs.llm_conditions:
|
||||
if on_condition.llm_function_name:
|
||||
to_be_removed.append(on_condition.llm_function_name)
|
||||
else:
|
||||
raise ValueError("OnCondition must have a function name")
|
||||
|
||||
remove_function = make_remove_function(to_be_removed)
|
||||
|
||||
# register hook to remove transit messages for group agents
|
||||
for agent in agents:
|
||||
agent.register_hook("process_all_messages_before_reply", remove_function)
|
||||
|
||||
|
||||
def prepare_group_agents(
|
||||
agents: list["ConversableAgent"],
|
||||
context_variables: ContextVariables,
|
||||
exclude_transit_message: bool = True,
|
||||
) -> tuple[GroupToolExecutor, list["ConversableAgent"]]:
|
||||
"""Validates agents, create the tool executor, wrap necessary targets in agents.
|
||||
|
||||
Args:
|
||||
agents (list["ConversableAgent"]): List of all agents in the conversation.
|
||||
context_variables (ContextVariables): Context variables to assign to all agents.
|
||||
exclude_transit_message (bool): Whether to exclude transit messages from the agents.
|
||||
|
||||
Returns:
|
||||
"ConversableAgent": The tool executor agent.
|
||||
list["ConversableAgent"]: List of wrapped agents.
|
||||
"""
|
||||
# Initialise all agents as group agents
|
||||
for agent in agents:
|
||||
if not hasattr(agent, "_group_is_established"):
|
||||
establish_group_agent(agent)
|
||||
|
||||
# Ensure all agents in hand-off after-works are in the passed in agents list
|
||||
ensure_handoff_agents_in_group(agents)
|
||||
|
||||
# Create Tool Executor for the group
|
||||
tool_execution = GroupToolExecutor()
|
||||
|
||||
# Wrap handoff targets in agents that need to be wrapped
|
||||
wrapped_chat_agents: list["ConversableAgent"] = []
|
||||
for agent in agents:
|
||||
wrap_agent_handoff_targets(agent, wrapped_chat_agents)
|
||||
|
||||
# Create the functions for the OnConditions so that the current tool handling works
|
||||
for agent in agents:
|
||||
create_on_condition_handoff_functions(agent)
|
||||
|
||||
# Register all the agents' functions with the tool executor and
|
||||
# use dependency injection for the context variables parameter
|
||||
# Update tool execution agent with all the functions from all the agents
|
||||
tool_execution.register_agents_functions(agents + wrapped_chat_agents, context_variables)
|
||||
|
||||
if exclude_transit_message:
|
||||
prepare_exclude_transit_messages(agents)
|
||||
|
||||
return tool_execution, wrapped_chat_agents
|
||||
|
||||
|
||||
def wrap_agent_handoff_targets(agent: "ConversableAgent", wrapped_agent_list: list["ConversableAgent"]) -> None:
|
||||
"""Wrap handoff targets in agents that need to be wrapped to be part of the group chat.
|
||||
|
||||
Example is NestedChatTarget.
|
||||
|
||||
Args:
|
||||
agent ("ConversableAgent"): The agent to wrap the handoff targets for.
|
||||
wrapped_agent_list (list["ConversableAgent"]): List of wrapped chat agents that will be appended to.
|
||||
"""
|
||||
# Wrap OnCondition targets
|
||||
for i, handoff_oncondition_requiring_wrapping in enumerate(agent.handoffs.get_llm_conditions_requiring_wrapping()):
|
||||
# Create wrapper agent
|
||||
wrapper_agent = handoff_oncondition_requiring_wrapping.target.create_wrapper_agent(parent_agent=agent, index=i)
|
||||
wrapped_agent_list.append(wrapper_agent)
|
||||
|
||||
# Change this handoff target to point to the newly created agent
|
||||
handoff_oncondition_requiring_wrapping.target = AgentTarget(wrapper_agent)
|
||||
|
||||
for i, handoff_oncontextcondition_requiring_wrapping in enumerate(
|
||||
agent.handoffs.get_context_conditions_requiring_wrapping()
|
||||
):
|
||||
# Create wrapper agent
|
||||
wrapper_agent = handoff_oncontextcondition_requiring_wrapping.target.create_wrapper_agent(
|
||||
parent_agent=agent, index=i
|
||||
)
|
||||
wrapped_agent_list.append(wrapper_agent)
|
||||
|
||||
# Change this handoff target to point to the newly created agent
|
||||
handoff_oncontextcondition_requiring_wrapping.target = AgentTarget(wrapper_agent)
|
||||
|
||||
|
||||
def process_initial_messages(
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
agents: list["ConversableAgent"],
|
||||
wrapped_agents: list["ConversableAgent"],
|
||||
) -> tuple[list[dict[str, Any]], Optional["ConversableAgent"], list[str], list[Agent]]:
|
||||
"""Process initial messages, validating agent names against messages, and determining the last agent to speak.
|
||||
|
||||
Args:
|
||||
messages: Initial messages to process.
|
||||
user_agent: Optional user proxy agent passed in to a_/initiate_group_chat.
|
||||
agents: Agents in the group.
|
||||
wrapped_agents: List of wrapped agents.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: Processed message(s).
|
||||
Agent: Last agent to speak.
|
||||
list[str]: List of agent names.
|
||||
list[Agent]: List of temporary user proxy agents to add to GroupChat.
|
||||
"""
|
||||
from ..conversable_agent import ConversableAgent # NEED SOLUTION
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
group_agent_names = [agent.name for agent in agents + wrapped_agents]
|
||||
|
||||
# If there's only one message and there's no identified group agent
|
||||
# Start with a user proxy agent, creating one if they haven't passed one in
|
||||
last_agent: Optional[ConversableAgent]
|
||||
temp_user_proxy: Optional[ConversableAgent] = None
|
||||
temp_user_list: list[Agent] = []
|
||||
if len(messages) == 1 and "name" not in messages[0] and not user_agent:
|
||||
temp_user_proxy = ConversableAgent(name="_User", code_execution_config=False, human_input_mode="ALWAYS")
|
||||
last_agent = temp_user_proxy
|
||||
temp_user_list.append(temp_user_proxy)
|
||||
else:
|
||||
last_message = messages[0]
|
||||
if "name" in last_message:
|
||||
if last_message["name"] in group_agent_names:
|
||||
last_agent = next(agent for agent in agents + wrapped_agents if agent.name == last_message["name"]) # type: ignore[assignment]
|
||||
elif user_agent and last_message["name"] == user_agent.name:
|
||||
last_agent = user_agent
|
||||
else:
|
||||
raise ValueError(f"Invalid group agent name in last message: {last_message['name']}")
|
||||
else:
|
||||
last_agent = user_agent if user_agent else temp_user_proxy
|
||||
|
||||
return messages, last_agent, group_agent_names, temp_user_list
|
||||
|
||||
|
||||
def setup_context_variables(
|
||||
tool_execution: "ConversableAgent",
|
||||
agents: list["ConversableAgent"],
|
||||
manager: GroupChatManager,
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
context_variables: ContextVariables,
|
||||
) -> None:
|
||||
"""Assign a common context_variables reference to all agents in the group, including the tool executor, group chat manager, and user proxy agent.
|
||||
|
||||
Args:
|
||||
tool_execution: The tool execution agent.
|
||||
agents: List of all agents in the conversation.
|
||||
manager: GroupChatManager instance.
|
||||
user_agent: Optional user proxy agent.
|
||||
context_variables: Context variables to assign to all agents.
|
||||
"""
|
||||
for agent in agents + [tool_execution] + [manager] + ([user_agent] if user_agent else []):
|
||||
agent.context_variables = context_variables
|
||||
|
||||
|
||||
def cleanup_temp_user_messages(chat_result: Any) -> None:
|
||||
"""Remove temporary user proxy agent name from messages before returning.
|
||||
|
||||
Args:
|
||||
chat_result: ChatResult instance.
|
||||
"""
|
||||
for message in chat_result.chat_history:
|
||||
if "name" in message and message["name"] == "_User":
|
||||
del message["name"]
|
||||
|
||||
|
||||
def get_last_agent_speaker(
|
||||
groupchat: GroupChat, group_agent_names: list[str], tool_executor: GroupToolExecutor
|
||||
) -> Agent:
|
||||
"""Get the last group agent from the group chat messages. Not including the tool executor."""
|
||||
last_group_speaker = None
|
||||
for message in reversed(groupchat.messages):
|
||||
if "name" in message and message["name"] in group_agent_names and message["name"] != tool_executor.name:
|
||||
agent = groupchat.agent_by_name(name=message["name"])
|
||||
if agent:
|
||||
last_group_speaker = agent
|
||||
break
|
||||
if last_group_speaker is None:
|
||||
raise ValueError("No group agent found in the message history")
|
||||
|
||||
return last_group_speaker
|
||||
|
||||
|
||||
def determine_next_agent(
|
||||
last_speaker: "ConversableAgent",
|
||||
groupchat: GroupChat,
|
||||
initial_agent: "ConversableAgent",
|
||||
use_initial_agent: bool,
|
||||
tool_executor: GroupToolExecutor,
|
||||
group_agent_names: list[str],
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
group_after_work: TransitionTarget,
|
||||
) -> Optional[Union[Agent, str]]:
|
||||
"""Determine the next agent in the conversation.
|
||||
|
||||
Args:
|
||||
last_speaker ("ConversableAgent"): The last agent to speak.
|
||||
groupchat (GroupChat): GroupChat instance.
|
||||
initial_agent ("ConversableAgent"): The initial agent in the conversation.
|
||||
use_initial_agent (bool): Whether to use the initial agent straight away.
|
||||
tool_executor ("ConversableAgent"): The tool execution agent.
|
||||
group_agent_names (list[str]): List of agent names.
|
||||
user_agent (UserProxyAgent): Optional user proxy agent.
|
||||
group_after_work (TransitionTarget): Group-level Transition option when an agent doesn't select the next agent.
|
||||
|
||||
Returns:
|
||||
Optional[Union[Agent, str]]: The next agent or speaker selection method.
|
||||
"""
|
||||
|
||||
# Logic for determining the next target (anything based on Transition Target: an agent, wrapped agent, TerminateTarget, StayTarget, RevertToUserTarget, GroupManagerTarget, etc.
|
||||
# 1. If it's the first response -> initial agent
|
||||
# 2. If the last message is a tool call -> tool execution agent
|
||||
# 3. If the Tool Executor has determined a next target (e.g. ReplyResult specified target) -> transition to tool reply target
|
||||
# 4. If the user last spoke -> return to the previous agent
|
||||
# NOW "AFTER WORK":
|
||||
# 5. Get the After Work condition (if the agent doesn't have one, get the group-level one)
|
||||
# 6. Resolve and return the After Work condition -> agent / wrapped agent / TerminateTarget / StayTarget / RevertToUserTarget / GroupManagerTarget / etc.
|
||||
|
||||
# 1. If it's the first response, return the initial agent
|
||||
if use_initial_agent:
|
||||
return initial_agent
|
||||
|
||||
# 2. If the last message is a tool call, return the tool execution agent
|
||||
if "tool_calls" in groupchat.messages[-1]:
|
||||
return tool_executor
|
||||
|
||||
# 3. If the Tool Executor has determined a next target, return that
|
||||
if tool_executor.has_next_target():
|
||||
next_agent = tool_executor.get_next_target()
|
||||
tool_executor.clear_next_target()
|
||||
|
||||
if next_agent.can_resolve_for_speaker_selection():
|
||||
return next_agent.resolve(groupchat, last_speaker, user_agent).get_speaker_selection_result(groupchat)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Tool Executor next target must be a valid TransitionTarget that can resolve for speaker selection."
|
||||
)
|
||||
|
||||
# get the last group agent
|
||||
last_agent_speaker = get_last_agent_speaker(groupchat, group_agent_names, tool_executor)
|
||||
|
||||
# If we are returning from a tool execution, return to the last agent that spoke
|
||||
if groupchat.messages[-1]["role"] == "tool":
|
||||
return last_agent_speaker
|
||||
|
||||
# If the user last spoke, return to the agent prior to them (if they don't have an after work, otherwise it's treated like any other agent)
|
||||
if user_agent and last_speaker == user_agent:
|
||||
if not user_agent.handoffs.after_works:
|
||||
return last_agent_speaker
|
||||
else:
|
||||
last_agent_speaker = user_agent
|
||||
|
||||
# AFTER WORK:
|
||||
|
||||
# First, try to evaluate after_works context conditions
|
||||
after_works_result = _evaluate_after_works_conditions(
|
||||
last_agent_speaker, # type: ignore[arg-type]
|
||||
groupchat,
|
||||
user_agent,
|
||||
)
|
||||
if after_works_result is not None:
|
||||
return after_works_result
|
||||
|
||||
# If no after_works conditions matched, use the group-level after_work
|
||||
# Resolve the next agent, termination, or speaker selection method
|
||||
resolved_speaker_selection_result = group_after_work.resolve(
|
||||
groupchat,
|
||||
last_agent_speaker, # type: ignore[arg-type]
|
||||
user_agent,
|
||||
).get_speaker_selection_result(groupchat)
|
||||
|
||||
return resolved_speaker_selection_result
|
||||
|
||||
|
||||
def create_group_transition(
|
||||
initial_agent: "ConversableAgent",
|
||||
tool_execution: GroupToolExecutor,
|
||||
group_agent_names: list[str],
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
group_after_work: TransitionTarget,
|
||||
) -> Callable[["ConversableAgent", GroupChat], Optional[Union[Agent, str]]]:
|
||||
"""Creates a transition function for group chat with enclosed state for the use_initial_agent.
|
||||
|
||||
Args:
|
||||
initial_agent ("ConversableAgent"): The first agent to speak
|
||||
tool_execution (GroupToolExecutor): The tool execution agent
|
||||
group_agent_names (list[str]): List of all agent names
|
||||
user_agent (UserProxyAgent): Optional user proxy agent
|
||||
group_after_work (TransitionTarget): Group-level after work
|
||||
|
||||
Returns:
|
||||
Callable[["ConversableAgent", GroupChat], Optional[Union[Agent, str]]]: The transition function
|
||||
"""
|
||||
# Create enclosed state, this will be set once per creation so will only be True on the first execution
|
||||
# of group_transition
|
||||
state = {"use_initial_agent": True}
|
||||
|
||||
def group_transition(last_speaker: "ConversableAgent", groupchat: GroupChat) -> Optional[Union[Agent, str]]:
|
||||
result = determine_next_agent(
|
||||
last_speaker=last_speaker,
|
||||
groupchat=groupchat,
|
||||
initial_agent=initial_agent,
|
||||
use_initial_agent=state["use_initial_agent"],
|
||||
tool_executor=tool_execution,
|
||||
group_agent_names=group_agent_names,
|
||||
user_agent=user_agent,
|
||||
group_after_work=group_after_work,
|
||||
)
|
||||
state["use_initial_agent"] = False
|
||||
return result
|
||||
|
||||
return group_transition
|
||||
|
||||
|
||||
def create_group_manager(
|
||||
groupchat: GroupChat,
|
||||
group_manager_args: Optional[dict[str, Any]],
|
||||
agents: list["ConversableAgent"],
|
||||
group_after_work: TransitionTarget,
|
||||
) -> GroupChatManager:
|
||||
"""Create a GroupChatManager for the group chat utilising any arguments passed in and ensure an LLM Config exists if needed
|
||||
|
||||
Args:
|
||||
groupchat (GroupChat): The groupchat.
|
||||
group_manager_args (dict[str, Any]): Group manager arguments to create the GroupChatManager.
|
||||
agents (list["ConversableAgent"]): List of agents in the group to check handoffs and after work.
|
||||
group_after_work (TransitionTarget): Group-level after work to check.
|
||||
|
||||
Returns:
|
||||
GroupChatManager: GroupChatManager instance.
|
||||
"""
|
||||
manager_args = (group_manager_args or {}).copy()
|
||||
if "groupchat" in manager_args:
|
||||
raise ValueError("'groupchat' cannot be specified in group_manager_args as it is set by initiate_group_chat")
|
||||
manager = GroupChatManager(groupchat, **manager_args)
|
||||
|
||||
# Ensure that our manager has an LLM Config if we have any GroupManagerTarget targets used
|
||||
if manager.llm_config is False:
|
||||
has_group_manager_target = False
|
||||
|
||||
if isinstance(group_after_work, GroupManagerTarget):
|
||||
# Check group after work
|
||||
has_group_manager_target = True
|
||||
else:
|
||||
# Check agent hand-offs and after work
|
||||
for agent in agents:
|
||||
if (
|
||||
len(agent.handoffs.get_context_conditions_by_target_type(GroupManagerTarget)) > 0
|
||||
or len(agent.handoffs.get_llm_conditions_by_target_type(GroupManagerTarget)) > 0
|
||||
or any(isinstance(aw.target, GroupManagerTarget) for aw in agent.handoffs.after_works)
|
||||
):
|
||||
has_group_manager_target = True
|
||||
break
|
||||
|
||||
if has_group_manager_target:
|
||||
raise ValueError(
|
||||
"The group manager doesn't have an LLM Config and it is required for any targets or after works using a GroupManagerTarget. Use the 'llm_config' in the group_manager_args parameter to specify the LLM Config for the group manager."
|
||||
)
|
||||
|
||||
return manager
|
||||
|
||||
|
||||
def make_remove_function(tool_msgs_to_remove: list[str]) -> Callable[[list[dict[str, Any]]], list[dict[str, Any]]]:
|
||||
"""Create a function to remove messages with tool calls from the messages list.
|
||||
|
||||
The returned function can be registered as a hook to "process_all_messages_before_reply"" to remove messages with tool calls.
|
||||
"""
|
||||
|
||||
def remove_messages(messages: list[dict[str, Any]], tool_msgs_to_remove: list[str]) -> list[dict[str, Any]]:
|
||||
copied = copy.deepcopy(messages)
|
||||
new_messages = []
|
||||
removed_tool_ids = []
|
||||
for message in copied:
|
||||
# remove tool calls
|
||||
if message.get("tool_calls") is not None:
|
||||
filtered_tool_calls = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
if tool_call.get("function") is not None and tool_call["function"]["name"] in tool_msgs_to_remove:
|
||||
# remove
|
||||
removed_tool_ids.append(tool_call["id"])
|
||||
else:
|
||||
filtered_tool_calls.append(tool_call)
|
||||
if len(filtered_tool_calls) > 0:
|
||||
message["tool_calls"] = filtered_tool_calls
|
||||
else:
|
||||
del message["tool_calls"]
|
||||
if (
|
||||
message.get("content") is None
|
||||
or message.get("content") == ""
|
||||
or message.get("content") == "None"
|
||||
):
|
||||
continue # if no tool call and no content, skip this message
|
||||
# else: keep the message with tool_calls removed
|
||||
# remove corresponding tool responses
|
||||
elif message.get("tool_responses") is not None:
|
||||
filtered_tool_responses = []
|
||||
for tool_response in message["tool_responses"]:
|
||||
if tool_response["tool_call_id"] not in removed_tool_ids:
|
||||
filtered_tool_responses.append(tool_response)
|
||||
|
||||
if len(filtered_tool_responses) > 0:
|
||||
message["tool_responses"] = filtered_tool_responses
|
||||
else:
|
||||
continue
|
||||
|
||||
new_messages.append(message)
|
||||
|
||||
return new_messages
|
||||
|
||||
return partial(remove_messages, tool_msgs_to_remove=tool_msgs_to_remove)
|
||||
320
mm_agents/coact/autogen/agentchat/group/handoffs.py
Normal file
320
mm_agents/coact/autogen/agentchat/group/handoffs.py
Normal file
@@ -0,0 +1,320 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Union, overload
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .on_condition import OnCondition
|
||||
from .on_context_condition import OnContextCondition
|
||||
from .targets.transition_target import TransitionTarget
|
||||
|
||||
__all__ = ["Handoffs"]
|
||||
|
||||
|
||||
class Handoffs(BaseModel):
|
||||
"""
|
||||
Container for all handoff transition conditions of a ConversableAgent.
|
||||
|
||||
Three types of conditions can be added, each with a different order and time of use:
|
||||
1. OnContextConditions (evaluated without an LLM)
|
||||
2. OnConditions (evaluated with an LLM)
|
||||
3. After work TransitionTarget (if no other transition is triggered)
|
||||
|
||||
Supports method chaining:
|
||||
agent.handoffs.add_context_conditions([condition1]) \
|
||||
.add_llm_condition(condition2) \
|
||||
.set_after_work(after_work)
|
||||
"""
|
||||
|
||||
context_conditions: list[OnContextCondition] = Field(default_factory=list)
|
||||
llm_conditions: list[OnCondition] = Field(default_factory=list)
|
||||
after_works: list[OnContextCondition] = Field(default_factory=list)
|
||||
|
||||
def add_context_condition(self, condition: OnContextCondition) -> "Handoffs":
|
||||
"""
|
||||
Add a single context condition.
|
||||
|
||||
Args:
|
||||
condition: The OnContextCondition to add
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
# Validate that it is an OnContextCondition
|
||||
if not isinstance(condition, OnContextCondition):
|
||||
raise TypeError(f"Expected an OnContextCondition instance, got {type(condition).__name__}")
|
||||
|
||||
self.context_conditions.append(condition)
|
||||
return self
|
||||
|
||||
def add_context_conditions(self, conditions: list[OnContextCondition]) -> "Handoffs":
|
||||
"""
|
||||
Add multiple context conditions.
|
||||
|
||||
Args:
|
||||
conditions: List of OnContextConditions to add
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
# Validate that it is a list of OnContextConditions
|
||||
if not all(isinstance(condition, OnContextCondition) for condition in conditions):
|
||||
raise TypeError("All conditions must be of type OnContextCondition")
|
||||
|
||||
self.context_conditions.extend(conditions)
|
||||
return self
|
||||
|
||||
def add_llm_condition(self, condition: OnCondition) -> "Handoffs":
|
||||
"""
|
||||
Add a single LLM condition.
|
||||
|
||||
Args:
|
||||
condition: The OnCondition to add
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
# Validate that it is an OnCondition
|
||||
if not isinstance(condition, OnCondition):
|
||||
raise TypeError(f"Expected an OnCondition instance, got {type(condition).__name__}")
|
||||
|
||||
self.llm_conditions.append(condition)
|
||||
return self
|
||||
|
||||
def add_llm_conditions(self, conditions: list[OnCondition]) -> "Handoffs":
|
||||
"""
|
||||
Add multiple LLM conditions.
|
||||
|
||||
Args:
|
||||
conditions: List of OnConditions to add
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
# Validate that it is a list of OnConditions
|
||||
if not all(isinstance(condition, OnCondition) for condition in conditions):
|
||||
raise TypeError("All conditions must be of type OnCondition")
|
||||
|
||||
self.llm_conditions.extend(conditions)
|
||||
return self
|
||||
|
||||
def set_after_work(self, target: TransitionTarget) -> "Handoffs":
|
||||
"""
|
||||
Set the after work target (replaces all after_works with single entry).
|
||||
|
||||
For backward compatibility, this creates an OnContextCondition with no condition (always true).
|
||||
|
||||
Args:
|
||||
target: The after work TransitionTarget to set
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
if not isinstance(target, TransitionTarget):
|
||||
raise TypeError(f"Expected a TransitionTarget instance, got {type(target).__name__}")
|
||||
|
||||
# Create OnContextCondition with no condition (always true)
|
||||
after_work_condition = OnContextCondition(target=target, condition=None)
|
||||
self.after_works = [after_work_condition]
|
||||
return self
|
||||
|
||||
def add_after_work(self, condition: OnContextCondition) -> "Handoffs":
|
||||
"""
|
||||
Add a single after-work condition.
|
||||
|
||||
If the condition has condition=None, it will replace any existing
|
||||
condition=None entry and be placed at the end.
|
||||
|
||||
Args:
|
||||
condition: The OnContextCondition to add
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
if not isinstance(condition, OnContextCondition):
|
||||
raise TypeError(f"Expected an OnContextCondition instance, got {type(condition).__name__}")
|
||||
|
||||
if condition.condition is None:
|
||||
# Remove any existing condition=None entries
|
||||
self.after_works = [c for c in self.after_works if c.condition is not None]
|
||||
# Add the new one at the end
|
||||
self.after_works.append(condition)
|
||||
else:
|
||||
# For regular conditions, check if we need to move condition=None to the end
|
||||
none_conditions = [c for c in self.after_works if c.condition is None]
|
||||
if none_conditions:
|
||||
# Remove the None condition temporarily
|
||||
self.after_works = [c for c in self.after_works if c.condition is not None]
|
||||
# Add the new regular condition
|
||||
self.after_works.append(condition)
|
||||
# Re-add the None condition at the end
|
||||
self.after_works.append(none_conditions[0])
|
||||
else:
|
||||
# No None condition exists, just append
|
||||
self.after_works.append(condition)
|
||||
|
||||
return self
|
||||
|
||||
def add_after_works(self, conditions: list[OnContextCondition]) -> "Handoffs":
|
||||
"""
|
||||
Add multiple after-work conditions.
|
||||
|
||||
Special handling for condition=None entries:
|
||||
- Only one condition=None entry is allowed (the fallback)
|
||||
- It will always be placed at the end of the list
|
||||
- If multiple condition=None entries are provided, only the last one is kept
|
||||
|
||||
Args:
|
||||
conditions: List of OnContextConditions to add
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
# Validate that it is a list of OnContextConditions
|
||||
if not all(isinstance(condition, OnContextCondition) for condition in conditions):
|
||||
raise TypeError("All conditions must be of type OnContextCondition")
|
||||
|
||||
# Separate conditions with None and without None
|
||||
none_conditions = [c for c in conditions if c.condition is None]
|
||||
regular_conditions = [c for c in conditions if c.condition is not None]
|
||||
|
||||
# Remove any existing condition=None entries
|
||||
self.after_works = [c for c in self.after_works if c.condition is not None]
|
||||
|
||||
# Add regular conditions
|
||||
self.after_works.extend(regular_conditions)
|
||||
|
||||
# Add at most one None condition at the end
|
||||
if none_conditions:
|
||||
self.after_works.append(none_conditions[-1]) # Use the last one if multiple provided
|
||||
|
||||
return self
|
||||
|
||||
@overload
|
||||
def add(self, condition: OnContextCondition) -> "Handoffs": ...
|
||||
|
||||
@overload
|
||||
def add(self, condition: OnCondition) -> "Handoffs": ...
|
||||
|
||||
def add(self, condition: Union[OnContextCondition, OnCondition]) -> "Handoffs":
|
||||
"""
|
||||
Add a single condition (OnContextCondition or OnCondition).
|
||||
|
||||
Args:
|
||||
condition: The condition to add (OnContextCondition or OnCondition)
|
||||
|
||||
Raises:
|
||||
TypeError: If the condition type is not supported
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
# This add method is a helper method designed to make it easier for
|
||||
# adding handoffs without worrying about the specific type.
|
||||
if isinstance(condition, OnContextCondition):
|
||||
return self.add_context_condition(condition)
|
||||
elif isinstance(condition, OnCondition):
|
||||
return self.add_llm_condition(condition)
|
||||
else:
|
||||
raise TypeError(f"Unsupported condition type: {type(condition).__name__}")
|
||||
|
||||
def add_many(self, conditions: list[Union[OnContextCondition, OnCondition]]) -> "Handoffs":
|
||||
"""
|
||||
Add multiple conditions of any supported types (OnContextCondition and OnCondition).
|
||||
|
||||
Args:
|
||||
conditions: List of conditions to add
|
||||
|
||||
Raises:
|
||||
TypeError: If an unsupported condition type is provided
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
# This add_many method is a helper method designed to make it easier for
|
||||
# adding handoffs without worrying about the specific type.
|
||||
context_conditions = []
|
||||
llm_conditions = []
|
||||
|
||||
for condition in conditions:
|
||||
if isinstance(condition, OnContextCondition):
|
||||
context_conditions.append(condition)
|
||||
elif isinstance(condition, OnCondition):
|
||||
llm_conditions.append(condition)
|
||||
else:
|
||||
raise TypeError(f"Unsupported condition type: {type(condition).__name__}")
|
||||
|
||||
if context_conditions:
|
||||
self.add_context_conditions(context_conditions)
|
||||
if llm_conditions:
|
||||
self.add_llm_conditions(llm_conditions)
|
||||
|
||||
return self
|
||||
|
||||
def clear(self) -> "Handoffs":
|
||||
"""
|
||||
Clear all handoff conditions.
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
self.context_conditions.clear()
|
||||
self.llm_conditions.clear()
|
||||
self.after_works.clear()
|
||||
return self
|
||||
|
||||
def get_llm_conditions_by_target_type(self, target_type: type) -> list[OnCondition]:
|
||||
"""
|
||||
Get OnConditions for a specific target type.
|
||||
|
||||
Args:
|
||||
target_type: The type of condition to retrieve
|
||||
|
||||
Returns:
|
||||
List of conditions of the specified type, or None if none exist
|
||||
"""
|
||||
return [on_condition for on_condition in self.llm_conditions if on_condition.has_target_type(target_type)]
|
||||
|
||||
def get_context_conditions_by_target_type(self, target_type: type) -> list[OnContextCondition]:
|
||||
"""
|
||||
Get OnContextConditions for a specific target type.
|
||||
|
||||
Args:
|
||||
target_type: The type of condition to retrieve
|
||||
|
||||
Returns:
|
||||
List of conditions of the specified type, or None if none exist
|
||||
"""
|
||||
return [
|
||||
on_context_condition
|
||||
for on_context_condition in self.context_conditions
|
||||
if on_context_condition.has_target_type(target_type)
|
||||
]
|
||||
|
||||
def get_llm_conditions_requiring_wrapping(self) -> list[OnCondition]:
|
||||
"""
|
||||
Get LLM conditions that have targets that require wrapping.
|
||||
|
||||
Returns:
|
||||
List of LLM conditions that require wrapping
|
||||
"""
|
||||
return [condition for condition in self.llm_conditions if condition.target_requires_wrapping()]
|
||||
|
||||
def get_context_conditions_requiring_wrapping(self) -> list[OnContextCondition]:
|
||||
"""
|
||||
Get context conditions that have targets that require wrapping.
|
||||
|
||||
Returns:
|
||||
List of context conditions that require wrapping
|
||||
"""
|
||||
return [condition for condition in self.context_conditions if condition.target_requires_wrapping()]
|
||||
|
||||
def set_llm_function_names(self) -> None:
|
||||
"""
|
||||
Set the LLM function names for all LLM conditions, creating unique names for each function.
|
||||
"""
|
||||
for i, condition in enumerate(self.llm_conditions):
|
||||
# Function names are made unique and allow multiple OnCondition's to the same agent
|
||||
condition.llm_function_name = f"transfer_to_{condition.target.normalized_name()}_{i + 1}"
|
||||
93
mm_agents/coact/autogen/agentchat/group/llm_condition.py
Normal file
93
mm_agents/coact/autogen/agentchat/group/llm_condition.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .context_str import ContextStr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Avoid circular import
|
||||
from ..conversable_agent import ConversableAgent
|
||||
|
||||
__all__ = ["ContextStrLLMCondition", "LLMCondition", "StringLLMCondition"]
|
||||
|
||||
|
||||
class LLMCondition(BaseModel):
|
||||
"""Protocol for conditions evaluated by an LLM."""
|
||||
|
||||
def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str:
|
||||
"""Get the prompt text for LLM evaluation.
|
||||
|
||||
Args:
|
||||
agent: The agent evaluating the condition
|
||||
messages: The conversation history
|
||||
|
||||
Returns:
|
||||
The prompt text to be evaluated by the LLM
|
||||
"""
|
||||
raise NotImplementedError("Requires subclasses to implement.")
|
||||
|
||||
|
||||
class StringLLMCondition(LLMCondition):
|
||||
"""Simple string-based LLM condition.
|
||||
|
||||
This condition provides a static string prompt to be evaluated by an LLM.
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
|
||||
def __init__(self, prompt: str, **data: Any) -> None:
|
||||
"""Initialize with a prompt string as a positional parameter.
|
||||
|
||||
Args:
|
||||
prompt: The static prompt string to evaluate
|
||||
data: Additional data for the parent class
|
||||
"""
|
||||
super().__init__(prompt=prompt, **data)
|
||||
|
||||
def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str:
|
||||
"""Return the static prompt string.
|
||||
|
||||
Args:
|
||||
agent: The agent evaluating the condition (not used)
|
||||
messages: The conversation history (not used)
|
||||
|
||||
Returns:
|
||||
The static prompt string
|
||||
"""
|
||||
return self.prompt
|
||||
|
||||
|
||||
class ContextStrLLMCondition(LLMCondition):
|
||||
"""Context variable-based LLM condition.
|
||||
|
||||
This condition uses a ContextStr object with context variable placeholders that
|
||||
will be substituted before being evaluated by an LLM.
|
||||
"""
|
||||
|
||||
context_str: ContextStr
|
||||
|
||||
def __init__(self, context_str: ContextStr, **data: Any) -> None:
|
||||
"""Initialize with a context string as a positional parameter.
|
||||
|
||||
Args:
|
||||
context_str: The ContextStr object with variable placeholders
|
||||
data: Additional data for the parent class
|
||||
"""
|
||||
super().__init__(context_str=context_str, **data)
|
||||
|
||||
def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str:
|
||||
"""Return the prompt with context variables substituted.
|
||||
|
||||
Args:
|
||||
agent: The agent evaluating the condition (provides context variables)
|
||||
messages: The conversation history (not used)
|
||||
|
||||
Returns:
|
||||
The prompt with context variables substituted
|
||||
"""
|
||||
result = self.context_str.format(agent.context_variables)
|
||||
return result if result is not None else ""
|
||||
237
mm_agents/coact/autogen/agentchat/group/multi_agent_chat.py
Normal file
237
mm_agents/coact/autogen/agentchat/group/multi_agent_chat.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from ...events.agent_events import ErrorEvent, RunCompletionEvent
|
||||
from ...io.base import IOStream
|
||||
from ...io.run_response import AsyncRunResponse, AsyncRunResponseProtocol, RunResponse, RunResponseProtocol
|
||||
from ...io.thread_io_stream import AsyncThreadIOStream, ThreadIOStream
|
||||
from ..chat import ChatResult
|
||||
from .context_variables import ContextVariables
|
||||
from .group_utils import cleanup_temp_user_messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agent import Agent
|
||||
from .patterns.pattern import Pattern
|
||||
|
||||
__all__ = [
|
||||
"a_initiate_group_chat",
|
||||
"a_run_group_chat",
|
||||
"initiate_group_chat",
|
||||
"run_group_chat",
|
||||
]
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
def initiate_group_chat(
|
||||
pattern: "Pattern",
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
max_rounds: int = 20,
|
||||
) -> tuple[ChatResult, ContextVariables, "Agent"]:
|
||||
"""Initialize and run a group chat using a pattern for configuration.
|
||||
|
||||
Args:
|
||||
pattern: Pattern object that encapsulates the chat configuration.
|
||||
messages: Initial message(s).
|
||||
max_rounds: Maximum number of conversation rounds.
|
||||
|
||||
Returns:
|
||||
ChatResult: Conversations chat history.
|
||||
ContextVariables: Updated Context variables.
|
||||
"ConversableAgent": Last speaker.
|
||||
"""
|
||||
# Let the pattern prepare the group chat and all its components
|
||||
# Only passing the necessary parameters that aren't already in the pattern
|
||||
(
|
||||
_, # agents,
|
||||
_, # wrapped_agents,
|
||||
_, # user_agent,
|
||||
context_variables,
|
||||
_, # initial_agent,
|
||||
_, # group_after_work,
|
||||
_, # tool_execution,
|
||||
_, # groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
_, # group_agent_names,
|
||||
_, # temp_user_list,
|
||||
) = pattern.prepare_group_chat(
|
||||
max_rounds=max_rounds,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Start or resume the conversation
|
||||
if len(processed_messages) > 1:
|
||||
last_agent, last_message = manager.resume(messages=processed_messages)
|
||||
clear_history = False
|
||||
else:
|
||||
last_message = processed_messages[0]
|
||||
clear_history = True
|
||||
|
||||
if last_agent is None:
|
||||
raise ValueError("No agent selected to start the conversation")
|
||||
|
||||
chat_result = last_agent.initiate_chat(
|
||||
manager,
|
||||
message=last_message,
|
||||
clear_history=clear_history,
|
||||
summary_method=pattern.summary_method,
|
||||
)
|
||||
|
||||
cleanup_temp_user_messages(chat_result)
|
||||
|
||||
return chat_result, context_variables, manager.last_speaker
|
||||
|
||||
|
||||
@export_module("autogen.agentchat")
|
||||
async def a_initiate_group_chat(
|
||||
pattern: "Pattern",
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
max_rounds: int = 20,
|
||||
) -> tuple[ChatResult, ContextVariables, "Agent"]:
|
||||
"""Initialize and run a group chat using a pattern for configuration, asynchronously.
|
||||
|
||||
Args:
|
||||
pattern: Pattern object that encapsulates the chat configuration.
|
||||
messages: Initial message(s).
|
||||
max_rounds: Maximum number of conversation rounds.
|
||||
|
||||
Returns:
|
||||
ChatResult: Conversations chat history.
|
||||
ContextVariables: Updated Context variables.
|
||||
"ConversableAgent": Last speaker.
|
||||
"""
|
||||
# Let the pattern prepare the group chat and all its components
|
||||
# Only passing the necessary parameters that aren't already in the pattern
|
||||
(
|
||||
_, # agents,
|
||||
_, # wrapped_agents,
|
||||
_, # user_agent,
|
||||
context_variables,
|
||||
_, # initial_agent,
|
||||
_, # group_after_work,
|
||||
_, # tool_execution,
|
||||
_, # groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
_, # group_agent_names,
|
||||
_, # temp_user_list,
|
||||
) = pattern.prepare_group_chat(
|
||||
max_rounds=max_rounds,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Start or resume the conversation
|
||||
if len(processed_messages) > 1:
|
||||
last_agent, last_message = await manager.a_resume(messages=processed_messages)
|
||||
clear_history = False
|
||||
else:
|
||||
last_message = processed_messages[0]
|
||||
clear_history = True
|
||||
|
||||
if last_agent is None:
|
||||
raise ValueError("No agent selected to start the conversation")
|
||||
|
||||
chat_result = await last_agent.a_initiate_chat(
|
||||
manager,
|
||||
message=last_message, # type: ignore[arg-type]
|
||||
clear_history=clear_history,
|
||||
summary_method=pattern.summary_method,
|
||||
)
|
||||
|
||||
cleanup_temp_user_messages(chat_result)
|
||||
|
||||
return chat_result, context_variables, manager.last_speaker
|
||||
|
||||
|
||||
@export_module("autogen.agentchat")
|
||||
def run_group_chat(
|
||||
pattern: "Pattern",
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
max_rounds: int = 20,
|
||||
) -> RunResponseProtocol:
|
||||
iostream = ThreadIOStream()
|
||||
# todo: add agents
|
||||
response = RunResponse(iostream, agents=[])
|
||||
|
||||
def _initiate_group_chat(
|
||||
pattern: "Pattern" = pattern,
|
||||
messages: Union[list[dict[str, Any]], str] = messages,
|
||||
max_rounds: int = max_rounds,
|
||||
iostream: ThreadIOStream = iostream,
|
||||
response: RunResponse = response,
|
||||
) -> None:
|
||||
with IOStream.set_default(iostream):
|
||||
try:
|
||||
chat_result, context_vars, agent = initiate_group_chat(
|
||||
pattern=pattern,
|
||||
messages=messages,
|
||||
max_rounds=max_rounds,
|
||||
)
|
||||
|
||||
IOStream.get_default().send(
|
||||
RunCompletionEvent( # type: ignore[call-arg]
|
||||
history=chat_result.chat_history,
|
||||
summary=chat_result.summary,
|
||||
cost=chat_result.cost,
|
||||
last_speaker=agent.name,
|
||||
context_variables=context_vars,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
response.iostream.send(ErrorEvent(error=e)) # type: ignore[call-arg]
|
||||
|
||||
threading.Thread(
|
||||
target=_initiate_group_chat,
|
||||
).start()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@export_module("autogen.agentchat")
|
||||
async def a_run_group_chat(
|
||||
pattern: "Pattern",
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
max_rounds: int = 20,
|
||||
) -> AsyncRunResponseProtocol:
|
||||
iostream = AsyncThreadIOStream()
|
||||
# todo: add agents
|
||||
response = AsyncRunResponse(iostream, agents=[])
|
||||
|
||||
async def _initiate_group_chat(
|
||||
pattern: "Pattern" = pattern,
|
||||
messages: Union[list[dict[str, Any]], str] = messages,
|
||||
max_rounds: int = max_rounds,
|
||||
iostream: AsyncThreadIOStream = iostream,
|
||||
response: AsyncRunResponse = response,
|
||||
) -> None:
|
||||
with IOStream.set_default(iostream):
|
||||
try:
|
||||
chat_result, context_vars, agent = await a_initiate_group_chat(
|
||||
pattern=pattern,
|
||||
messages=messages,
|
||||
max_rounds=max_rounds,
|
||||
)
|
||||
|
||||
IOStream.get_default().send(
|
||||
RunCompletionEvent( # type: ignore[call-arg]
|
||||
history=chat_result.chat_history,
|
||||
summary=chat_result.summary,
|
||||
cost=chat_result.cost,
|
||||
last_speaker=agent.name,
|
||||
context_variables=context_vars,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
response.iostream.send(ErrorEvent(error=e)) # type: ignore[call-arg]
|
||||
|
||||
asyncio.create_task(_initiate_group_chat())
|
||||
|
||||
return response
|
||||
58
mm_agents/coact/autogen/agentchat/group/on_condition.py
Normal file
58
mm_agents/coact/autogen/agentchat/group/on_condition.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from .available_condition import AvailableCondition
|
||||
from .llm_condition import LLMCondition
|
||||
from .targets.transition_target import TransitionTarget
|
||||
|
||||
__all__ = [
|
||||
"OnCondition",
|
||||
]
|
||||
|
||||
|
||||
@export_module("autogen")
|
||||
class OnCondition(BaseModel): # noqa: N801
|
||||
"""Defines a condition for transitioning to another agent or nested chats.
|
||||
|
||||
This is for LLM-based condition evaluation where these conditions are translated into tools and attached to the agent.
|
||||
|
||||
These are evaluated after the OnCondition conditions but before the after work condition.
|
||||
|
||||
Args:
|
||||
target (TransitionTarget): The transition (essentially an agent) to hand off to.
|
||||
condition (LLMCondition): The condition for transitioning to the target agent, evaluated by the LLM.
|
||||
available (AvailableCondition): Optional condition to determine if this OnCondition is included for the LLM to evaluate based on context variables using classes like StringAvailableCondition and ContextExpressionAvailableCondition.
|
||||
llm_function_name (Optional[str]): The name of the LLM function to use for this condition.
|
||||
"""
|
||||
|
||||
target: TransitionTarget
|
||||
condition: LLMCondition
|
||||
available: Optional[AvailableCondition] = None
|
||||
llm_function_name: Optional[str] = None
|
||||
|
||||
def has_target_type(self, target_type: type) -> bool:
|
||||
"""
|
||||
Check if the target type matches the specified type.
|
||||
|
||||
Args:
|
||||
target_type (type): The target type to check against, which should be a subclass of TransitionTarget
|
||||
|
||||
Returns:
|
||||
bool: True if the target type matches, False otherwise
|
||||
"""
|
||||
return isinstance(self.target, target_type)
|
||||
|
||||
def target_requires_wrapping(self) -> bool:
|
||||
"""
|
||||
Check if the target requires wrapping in an agent.
|
||||
|
||||
Returns:
|
||||
bool: True if the target requires wrapping, False otherwise
|
||||
"""
|
||||
return self.target.needs_agent_wrapper()
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .available_condition import AvailableCondition
|
||||
from .context_condition import ContextCondition
|
||||
from .targets.transition_target import TransitionTarget
|
||||
|
||||
__all__ = [
|
||||
"OnContextCondition",
|
||||
]
|
||||
|
||||
|
||||
class OnContextCondition(BaseModel): # noqa: N801
|
||||
"""Defines a condition for transitioning to another agent or nested chats using context variables and the ContextExpression class.
|
||||
|
||||
This is for context variable-based condition evaluation (does not use the agent's LLM).
|
||||
|
||||
These are evaluated before the OnCondition and after work conditions.
|
||||
|
||||
Args:
|
||||
target (TransitionTarget): The transition (essentially an agent) to hand off to.
|
||||
condition (Optional[ContextCondition]): The context variable based condition for transitioning to the target agent. If None, the condition always evaluates to True.
|
||||
available (AvailableCondition): Optional condition to determine if this OnCondition is included for the LLM to evaluate based on context variables using classes like StringAvailableCondition and ContextExpressionAvailableCondition.
|
||||
"""
|
||||
|
||||
target: TransitionTarget
|
||||
condition: Optional[ContextCondition] = None
|
||||
available: Optional[AvailableCondition] = None
|
||||
|
||||
def has_target_type(self, target_type: type) -> bool:
|
||||
"""
|
||||
Check if the target type matches the specified type.
|
||||
|
||||
Args:
|
||||
target_type (type): The target type to check against. Should be a subclass of TransitionTarget.
|
||||
|
||||
Returns:
|
||||
bool: True if the target type matches, False otherwise
|
||||
"""
|
||||
return isinstance(self.target, target_type)
|
||||
|
||||
def target_requires_wrapping(self) -> bool:
|
||||
"""
|
||||
Check if the target requires wrapping in an agent.
|
||||
|
||||
Returns:
|
||||
bool: True if the target requires wrapping, False otherwise
|
||||
"""
|
||||
return self.target.needs_agent_wrapper()
|
||||
18
mm_agents/coact/autogen/agentchat/group/patterns/__init__.py
Normal file
18
mm_agents/coact/autogen/agentchat/group/patterns/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
|
||||
from .auto import AutoPattern
|
||||
from .manual import ManualPattern
|
||||
from .pattern import DefaultPattern
|
||||
from .random import RandomPattern
|
||||
from .round_robin import RoundRobinPattern
|
||||
|
||||
__all__ = [
|
||||
"AutoPattern",
|
||||
"DefaultPattern",
|
||||
"ManualPattern",
|
||||
"RandomPattern",
|
||||
"RoundRobinPattern",
|
||||
]
|
||||
159
mm_agents/coact/autogen/agentchat/group/patterns/auto.py
Normal file
159
mm_agents/coact/autogen/agentchat/group/patterns/auto.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
||||
|
||||
from ..context_variables import ContextVariables
|
||||
from ..targets.group_manager_target import GroupManagerSelectionMessage, GroupManagerTarget
|
||||
from ..targets.transition_target import TransitionTarget
|
||||
from .pattern import Pattern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...conversable_agent import ConversableAgent
|
||||
from ...groupchat import GroupChat, GroupChatManager
|
||||
from ..group_tool_executor import GroupToolExecutor
|
||||
|
||||
|
||||
class AutoPattern(Pattern):
|
||||
"""AutoPattern implements a flexible pattern where agents are selected based on their expertise.
|
||||
|
||||
In this pattern, a group manager automatically selects the next agent to speak based on the context
|
||||
of the conversation and agent descriptions. The after_work is always set to "group_manager" as
|
||||
this is the defining characteristic of this pattern.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_agent: "ConversableAgent",
|
||||
agents: list["ConversableAgent"],
|
||||
user_agent: Optional["ConversableAgent"] = None,
|
||||
group_manager_args: Optional[dict[str, Any]] = None,
|
||||
context_variables: Optional[ContextVariables] = None,
|
||||
selection_message: Optional[GroupManagerSelectionMessage] = None,
|
||||
exclude_transit_message: bool = True,
|
||||
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
|
||||
):
|
||||
"""Initialize the AutoPattern.
|
||||
|
||||
The after_work is always set to group_manager selection, which is the defining
|
||||
characteristic of this pattern. You can customize the selection message used
|
||||
by the group manager when selecting the next agent.
|
||||
|
||||
Args:
|
||||
initial_agent: The first agent to speak in the group chat.
|
||||
agents: List of all agents participating in the chat.
|
||||
user_agent: Optional user proxy agent.
|
||||
group_manager_args: Optional arguments for the GroupChatManager.
|
||||
context_variables: Initial context variables for the chat.
|
||||
selection_message: Custom message to use when the group manager is selecting agents.
|
||||
exclude_transit_message: Whether to exclude transit messages from the conversation.
|
||||
summary_method: Method for summarizing the conversation.
|
||||
"""
|
||||
# Create the group_manager after_work with the provided selection message
|
||||
group_manager_after_work = GroupManagerTarget(selection_message=selection_message)
|
||||
|
||||
super().__init__(
|
||||
initial_agent=initial_agent,
|
||||
agents=agents,
|
||||
user_agent=user_agent,
|
||||
group_manager_args=group_manager_args,
|
||||
context_variables=context_variables,
|
||||
group_after_work=group_manager_after_work,
|
||||
exclude_transit_message=exclude_transit_message,
|
||||
summary_method=summary_method,
|
||||
)
|
||||
|
||||
# Store the selection message for potential use
|
||||
self.selection_message = selection_message
|
||||
|
||||
def prepare_group_chat(
|
||||
self,
|
||||
max_rounds: int,
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
) -> Tuple[
|
||||
list["ConversableAgent"],
|
||||
list["ConversableAgent"],
|
||||
Optional["ConversableAgent"],
|
||||
ContextVariables,
|
||||
"ConversableAgent",
|
||||
TransitionTarget,
|
||||
"GroupToolExecutor",
|
||||
"GroupChat",
|
||||
"GroupChatManager",
|
||||
list[dict[str, Any]],
|
||||
Any,
|
||||
list[str],
|
||||
list[Any],
|
||||
]:
|
||||
"""Prepare the group chat for organic agent selection.
|
||||
|
||||
Ensures that:
|
||||
1. The group manager has a valid LLM config
|
||||
2. All agents have appropriate descriptions for the group manager to use
|
||||
|
||||
Args:
|
||||
max_rounds: Maximum number of conversation rounds.
|
||||
messages: Initial message(s) to start the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple containing all necessary components for the group chat.
|
||||
"""
|
||||
# Validate that group_manager_args has an LLM config which is required for this pattern
|
||||
if not self.group_manager_args.get("llm_config", False):
|
||||
# Check if any agent has an LLM config we can use
|
||||
has_llm_config = any(getattr(agent, "llm_config", False) for agent in self.agents)
|
||||
|
||||
if not has_llm_config:
|
||||
raise ValueError(
|
||||
"AutoPattern requires the group_manager_args to include an llm_config, "
|
||||
"or at least one agent to have an llm_config"
|
||||
)
|
||||
|
||||
# Check that all agents have descriptions for effective group manager selection
|
||||
for agent in self.agents:
|
||||
if not hasattr(agent, "description") or not agent.description:
|
||||
agent.description = f"Agent {agent.name}"
|
||||
|
||||
# Use the parent class's implementation to prepare the agents and group chat
|
||||
components = super().prepare_group_chat(
|
||||
max_rounds=max_rounds,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Extract the group_after_work and the rest of the components
|
||||
(
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
_,
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
) = components
|
||||
|
||||
# Ensure we're using the group_manager after_work
|
||||
group_after_work = self.group_after_work
|
||||
|
||||
# Return all components with our group_after_work
|
||||
return (
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
group_after_work,
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
)
|
||||
176
mm_agents/coact/autogen/agentchat/group/patterns/manual.py
Normal file
176
mm_agents/coact/autogen/agentchat/group/patterns/manual.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
||||
|
||||
from ..context_variables import ContextVariables
|
||||
from ..group_tool_executor import GroupToolExecutor
|
||||
from ..targets.transition_target import AskUserTarget, TransitionTarget
|
||||
from .pattern import Pattern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...conversable_agent import ConversableAgent
|
||||
from ...groupchat import GroupChat, GroupChatManager
|
||||
|
||||
|
||||
class ManualPattern(Pattern):
|
||||
"""ManualPattern will ask the user to nominate the next agent to speak at each turn."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_agent: "ConversableAgent",
|
||||
agents: list["ConversableAgent"],
|
||||
user_agent: Optional["ConversableAgent"] = None,
|
||||
group_manager_args: Optional[dict[str, Any]] = None,
|
||||
context_variables: Optional[ContextVariables] = None,
|
||||
exclude_transit_message: bool = True,
|
||||
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
|
||||
):
|
||||
"""Initialize the ManualPattern.
|
||||
|
||||
The after_work is always set to ask_user, which will prompt the user for the next agent
|
||||
|
||||
Args:
|
||||
initial_agent: The first agent to speak in the group chat.
|
||||
agents: List of all agents participating in the chat.
|
||||
user_agent: Optional user proxy agent.
|
||||
group_manager_args: Optional arguments for the GroupChatManager.
|
||||
context_variables: Initial context variables for the chat.
|
||||
exclude_transit_message: Whether to exclude transit messages from the conversation.
|
||||
summary_method: Method for summarizing the conversation.
|
||||
"""
|
||||
# The group after work will be to ask the user
|
||||
group_after_work = AskUserTarget()
|
||||
|
||||
super().__init__(
|
||||
initial_agent=initial_agent,
|
||||
agents=agents,
|
||||
user_agent=user_agent,
|
||||
group_manager_args=group_manager_args,
|
||||
context_variables=context_variables,
|
||||
group_after_work=group_after_work,
|
||||
exclude_transit_message=exclude_transit_message,
|
||||
summary_method=summary_method,
|
||||
)
|
||||
|
||||
def prepare_group_chat(
|
||||
self,
|
||||
max_rounds: int,
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
) -> Tuple[
|
||||
list["ConversableAgent"],
|
||||
list["ConversableAgent"],
|
||||
Optional["ConversableAgent"],
|
||||
ContextVariables,
|
||||
"ConversableAgent",
|
||||
TransitionTarget,
|
||||
"GroupToolExecutor",
|
||||
"GroupChat",
|
||||
"GroupChatManager",
|
||||
list[dict[str, Any]],
|
||||
Any,
|
||||
list[str],
|
||||
list[Any],
|
||||
]:
|
||||
"""Prepare the group chat for organic agent selection.
|
||||
|
||||
Ensures that:
|
||||
1. The group manager has a valid LLM config
|
||||
2. All agents have appropriate descriptions for the group manager to use
|
||||
|
||||
Args:
|
||||
max_rounds: Maximum number of conversation rounds.
|
||||
messages: Initial message(s) to start the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple containing all necessary components for the group chat.
|
||||
"""
|
||||
# Use the parent class's implementation to prepare the agents and group chat
|
||||
components = super().prepare_group_chat(
|
||||
max_rounds=max_rounds,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Extract the group_after_work and the rest of the components
|
||||
(
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
_,
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
) = components
|
||||
|
||||
# Ensure we're using the group_manager after_work
|
||||
group_after_work = self.group_after_work
|
||||
|
||||
# Set up the allowed speaker transitions to exclude user_agent and GroupToolExecutor
|
||||
self._setup_allowed_transitions(groupchat, user_agent, tool_executor)
|
||||
|
||||
# Return all components with our group_after_work
|
||||
return (
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
group_after_work,
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
)
|
||||
|
||||
def _setup_allowed_transitions(
|
||||
self, groupchat: "GroupChat", user_agent: Optional["ConversableAgent"], tool_executor: "GroupToolExecutor"
|
||||
) -> None:
|
||||
"""Set up the allowed speaker transitions for the group chat so that when a user selects the next agent the tool executor and user agent don't appear as options.
|
||||
|
||||
Creates transitions where:
|
||||
1. Any agent can speak after any other agent, including themselves
|
||||
2. The user_agent and GroupToolExecutor are excluded from transitions
|
||||
|
||||
Args:
|
||||
groupchat: The GroupChat instance to configure
|
||||
user_agent: The user agent to exclude from transitions
|
||||
tool_executor: The GroupToolExecutor to exclude from transitions
|
||||
"""
|
||||
# NOTE: THIS IS NOT WORKING - THE TRANSITIONS ARE NOT BEING KEPT?!
|
||||
"""
|
||||
# Get all agents in the group chat
|
||||
all_agents = groupchat.agents
|
||||
|
||||
# Filter out user_agent and group tool executor
|
||||
eligible_agents = []
|
||||
for agent in all_agents:
|
||||
# Skip user_agent
|
||||
if agent == user_agent:
|
||||
continue
|
||||
|
||||
# Skip GroupToolExecutor
|
||||
if isinstance(agent, GroupToolExecutor):
|
||||
continue
|
||||
|
||||
eligible_agents.append(agent)
|
||||
|
||||
# Create a fully connected graph among eligible agents
|
||||
# Each agent can be followed by any other eligible agent
|
||||
allowed_transitions = {}
|
||||
for agent in eligible_agents:
|
||||
# For each agent, every other eligible agent can follow
|
||||
allowed_transitions[agent] = eligible_agents
|
||||
|
||||
# Set the transitions in the group chat
|
||||
groupchat.allowed_speaker_transitions_dict = allowed_transitions
|
||||
"""
|
||||
294
mm_agents/coact/autogen/agentchat/group/patterns/pattern.py
Normal file
294
mm_agents/coact/autogen/agentchat/group/patterns/pattern.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Patterns of agent orchestrations
|
||||
# Uses the group chat or the agents' handoffs to create a pattern
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
||||
|
||||
from ..context_variables import ContextVariables
|
||||
from ..group_utils import (
|
||||
create_group_manager,
|
||||
create_group_transition,
|
||||
link_agents_to_group_manager,
|
||||
prepare_group_agents,
|
||||
process_initial_messages,
|
||||
setup_context_variables,
|
||||
)
|
||||
from ..targets.transition_target import TerminateTarget, TransitionTarget
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agent import Agent
|
||||
from ...conversable_agent import ConversableAgent
|
||||
from ...groupchat import GroupChat, GroupChatManager
|
||||
from ..group_tool_executor import GroupToolExecutor
|
||||
|
||||
|
||||
class Pattern(ABC):
|
||||
"""Base abstract class for all orchestration patterns.
|
||||
|
||||
Patterns provide a reusable way to define how agents interact within a group chat.
|
||||
Each pattern encapsulates the logic for setting up agents, configuring handoffs,
|
||||
and determining the flow of conversation.
|
||||
|
||||
This is an abstract base class and should not be instantiated directly.
|
||||
Use one of the concrete pattern implementations like AutoPattern,
|
||||
RoundRobinPattern, RandomPattern, or ManualPattern.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_agent: "ConversableAgent",
|
||||
agents: list["ConversableAgent"],
|
||||
user_agent: Optional["ConversableAgent"] = None,
|
||||
group_manager_args: Optional[dict[str, Any]] = None,
|
||||
context_variables: Optional[ContextVariables] = None,
|
||||
group_after_work: Optional[TransitionTarget] = None,
|
||||
exclude_transit_message: bool = True,
|
||||
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
|
||||
):
|
||||
"""Initialize the pattern with the required components.
|
||||
|
||||
Args:
|
||||
initial_agent: The first agent to speak in the group chat.
|
||||
agents: List of all agents participating in the chat.
|
||||
user_agent: Optional user proxy agent.
|
||||
group_manager_args: Optional arguments for the GroupChatManager.
|
||||
context_variables: Initial context variables for the chat.
|
||||
group_after_work: Default after work transition behavior when no specific next agent is determined.
|
||||
exclude_transit_message: Whether to exclude transit messages from the conversation.
|
||||
summary_method: Method for summarizing the conversation.
|
||||
"""
|
||||
self.initial_agent = initial_agent
|
||||
self.agents = agents
|
||||
self.user_agent = user_agent
|
||||
self.group_manager_args = group_manager_args or {}
|
||||
self.context_variables = context_variables or ContextVariables()
|
||||
self.group_after_work = group_after_work if group_after_work is not None else TerminateTarget()
|
||||
self.exclude_transit_message = exclude_transit_message
|
||||
self.summary_method = summary_method
|
||||
|
||||
@abstractmethod
|
||||
def prepare_group_chat(
|
||||
self,
|
||||
max_rounds: int,
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
) -> Tuple[
|
||||
list["ConversableAgent"],
|
||||
list["ConversableAgent"],
|
||||
Optional["ConversableAgent"],
|
||||
ContextVariables,
|
||||
"ConversableAgent",
|
||||
TransitionTarget,
|
||||
"GroupToolExecutor",
|
||||
"GroupChat",
|
||||
"GroupChatManager",
|
||||
list[dict[str, Any]],
|
||||
"ConversableAgent",
|
||||
list[str],
|
||||
list["Agent"],
|
||||
]:
|
||||
"""Prepare the group chat for orchestration.
|
||||
|
||||
This is the main method called by initiate_group_chat to set up the pattern.
|
||||
Subclasses must implement or extend this method to define pattern-specific behavior.
|
||||
|
||||
Args:
|
||||
max_rounds: Maximum number of conversation rounds.
|
||||
messages: Initial message(s) to start the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of agents involved in the group chat
|
||||
- List of wrapped agents
|
||||
- User agent, if applicable
|
||||
- Context variables for the group chat
|
||||
- Initial agent for the group chat
|
||||
- Group-level after work transition for the group chat
|
||||
- Tool executor for the group chat
|
||||
- GroupChat instance
|
||||
- GroupChatManager instance
|
||||
- Processed messages
|
||||
- Last agent to speak
|
||||
- List of group agent names
|
||||
- List of temporary user agents
|
||||
"""
|
||||
from ...groupchat import GroupChat
|
||||
|
||||
# Prepare the agents using the existing helper function
|
||||
tool_executor, wrapped_agents = prepare_group_agents(
|
||||
self.agents, self.context_variables, self.exclude_transit_message
|
||||
)
|
||||
|
||||
# Process the initial messages BEFORE creating the GroupChat
|
||||
# This will create a temporary user agent if needed
|
||||
processed_messages, last_agent, group_agent_names, temp_user_list = process_initial_messages(
|
||||
messages, self.user_agent, self.agents, wrapped_agents
|
||||
)
|
||||
|
||||
# Create transition function (has enclosed state for initial agent)
|
||||
group_transition = create_group_transition(
|
||||
initial_agent=self.initial_agent,
|
||||
tool_execution=tool_executor,
|
||||
group_agent_names=group_agent_names,
|
||||
user_agent=self.user_agent,
|
||||
group_after_work=self.group_after_work,
|
||||
)
|
||||
|
||||
# Create the group chat - now we use temp_user_list if no user_agent
|
||||
groupchat = GroupChat(
|
||||
agents=[tool_executor]
|
||||
+ self.agents
|
||||
+ wrapped_agents
|
||||
+ ([self.user_agent] if self.user_agent else temp_user_list),
|
||||
messages=[],
|
||||
max_round=max_rounds,
|
||||
speaker_selection_method=group_transition,
|
||||
)
|
||||
|
||||
# Create the group manager
|
||||
manager = create_group_manager(groupchat, self.group_manager_args, self.agents, self.group_after_work)
|
||||
|
||||
# Point all agent's context variables to this function's context_variables
|
||||
setup_context_variables(
|
||||
tool_execution=tool_executor,
|
||||
agents=self.agents,
|
||||
manager=manager,
|
||||
user_agent=self.user_agent,
|
||||
context_variables=self.context_variables,
|
||||
)
|
||||
|
||||
# Link all agents with the GroupChatManager to allow access to the group chat
|
||||
link_agents_to_group_manager(groupchat.agents, manager)
|
||||
|
||||
return (
|
||||
self.agents,
|
||||
wrapped_agents,
|
||||
self.user_agent,
|
||||
self.context_variables,
|
||||
self.initial_agent,
|
||||
self.group_after_work,
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
) # type: ignore[return-value]
|
||||
|
||||
@classmethod
|
||||
def create_default(
|
||||
cls,
|
||||
initial_agent: "ConversableAgent",
|
||||
agents: list["ConversableAgent"],
|
||||
user_agent: Optional["ConversableAgent"] = None,
|
||||
group_manager_args: Optional[dict[str, Any]] = None,
|
||||
context_variables: Optional[ContextVariables] = None,
|
||||
exclude_transit_message: bool = True,
|
||||
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
|
||||
) -> "DefaultPattern":
|
||||
"""Create a default pattern with minimal configuration.
|
||||
|
||||
This replaces the need for a separate BasePattern class by providing
|
||||
a factory method that creates a simple DefaultPattern instance.
|
||||
|
||||
Args:
|
||||
initial_agent: The first agent to speak in the group chat.
|
||||
agents: List of all agents participating in the chat.
|
||||
user_agent: Optional user proxy agent.
|
||||
group_manager_args: Optional arguments for the GroupChatManager.
|
||||
context_variables: Initial context variables for the chat.
|
||||
exclude_transit_message: Whether to exclude transit messages from the conversation.
|
||||
summary_method: Method for summarizing the conversation.
|
||||
|
||||
Returns:
|
||||
A DefaultPattern instance with basic configuration.
|
||||
"""
|
||||
return DefaultPattern(
|
||||
initial_agent=initial_agent,
|
||||
agents=agents,
|
||||
user_agent=user_agent,
|
||||
group_manager_args=group_manager_args,
|
||||
context_variables=context_variables,
|
||||
exclude_transit_message=exclude_transit_message,
|
||||
summary_method=summary_method,
|
||||
)
|
||||
|
||||
|
||||
class DefaultPattern(Pattern):
|
||||
"""DefaultPattern implements a minimal pattern for simple agent interactions.
|
||||
|
||||
This replaces the previous BasePattern and provides a concrete implementation
|
||||
of the Pattern abstract base class.
|
||||
"""
|
||||
|
||||
def prepare_group_chat(
|
||||
self,
|
||||
max_rounds: int,
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
) -> Tuple[
|
||||
list["ConversableAgent"],
|
||||
list["ConversableAgent"],
|
||||
Optional["ConversableAgent"],
|
||||
ContextVariables,
|
||||
"ConversableAgent",
|
||||
TransitionTarget,
|
||||
"GroupToolExecutor",
|
||||
"GroupChat",
|
||||
"GroupChatManager",
|
||||
list[dict[str, Any]],
|
||||
Any,
|
||||
list[str],
|
||||
list[Any],
|
||||
]:
|
||||
"""Prepare the group chat with default configuration.
|
||||
|
||||
This implementation calls the parent class method but ensures that
|
||||
the group_after_work in the returned tuple is the pattern's own.
|
||||
|
||||
Args:
|
||||
max_rounds: Maximum number of conversation rounds.
|
||||
messages: Initial message(s) to start the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple containing all necessary components for the group chat.
|
||||
"""
|
||||
# Use the parent class's implementation to prepare the agents and group chat
|
||||
(
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
_, # Ignore the group_after_work from parent
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
) = super().prepare_group_chat(
|
||||
max_rounds=max_rounds,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Return all components with our group_after_work
|
||||
return (
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
self.group_after_work, # Use our own group_after_work
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
)
|
||||
106
mm_agents/coact/autogen/agentchat/group/patterns/random.py
Normal file
106
mm_agents/coact/autogen/agentchat/group/patterns/random.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||
|
||||
from ..context_variables import ContextVariables
|
||||
from ..targets.transition_target import RandomAgentTarget, TransitionTarget
|
||||
from .pattern import Pattern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...conversable_agent import ConversableAgent
|
||||
from ...groupchat import GroupChat, GroupChatManager
|
||||
from ..group_tool_executor import GroupToolExecutor
|
||||
|
||||
|
||||
class RandomPattern(Pattern):
|
||||
"""RandomPattern implements a random agent selection process."""
|
||||
|
||||
def _generate_handoffs(
|
||||
self,
|
||||
initial_agent: "ConversableAgent",
|
||||
agents: list["ConversableAgent"],
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> None:
|
||||
"""Generate handoffs between agents in a random fashion."""
|
||||
agent_list = agents + ([user_agent] if user_agent is not None else [])
|
||||
|
||||
for agent in agent_list:
|
||||
# Get the list of agents except itself
|
||||
other_agents = [a for a in agent_list if a != agent]
|
||||
|
||||
# Create a random after work
|
||||
agent.handoffs.set_after_work(target=RandomAgentTarget(agents=other_agents))
|
||||
|
||||
def prepare_group_chat(
|
||||
self,
|
||||
max_rounds: int,
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
) -> Tuple[
|
||||
list["ConversableAgent"],
|
||||
list["ConversableAgent"],
|
||||
Optional["ConversableAgent"],
|
||||
ContextVariables,
|
||||
"ConversableAgent",
|
||||
TransitionTarget,
|
||||
"GroupToolExecutor",
|
||||
"GroupChat",
|
||||
"GroupChatManager",
|
||||
list[dict[str, Any]],
|
||||
Any,
|
||||
list[str],
|
||||
list[Any],
|
||||
]:
|
||||
"""Prepare the group chat for organic agent selection.
|
||||
|
||||
Ensures that:
|
||||
1. The group manager has a valid LLM config
|
||||
2. All agents have appropriate descriptions for the group manager to use
|
||||
|
||||
Args:
|
||||
max_rounds: Maximum number of conversation rounds.
|
||||
messages: Initial message(s) to start the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple containing all necessary components for the group chat.
|
||||
"""
|
||||
# Use the parent class's implementation to prepare the agents and group chat
|
||||
(
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
group_after_work,
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
) = super().prepare_group_chat(
|
||||
max_rounds=max_rounds,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Create the random handoffs between agents
|
||||
self._generate_handoffs(initial_agent=initial_agent, agents=agents, user_agent=user_agent)
|
||||
|
||||
# Return all components with our group_after_work
|
||||
return (
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
group_after_work,
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
)
|
||||
117
mm_agents/coact/autogen/agentchat/group/patterns/round_robin.py
Normal file
117
mm_agents/coact/autogen/agentchat/group/patterns/round_robin.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||
|
||||
from ..context_variables import ContextVariables
|
||||
from ..targets.transition_target import AgentTarget, TransitionTarget
|
||||
from .pattern import Pattern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...conversable_agent import ConversableAgent
|
||||
from ...groupchat import GroupChat, GroupChatManager
|
||||
from ..group_tool_executor import GroupToolExecutor
|
||||
|
||||
|
||||
class RoundRobinPattern(Pattern):
|
||||
"""RoundRobinPattern implements a round robin with handoffs between agents."""
|
||||
|
||||
def _generate_handoffs(
|
||||
self,
|
||||
initial_agent: "ConversableAgent",
|
||||
agents: list["ConversableAgent"],
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> None:
|
||||
"""Generate handoffs between agents in a round-robin fashion."""
|
||||
# Create a list of the agents and the user_agent but put the initial_agent first
|
||||
agent_list = [initial_agent]
|
||||
|
||||
# Add the rest of the agents, excluding the initial_agent and user_agent
|
||||
for agent in agents:
|
||||
if agent != initial_agent and (user_agent is None or agent != user_agent):
|
||||
agent_list.append(agent)
|
||||
|
||||
# Add the user_agent last if it exists
|
||||
if user_agent is not None:
|
||||
agent_list.append(user_agent)
|
||||
|
||||
# Create handoffs in a round-robin fashion
|
||||
for i, agent in enumerate(agent_list):
|
||||
# Last agent hands off to the first agent
|
||||
# Otherwise agent hands off to the next one
|
||||
handoff_target = agent_list[0] if i == len(agent_list) - 1 else agent_list[i + 1]
|
||||
|
||||
agent.handoffs.set_after_work(target=AgentTarget(agent=handoff_target))
|
||||
|
||||
def prepare_group_chat(
|
||||
self,
|
||||
max_rounds: int,
|
||||
messages: Union[list[dict[str, Any]], str],
|
||||
) -> Tuple[
|
||||
list["ConversableAgent"],
|
||||
list["ConversableAgent"],
|
||||
Optional["ConversableAgent"],
|
||||
ContextVariables,
|
||||
"ConversableAgent",
|
||||
TransitionTarget,
|
||||
"GroupToolExecutor",
|
||||
"GroupChat",
|
||||
"GroupChatManager",
|
||||
list[dict[str, Any]],
|
||||
Any,
|
||||
list[str],
|
||||
list[Any],
|
||||
]:
|
||||
"""Prepare the group chat for organic agent selection.
|
||||
|
||||
Ensures that:
|
||||
1. The group manager has a valid LLM config
|
||||
2. All agents have appropriate descriptions for the group manager to use
|
||||
|
||||
Args:
|
||||
max_rounds: Maximum number of conversation rounds.
|
||||
messages: Initial message(s) to start the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple containing all necessary components for the group chat.
|
||||
"""
|
||||
# Use the parent class's implementation to prepare the agents and group chat
|
||||
(
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
group_after_work,
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
) = super().prepare_group_chat(
|
||||
max_rounds=max_rounds,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Create the handoffs between agents
|
||||
self._generate_handoffs(initial_agent=initial_agent, agents=agents, user_agent=user_agent)
|
||||
|
||||
# Return all components with our group_after_work
|
||||
return (
|
||||
agents,
|
||||
wrapped_agents,
|
||||
user_agent,
|
||||
context_variables,
|
||||
initial_agent,
|
||||
group_after_work,
|
||||
tool_executor,
|
||||
groupchat,
|
||||
manager,
|
||||
processed_messages,
|
||||
last_agent,
|
||||
group_agent_names,
|
||||
temp_user_list,
|
||||
)
|
||||
26
mm_agents/coact/autogen/agentchat/group/reply_result.py
Normal file
26
mm_agents/coact/autogen/agentchat/group/reply_result.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
__all__ = ["ReplyResult"]
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .context_variables import ContextVariables
|
||||
from .targets.transition_target import TransitionTarget
|
||||
|
||||
|
||||
class ReplyResult(BaseModel):
|
||||
"""Result of a tool call that is used to provide the return message and the target to transition to."""
|
||||
|
||||
message: str
|
||||
target: Optional[TransitionTarget] = None
|
||||
context_variables: Optional[ContextVariables] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""The string representation for ReplyResult will be just the message."""
|
||||
return self.message
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agent import Agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Avoid circular import
|
||||
from ..groupchat import GroupChat
|
||||
|
||||
|
||||
class SpeakerSelectionResult(BaseModel):
|
||||
"""Represents a speaker selection result that will be returned to GroupChat._prepare_and_select_agents to determine the next speaker.
|
||||
|
||||
This class can return an Agent, a None to end the conversation, or a string for a speaker selection method.
|
||||
"""
|
||||
|
||||
terminate: Optional[bool] = None
|
||||
agent_name: Optional[str] = None
|
||||
speaker_selection_method: Optional[str] = None
|
||||
|
||||
def get_speaker_selection_result(self, groupchat: "GroupChat") -> Optional[Union[Agent, str]]:
|
||||
"""Get the speaker selection result. If None, the conversation will end."""
|
||||
if self.agent_name is not None:
|
||||
# Find the agent by name in the groupchat
|
||||
for agent in groupchat.agents:
|
||||
if agent.name == self.agent_name:
|
||||
return agent
|
||||
raise ValueError(f"Agent '{self.agent_name}' not found in groupchat.")
|
||||
elif self.speaker_selection_method is not None:
|
||||
return self.speaker_selection_method
|
||||
elif self.terminate is not None and self.terminate:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unable to establish speaker selection result. No terminate, agent, or speaker selection method provided."
|
||||
)
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ...agent import Agent
|
||||
from ..speaker_selection_result import SpeakerSelectionResult
|
||||
from .transition_target import AgentTarget, TransitionTarget
|
||||
from .transition_utils import __AGENT_WRAPPER_PREFIX__
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...conversable_agent import ConversableAgent
|
||||
from ...groupchat import GroupChat
|
||||
from ..patterns.pattern import Pattern
|
||||
|
||||
|
||||
__all__ = ["GroupChatConfig", "GroupChatTarget"]
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.group")
|
||||
class GroupChatConfig(BaseModel):
|
||||
"""Configuration for a group chat transition target.
|
||||
|
||||
Note: If context_variables are not passed in, the outer context variables will be passed in"""
|
||||
|
||||
pattern: "Pattern"
|
||||
messages: Union[list[dict[str, Any]], str]
|
||||
max_rounds: int = 20
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.group")
|
||||
class GroupChatTarget(TransitionTarget):
|
||||
"""Target that represents a group chat."""
|
||||
|
||||
group_chat_config: GroupChatConfig
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection. For GroupChatTarget the chat must be encapsulated into an agent."""
|
||||
return False
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to the nested chat configuration."""
|
||||
raise NotImplementedError(
|
||||
"GroupChatTarget does not support the resolve method. An agent should be used to encapsulate this nested chat and then the target changed to an AgentTarget."
|
||||
)
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return "a group chat"
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling."""
|
||||
return "group_chat"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||
return "Transfer to group chat"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent. GroupChatTarget must be wrapped in an agent."""
|
||||
return True
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the group chat."""
|
||||
from autogen.agentchat import initiate_group_chat
|
||||
|
||||
from ...conversable_agent import ConversableAgent # to avoid circular import
|
||||
|
||||
# Create the wrapper agent with a name that identifies it as a wrapped group chat
|
||||
group_chat_agent = ConversableAgent(
|
||||
name=f"{__AGENT_WRAPPER_PREFIX__}group_{parent_agent.name}_{index + 1}",
|
||||
# Copy LLM config from parent agent to ensure it can generate replies if needed
|
||||
llm_config=parent_agent.llm_config,
|
||||
)
|
||||
|
||||
# Store the config directly on the agent
|
||||
group_chat_agent._group_chat_config = self.group_chat_config # type: ignore[attr-defined]
|
||||
|
||||
# Define the reply function that will run the group chat
|
||||
def group_chat_reply(
|
||||
agent: "ConversableAgent",
|
||||
messages: Optional[list[dict[str, Any]]] = None,
|
||||
sender: Optional["Agent"] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""Run the inner group chat and return its results as a reply."""
|
||||
# Get the configuration stored directly on the agent
|
||||
group_config = agent._group_chat_config # type: ignore[attr-defined]
|
||||
|
||||
# Pull through the second last message from the outer chat (the last message will be the handoff message)
|
||||
# This may need work to make sure we get the right message(s) from the outer chat
|
||||
message = (
|
||||
messages[-2]["content"]
|
||||
if messages and len(messages) >= 2 and "content" in messages[-2]
|
||||
else "No message to pass through."
|
||||
)
|
||||
|
||||
try:
|
||||
# Run the group chat with direct agent references from the config
|
||||
result, _, _ = initiate_group_chat(
|
||||
pattern=group_config.pattern,
|
||||
messages=message,
|
||||
max_rounds=group_config.max_rounds,
|
||||
)
|
||||
|
||||
# Return the summary from the chat result summary
|
||||
return True, {"content": result.summary}
|
||||
|
||||
except Exception as e:
|
||||
# Handle any errors during execution
|
||||
return True, {"content": f"Error running group chat: {str(e)}"}
|
||||
|
||||
# Register the reply function with the wrapper agent
|
||||
group_chat_agent.register_reply(
|
||||
trigger=[ConversableAgent, None],
|
||||
reply_func=group_chat_reply,
|
||||
remove_other_reply_funcs=True, # Use only this reply function
|
||||
)
|
||||
|
||||
# After the group chat completes, transition back to the parent agent
|
||||
group_chat_agent.handoffs.set_after_work(AgentTarget(parent_agent))
|
||||
|
||||
return group_chat_agent
|
||||
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ..context_str import ContextStr
|
||||
from ..group_tool_executor import GroupToolExecutor
|
||||
from ..speaker_selection_result import SpeakerSelectionResult
|
||||
from .transition_target import TransitionTarget
|
||||
from .transition_utils import __AGENT_WRAPPER_PREFIX__
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Avoid circular import
|
||||
from ...conversable_agent import ConversableAgent
|
||||
from ...groupchat import GroupChat
|
||||
|
||||
__all__ = ["GroupManagerTarget"]
|
||||
|
||||
|
||||
def prepare_groupchat_auto_speaker(
|
||||
groupchat: "GroupChat",
|
||||
last_group_agent: "ConversableAgent",
|
||||
group_chat_manager_selection_msg: Optional[Any],
|
||||
) -> None:
|
||||
"""Prepare the group chat for auto speaker selection, includes updating or restore the groupchat speaker selection message.
|
||||
|
||||
Tool Executor and wrapped agents will be removed from the available agents list.
|
||||
|
||||
Args:
|
||||
groupchat (GroupChat): GroupChat instance.
|
||||
last_group_agent ("ConversableAgent"): The last group agent for which the LLM config is used
|
||||
group_chat_manager_selection_msg (GroupManagerSelectionMessage): Optional message to use for the agent selection (in internal group chat).
|
||||
"""
|
||||
from ...groupchat import SELECT_SPEAKER_PROMPT_TEMPLATE
|
||||
|
||||
def substitute_agentlist(template: str) -> str:
|
||||
# Run through group chat's string substitution first for {agentlist}
|
||||
# We need to do this so that the next substitution doesn't fail with agentlist
|
||||
# and we can remove the tool executor and wrapped chats from the available agents list
|
||||
agent_list = [
|
||||
agent
|
||||
for agent in groupchat.agents
|
||||
if not isinstance(agent, GroupToolExecutor) and not agent.name.startswith(__AGENT_WRAPPER_PREFIX__)
|
||||
]
|
||||
|
||||
groupchat.select_speaker_prompt_template = template
|
||||
return groupchat.select_speaker_prompt(agent_list)
|
||||
|
||||
# Use the default speaker selection prompt if one is not specified, otherwise use the specified one
|
||||
groupchat.select_speaker_prompt_template = substitute_agentlist(
|
||||
SELECT_SPEAKER_PROMPT_TEMPLATE
|
||||
if group_chat_manager_selection_msg is None
|
||||
else group_chat_manager_selection_msg.get_message(last_group_agent)
|
||||
)
|
||||
|
||||
|
||||
# GroupManagerSelectionMessage protocol and implementations
|
||||
@export_module("autogen.agentchat.group")
|
||||
class GroupManagerSelectionMessage(BaseModel):
|
||||
"""Base class for all GroupManager selection message types."""
|
||||
|
||||
def get_message(self, agent: "ConversableAgent") -> str:
|
||||
"""Get the formatted message."""
|
||||
raise NotImplementedError("Requires subclasses to implement.")
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.group")
|
||||
class GroupManagerSelectionMessageString(GroupManagerSelectionMessage):
|
||||
"""Selection message that uses a plain string template."""
|
||||
|
||||
message: str
|
||||
|
||||
def get_message(self, agent: "ConversableAgent") -> str:
|
||||
"""Get the message string."""
|
||||
return self.message
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.group")
|
||||
class GroupManagerSelectionMessageContextStr(GroupManagerSelectionMessage):
|
||||
"""Selection message that uses a ContextStr template."""
|
||||
|
||||
context_str_template: str
|
||||
|
||||
# We will replace {agentlist} with another term and return it later for use with the internal group chat auto speaker selection
|
||||
# Otherwise our format will fail
|
||||
@field_validator("context_str_template", mode="before")
|
||||
def _replace_agentlist_placeholder(cls: Type["GroupManagerSelectionMessageContextStr"], v: Any) -> Union[str, Any]: # noqa: N805
|
||||
"""Replace {agentlist} placeholder before validation/assignment."""
|
||||
if isinstance(v, str):
|
||||
if "{agentlist}" in v:
|
||||
return v.replace("{agentlist}", "<<agent_list>>") # Perform the replacement
|
||||
else:
|
||||
return v # If no replacement is needed, return the original value
|
||||
return ""
|
||||
|
||||
def get_message(self, agent: "ConversableAgent") -> str:
|
||||
"""Get the formatted message with context variables substituted."""
|
||||
context_str = ContextStr(template=self.context_str_template)
|
||||
format_result = context_str.format(agent.context_variables)
|
||||
if format_result is None:
|
||||
return ""
|
||||
|
||||
return format_result.replace(
|
||||
"<<agent_list>>", "{agentlist}"
|
||||
) # Restore agentlist so it can be substituted by the internal group chat auto speaker selection
|
||||
|
||||
|
||||
class GroupManagerTarget(TransitionTarget):
|
||||
"""Target that represents an agent by name."""
|
||||
|
||||
selection_message: Optional[GroupManagerSelectionMessage] = None
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection."""
|
||||
return True
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to the speaker selection for the group."""
|
||||
if self.selection_message is not None:
|
||||
prepare_groupchat_auto_speaker(groupchat, current_agent, self.selection_message)
|
||||
|
||||
return SpeakerSelectionResult(speaker_selection_method="auto")
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return "the group manager"
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
return self.display_name()
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||
return "Transfer to the group manager"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent."""
|
||||
return False
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the target if needed."""
|
||||
raise NotImplementedError("GroupManagerTarget does not require wrapping in an agent.")
|
||||
@@ -0,0 +1,413 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..speaker_selection_result import SpeakerSelectionResult
|
||||
from .transition_utils import __AGENT_WRAPPER_PREFIX__
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Avoid circular import
|
||||
from ...conversable_agent import ConversableAgent
|
||||
from ...groupchat import GroupChat
|
||||
|
||||
__all__ = [
|
||||
"AgentNameTarget",
|
||||
"AgentTarget",
|
||||
"AskUserTarget",
|
||||
"NestedChatTarget",
|
||||
"RandomAgentTarget",
|
||||
"RevertToUserTarget",
|
||||
"StayTarget",
|
||||
"TerminateTarget",
|
||||
"TransitionTarget",
|
||||
]
|
||||
|
||||
# Common options for transitions
|
||||
# terminate: Terminate the conversation
|
||||
# revert_to_user: Revert to the user agent
|
||||
# stay: Stay with the current agent
|
||||
# group_manager: Use the group manager (auto speaker selection)
|
||||
# ask_user: Use the user manager (ask the user, aka manual)
|
||||
# TransitionOption = Literal["terminate", "revert_to_user", "stay", "group_manager", "ask_user"]
|
||||
|
||||
|
||||
class TransitionTarget(BaseModel):
|
||||
"""Base class for all transition targets across OnCondition, OnContextCondition, and after work."""
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve to an option for speaker selection (Agent, 'None' to end, Str for speaker selection method). In the case of a nested chat, this will return False as it should be encapsulated in an agent."""
|
||||
return False
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to a speaker selection result (Agent, None for termination, or str for speaker selection method)."""
|
||||
raise NotImplementedError("Requires subclasses to implement.")
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
raise NotImplementedError("Requires subclasses to implement.")
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
raise NotImplementedError("Requires subclasses to implement.")
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent."""
|
||||
raise NotImplementedError("Requires subclasses to implement.")
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the target if needed."""
|
||||
raise NotImplementedError("Requires subclasses to implement.")
|
||||
|
||||
|
||||
class AgentTarget(TransitionTarget):
|
||||
"""Target that represents a direct agent reference."""
|
||||
|
||||
agent_name: str
|
||||
|
||||
def __init__(self, agent: "ConversableAgent", **data: Any) -> None: # type: ignore[no-untyped-def]
|
||||
# Store the name from the agent for serialization
|
||||
super().__init__(agent_name=agent.name, **data)
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection."""
|
||||
return True
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to the actual agent object from the groupchat."""
|
||||
return SpeakerSelectionResult(agent_name=self.agent_name)
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return f"{self.agent_name}"
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
return self.display_name()
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||
return f"Transfer to {self.agent_name}"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent."""
|
||||
return False
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the target if needed."""
|
||||
raise NotImplementedError("AgentTarget does not require wrapping in an agent.")
|
||||
|
||||
|
||||
class AgentNameTarget(TransitionTarget):
|
||||
"""Target that represents an agent by name."""
|
||||
|
||||
agent_name: str
|
||||
|
||||
def __init__(self, agent_name: str, **data: Any) -> None:
|
||||
"""Initialize with agent name as a positional parameter."""
|
||||
super().__init__(agent_name=agent_name, **data)
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection."""
|
||||
return True
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to the agent name string."""
|
||||
return SpeakerSelectionResult(agent_name=self.agent_name)
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return f"{self.agent_name}"
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
return self.display_name()
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||
return f"Transfer to {self.agent_name}"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent."""
|
||||
return False
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the target if needed."""
|
||||
raise NotImplementedError("AgentNameTarget does not require wrapping in an agent.")
|
||||
|
||||
|
||||
class NestedChatTarget(TransitionTarget):
|
||||
"""Target that represents a nested chat configuration."""
|
||||
|
||||
nested_chat_config: dict[str, Any]
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection. For NestedChatTarget the nested chat must be encapsulated into an agent."""
|
||||
return False
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to the nested chat configuration."""
|
||||
raise NotImplementedError(
|
||||
"NestedChatTarget does not support the resolve method. An agent should be used to encapsulate this nested chat and then the target changed to an AgentTarget."
|
||||
)
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return "a nested chat"
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
return "nested_chat"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||
return "Transfer to nested chat"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent. NestedChatTarget must be wrapped in an agent."""
|
||||
return True
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the nested chat."""
|
||||
from ...conversable_agent import ConversableAgent # to avoid circular import - NEED SOLUTION
|
||||
|
||||
nested_chat_agent = ConversableAgent(name=f"{__AGENT_WRAPPER_PREFIX__}nested_{parent_agent.name}_{index + 1}")
|
||||
|
||||
nested_chat_agent.register_nested_chats(
|
||||
self.nested_chat_config["chat_queue"],
|
||||
reply_func_from_nested_chats=self.nested_chat_config.get("reply_func_from_nested_chats")
|
||||
or "summary_from_nested_chats",
|
||||
config=self.nested_chat_config.get("config"),
|
||||
trigger=lambda sender: True,
|
||||
position=0,
|
||||
use_async=self.nested_chat_config.get("use_async", False),
|
||||
)
|
||||
|
||||
# After the nested chat is complete, transfer back to the parent agent
|
||||
nested_chat_agent.handoffs.set_after_work(AgentTarget(parent_agent))
|
||||
|
||||
return nested_chat_agent
|
||||
|
||||
|
||||
class TerminateTarget(TransitionTarget):
|
||||
"""Target that represents a termination of the conversation."""
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection."""
|
||||
return True
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to termination."""
|
||||
return SpeakerSelectionResult(terminate=True)
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return "Terminate"
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
return "terminate"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||
return "Terminate"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent."""
|
||||
return False
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the target if needed."""
|
||||
raise NotImplementedError("TerminateTarget does not require wrapping in an agent.")
|
||||
|
||||
|
||||
class StayTarget(TransitionTarget):
|
||||
"""Target that represents staying with the current agent."""
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection."""
|
||||
return True
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to staying with the current agent."""
|
||||
return SpeakerSelectionResult(agent_name=current_agent.name)
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return "Stay"
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
return "stay"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||
return "Stay with agent"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent."""
|
||||
return False
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the target if needed."""
|
||||
raise NotImplementedError("StayTarget does not require wrapping in an agent.")
|
||||
|
||||
|
||||
class RevertToUserTarget(TransitionTarget):
|
||||
"""Target that represents reverting to the user agent."""
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection."""
|
||||
return True
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to reverting to the user agent."""
|
||||
if user_agent is None:
|
||||
raise ValueError("User agent must be provided to the chat for the revert_to_user option.")
|
||||
return SpeakerSelectionResult(agent_name=user_agent.name)
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return "Revert to User"
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
return "revert_to_user"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||
return "Revert to User"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent."""
|
||||
return False
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the target if needed."""
|
||||
raise NotImplementedError("RevertToUserTarget does not require wrapping in an agent.")
|
||||
|
||||
|
||||
class AskUserTarget(TransitionTarget):
|
||||
"""Target that represents asking the user for input."""
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection."""
|
||||
return True
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to asking the user for input."""
|
||||
return SpeakerSelectionResult(speaker_selection_method="manual")
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return "Ask User"
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
return "ask_user"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||
return "Ask User"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent."""
|
||||
return False
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the target if needed."""
|
||||
raise NotImplementedError("AskUserTarget does not require wrapping in an agent.")
|
||||
|
||||
|
||||
class RandomAgentTarget(TransitionTarget):
|
||||
"""Target that represents a random selection from a list of agents."""
|
||||
|
||||
agent_names: list[str]
|
||||
nominated_name: str = "<Not Randomly Selected Yet>"
|
||||
|
||||
def __init__(self, agents: list["ConversableAgent"], **data: Any) -> None: # type: ignore[no-untyped-def]
|
||||
# Store the name from the agent for serialization
|
||||
super().__init__(agent_names=[agent.name for agent in agents], **data)
|
||||
|
||||
def can_resolve_for_speaker_selection(self) -> bool:
|
||||
"""Check if the target can resolve for speaker selection."""
|
||||
return True
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
groupchat: "GroupChat",
|
||||
current_agent: "ConversableAgent",
|
||||
user_agent: Optional["ConversableAgent"],
|
||||
) -> SpeakerSelectionResult:
|
||||
"""Resolve to the actual agent object from the groupchat, choosing a random agent (except the current one)"""
|
||||
# Randomly select the next agent
|
||||
self.nominated_name = random.choice([name for name in self.agent_names if name != current_agent.name])
|
||||
|
||||
return SpeakerSelectionResult(agent_name=self.nominated_name)
|
||||
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name for the target."""
|
||||
return self.nominated_name
|
||||
|
||||
def normalized_name(self) -> str:
|
||||
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||
return self.display_name()
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation for RandomAgentTarget, can be shown as a function call message."""
|
||||
return f"Transfer to {self.nominated_name}"
|
||||
|
||||
def needs_agent_wrapper(self) -> bool:
|
||||
"""Check if the target needs to be wrapped in an agent."""
|
||||
return False
|
||||
|
||||
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||
"""Create a wrapper agent for the target if needed."""
|
||||
raise NotImplementedError("RandomAgentTarget does not require wrapping in an agent.")
|
||||
|
||||
|
||||
# TODO: Consider adding a SequentialChatTarget class
|
||||
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Prefix for all wrapped agent names
|
||||
__AGENT_WRAPPER_PREFIX__ = "wrapped_"
|
||||
Reference in New Issue
Block a user