Files
sci-gui-agent-benchmark/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_swarm.py
2025-07-31 10:35:20 +08:00

484 lines
17 KiB
Python

# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import logging
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
import anyio
from asyncer import asyncify, create_task_group, syncify
from ....agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat
from ....cache import AbstractCache
from ....code_utils import content_str
from ....doc_utils import export_module
from ... import Agent, ChatResult, ConversableAgent, LLMAgent
from ...utils import consolidate_chat_info, gather_usage_summary
if TYPE_CHECKING:
from .clients import Role
from .realtime_agent import RealtimeAgent
__all__ = ["register_swarm"]
SWARM_SYSTEM_MESSAGE = (
"You are a helpful voice assistant. Your task is to listen to user and to coordinate the tasks based on his/her inputs."
"You can and will communicate using audio output only."
)
QUESTION_ROLE: "Role" = "user"
QUESTION_MESSAGE = (
"I have a question/information for myself. DO NOT ANSWER YOURSELF, GET THE ANSWER FROM ME. "
"repeat the question to me **WITH AUDIO OUTPUT** and AFTER YOU GET THE ANSWER FROM ME call 'answer_task_question' with the answer in first person\n\n"
"IMPORTANT: repeat just the question, without any additional information or context\n\n"
"The question is: '{}'\n\n"
)
QUESTION_TIMEOUT_SECONDS = 20
logger = logging.getLogger(__name__)
F = TypeVar("F", bound=Callable[..., Any])
def message_to_dict(message: Union[dict[str, Any], str]) -> dict[str, Any]:
if isinstance(message, str):
return {"content": message}
elif isinstance(message, dict):
return message
else:
return dict(message)
def parse_oai_message(message: Union[dict[str, Any], str], role: str, adressee: Agent) -> dict[str, Any]:
"""
Parse a message into an OpenAI-compatible message format.
Args:
message: The message to parse.
role: The role associated with the message.
adressee: The agent that will receive the message.
Returns:
The parsed message in OpenAI-compatible format.
Raises:
ValueError: If the message lacks required fields like 'content', 'function_call', or 'tool_calls'.
"""
message = message_to_dict(message)
# Extract relevant fields while ensuring none are None
oai_message = {
key: message[key]
for key in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context")
if key in message and message[key] is not None
}
# Validate or set the content field
if "content" not in oai_message:
if "function_call" in oai_message or "tool_calls" in oai_message:
oai_message["content"] = None
else:
raise ValueError("Message must have either 'content', 'function_call', or 'tool_calls' field.")
# Determine and assign the role
if message.get("role") in ["function", "tool"]:
oai_message["role"] = message["role"]
# Ensure all tool responses have string content
for tool_response in oai_message.get("tool_responses", []):
tool_response["content"] = str(tool_response["content"])
elif "override_role" in message:
oai_message["role"] = message["override_role"]
else:
oai_message["role"] = role
# Enforce specific role requirements for assistant messages
if oai_message.get("function_call") or oai_message.get("tool_calls"):
oai_message["role"] = "assistant"
# Add a name field if missing
if "name" not in oai_message:
oai_message["name"] = adressee.name
return oai_message
class SwarmableAgent(Agent):
"""A class for an agent that can participate in a swarm chat."""
def __init__(
self,
name: str,
system_message: str = "You are a helpful AI Assistant.",
is_termination_msg: Optional[Callable[..., bool]] = None,
description: Optional[str] = None,
silent: Optional[bool] = None,
):
self._oai_messages: dict[Agent, Any] = defaultdict(list)
self._system_message = system_message
self._description = description if description is not None else system_message
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
self.silent = silent
self._name = name
# Initialize standalone client cache object.
self.client_cache = None
self.previous_cache = None
self.reply_at_receive: dict[Agent, bool] = defaultdict(bool)
@property
def system_message(self) -> str:
return self._system_message
def update_system_message(self, system_message: str) -> None:
"""Update this agent's system message.
Args:
system_message (str): system message for inference.
"""
self._system_message = system_message
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
def send(
self,
message: Union[dict[str, Any], str],
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> None:
self._oai_messages[recipient].append(parse_oai_message(message, "assistant", recipient))
recipient.receive(message, self, request_reply)
def receive(
self,
message: Union[dict[str, Any], str],
sender: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> None:
self._oai_messages[sender].append(parse_oai_message(message, "user", self))
if request_reply is False or (request_reply is None and self.reply_at_receive[sender] is False):
return
reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender)
if reply is not None:
self.send(reply, sender, silent=silent)
def generate_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, dict[str, Any], None]:
if messages is None:
if sender is None:
raise ValueError("Either messages or sender must be provided.")
messages = self._oai_messages[sender]
_, reply = self.check_termination_and_human_reply(messages=messages, sender=sender, config=None)
return reply
def check_termination_and_human_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Union[str, None]]:
raise NotImplementedError
def initiate_chat(
self,
recipient: ConversableAgent,
message: Union[dict[str, Any], str],
clear_history: bool = True,
silent: Optional[bool] = False,
cache: Optional[AbstractCache] = None,
summary_args: Optional[dict[str, Any]] = {},
**kwargs: dict[str, Any],
) -> ChatResult:
_chat_info = locals().copy()
_chat_info["sender"] = self
consolidate_chat_info(_chat_info, uniform_sender=self)
recipient._raise_exception_on_async_reply_functions()
recipient.previous_cache = recipient.client_cache # type: ignore[attr-defined]
recipient.client_cache = cache # type: ignore[attr-defined, assignment]
self._prepare_chat(recipient, clear_history)
self.send(message, recipient, silent=silent)
summary = self._last_msg_as_summary(self, recipient, summary_args)
recipient.client_cache = recipient.previous_cache # type: ignore[attr-defined]
recipient.previous_cache = None # type: ignore[attr-defined]
chat_result = ChatResult(
chat_history=self.chat_messages[recipient],
summary=summary,
cost=gather_usage_summary([self, recipient]), # type: ignore[arg-type]
human_input=[],
)
return chat_result
async def a_generate_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, dict[str, Any], None]:
return self.generate_reply(messages=messages, sender=sender, **kwargs)
async def a_receive(
self,
message: Union[dict[str, Any], str],
sender: "Agent",
request_reply: Optional[bool] = None,
) -> None:
self.receive(message, sender, request_reply)
async def a_send(
self,
message: Union[dict[str, Any], str],
recipient: "Agent",
request_reply: Optional[bool] = None,
) -> None:
self.send(message, recipient, request_reply)
@property
def chat_messages(self) -> dict[Agent, list[dict[str, Any]]]:
"""A dictionary of conversations from agent to list of messages."""
return self._oai_messages
def last_message(self, agent: Optional[Agent] = None) -> Optional[dict[str, Any]]:
if agent is None:
n_conversations = len(self._oai_messages)
if n_conversations == 0:
return None
if n_conversations == 1:
for conversation in self._oai_messages.values():
return conversation[-1] # type: ignore[no-any-return]
raise ValueError("More than one conversation is found. Please specify the sender to get the last message.")
if agent not in self._oai_messages():
raise KeyError(
f"The agent '{agent.name}' is not present in any conversation. No history available for this agent."
)
return self._oai_messages[agent][-1] # type: ignore[no-any-return]
def _prepare_chat(
self,
recipient: ConversableAgent,
clear_history: bool,
prepare_recipient: bool = True,
reply_at_receive: bool = True,
) -> None:
self.reply_at_receive[recipient] = reply_at_receive
if clear_history:
self._oai_messages[recipient].clear()
if prepare_recipient:
recipient._prepare_chat(self, clear_history, False, reply_at_receive) # type: ignore[arg-type]
def _raise_exception_on_async_reply_functions(self) -> None:
pass
def set_ui_tools(self, tools: Optional[list] = None) -> None:
"""Set UI tools for the agent."""
pass
def unset_ui_tools(self) -> None:
"""Unset UI tools for the agent."""
pass
@staticmethod
def _last_msg_as_summary(sender: Agent, recipient: Agent, summary_args: Optional[dict[str, Any]]) -> str:
"""Get a chat summary from the last message of the recipient."""
summary = ""
try:
content = recipient.last_message(sender)["content"] # type: ignore[attr-defined]
if isinstance(content, str):
summary = content.replace("TERMINATE", "")
elif isinstance(content, list):
summary = "\n".join(
x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x
)
except (IndexError, AttributeError) as e:
warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning)
return summary
# check that the SwarmableAgent class is implementing LLMAgent protocol
if TYPE_CHECKING:
def _create_swarmable_agent(
name: str,
system_message: str,
is_termination_msg: Optional[Callable[..., bool]],
description: Optional[str],
silent: Optional[bool],
) -> LLMAgent:
return SwarmableAgent(
name=name,
system_message=system_message,
is_termination_msg=is_termination_msg,
description=description,
silent=silent,
)
class SwarmableRealtimeAgent(SwarmableAgent):
def __init__(
self,
realtime_agent: "RealtimeAgent",
initial_agent: ConversableAgent,
agents: list[ConversableAgent],
question_message: Optional[str] = None,
) -> None:
self._initial_agent = initial_agent
self._agents = agents
self._realtime_agent = realtime_agent
self._answer_event: anyio.Event = anyio.Event()
self._answer: str = ""
self.question_message = question_message or QUESTION_MESSAGE
super().__init__(
name=realtime_agent._name,
is_termination_msg=None,
description=None,
silent=None,
)
def reset_answer(self) -> None:
"""Reset the answer event."""
self._answer_event = anyio.Event()
def set_answer(self, answer: str) -> str:
"""Set the answer to the question."""
self._answer = answer
self._answer_event.set()
return "Answer set successfully."
async def get_answer(self) -> str:
"""Get the answer to the question."""
await self._answer_event.wait()
return self._answer
async def ask_question(self, question: str, question_timeout: int) -> None:
"""Send a question for the user to the agent and wait for the answer.
If the answer is not received within the timeout, the question is repeated.
Args:
question: The question to ask the user.
question_timeout: The time in seconds to wait for the answer.
"""
self.reset_answer()
realtime_client = self._realtime_agent._realtime_client
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
async def _check_event_set(timeout: int = question_timeout) -> bool:
for _ in range(timeout):
if self._answer_event.is_set():
return True
await anyio.sleep(1)
return False
while not await _check_event_set():
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
def check_termination_and_human_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Optional[str]]:
"""Check if the conversation should be terminated and if the agent should reply.
Called when its agents turn in the chat conversation.
Args:
messages (list[dict[str, Any]]): The messages in the conversation.
sender (Agent): The agent that sent the message.
config (Optional[Any]): The configuration for the agent.
"""
if not messages:
return False, None
async def get_input() -> None:
async with create_task_group() as tg:
tg.soonify(self.ask_question)(
self.question_message.format(messages[-1]["content"]),
question_timeout=QUESTION_TIMEOUT_SECONDS,
)
syncify(get_input)()
return True, {"role": "user", "content": self._answer} # type: ignore[return-value]
def start_chat(self) -> None:
raise NotImplementedError
def configure_realtime_agent(self, system_message: Optional[str]) -> None:
realtime_agent = self._realtime_agent
logger = realtime_agent.logger
if not system_message:
if realtime_agent.system_message != "You are a helpful AI Assistant.":
logger.warning(
"Overriding system message set up in `__init__`, please use `system_message` parameter of the `register_swarm` function instead."
)
system_message = SWARM_SYSTEM_MESSAGE
realtime_agent._system_message = system_message
realtime_agent.register_realtime_function(
name="answer_task_question", description="Answer question from the task"
)(self.set_answer)
async def on_observers_ready() -> None:
self._realtime_agent._tg.soonify(asyncify(initiate_swarm_chat))(
initial_agent=self._initial_agent,
agents=self._agents,
user_agent=self, # type: ignore[arg-type]
messages="Find out what the user wants.",
after_work=AfterWorkOption.REVERT_TO_USER,
)
self._realtime_agent.callbacks.on_observers_ready = on_observers_ready
@export_module("autogen.agentchat.realtime.experimental")
def register_swarm(
*,
realtime_agent: "RealtimeAgent",
initial_agent: ConversableAgent,
agents: list[ConversableAgent],
system_message: Optional[str] = None,
question_message: Optional[str] = None,
) -> None:
"""Create a SwarmableRealtimeAgent.
Args:
realtime_agent (RealtimeAgent): The RealtimeAgent to create the SwarmableRealtimeAgent from.
initial_agent (ConversableAgent): The initial agent.
agents (list[ConversableAgent]): The agents in the swarm.
system_message (Optional[str]): The system message to set for the agent. If None, the default system message is used.
question_message (Optional[str]): The question message to set for the agent. If None, the default QUESTION_MESSAGE is used.
"""
swarmable_agent = SwarmableRealtimeAgent(
realtime_agent=realtime_agent, initial_agent=initial_agent, agents=agents, question_message=question_message
)
swarmable_agent.configure_realtime_agent(system_message=system_message)