CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .audio_adapters import TwilioAudioAdapter, WebSocketAudioAdapter
from .audio_observer import AudioObserver
from .function_observer import FunctionObserver
from .realtime_agent import RealtimeAgent
from .realtime_observer import RealtimeObserver
from .realtime_swarm import register_swarm
__all__ = [
"AudioObserver",
"FunctionObserver",
"RealtimeAgent",
"RealtimeObserver",
"TwilioAudioAdapter",
"WebSocketAudioAdapter",
"register_swarm",
]

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .twilio_audio_adapter import TwilioAudioAdapter
from .websocket_audio_adapter import WebSocketAudioAdapter
__all__ = ["TwilioAudioAdapter", "WebSocketAudioAdapter"]

View File

@@ -0,0 +1,148 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import base64
import json
from logging import Logger
from typing import TYPE_CHECKING, Optional
from .....doc_utils import export_module
from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted
from ..realtime_observer import RealtimeObserver
if TYPE_CHECKING:
from ..websockets import WebSocketProtocol as WebSocket
LOG_EVENT_TYPES = [
"error",
"response.content.done",
"rate_limits.updated",
"response.done",
"input_audio_buffer.committed",
"input_audio_buffer.speech_stopped",
"input_audio_buffer.speech_started",
"session.created",
]
SHOW_TIMING_MATH = False
@export_module("autogen.agentchat.realtime.experimental")
class TwilioAudioAdapter(RealtimeObserver):
"""Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa."""
def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None):
"""Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa.
Args:
websocket: the websocket connection to the Twilio service
logger: the logger to use for logging events
"""
super().__init__(logger=logger)
self.websocket = websocket
# Connection specific state
self.stream_sid = None
self.latest_media_timestamp = 0
self.last_assistant_item: Optional[str] = None
self.mark_queue: list[str] = []
self.response_start_timestamp_twilio: Optional[int] = None
async def on_event(self, event: RealtimeEvent) -> None:
"""Receive events from the OpenAI Realtime API, send audio back to Twilio."""
logger = self.logger
if isinstance(event, AudioDelta):
audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8")
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
await self.websocket.send_json(audio_delta)
if self.response_start_timestamp_twilio is None:
self.response_start_timestamp_twilio = self.latest_media_timestamp
if SHOW_TIMING_MATH:
logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_twilio}ms")
# Update last_assistant_item safely
if event.item_id:
self.last_assistant_item = event.item_id
await self.send_mark()
# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
if isinstance(event, SpeechStarted):
logger.info("Speech start detected.")
if self.last_assistant_item:
logger.info(f"Interrupting response with id: {self.last_assistant_item}")
await self.handle_speech_started_event()
async def handle_speech_started_event(self) -> None:
"""Handle interruption when the caller's speech starts."""
logger = self.logger
logger.info("Handling speech started event.")
if self.mark_queue and self.response_start_timestamp_twilio is not None:
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_twilio
if SHOW_TIMING_MATH:
logger.info(
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_twilio} = {elapsed_time}ms"
)
if self.last_assistant_item:
if SHOW_TIMING_MATH:
logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")
await self.realtime_client.truncate_audio(
audio_end_ms=elapsed_time,
content_index=0,
item_id=self.last_assistant_item,
)
await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})
self.mark_queue.clear()
self.last_assistant_item = None
self.response_start_timestamp_twilio = None
async def send_mark(self) -> None:
"""Send a mark of audio interruption to the Twilio websocket."""
if self.stream_sid:
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
await self.websocket.send_json(mark_event)
self.mark_queue.append("responsePart")
async def run_loop(self) -> None:
"""Run the adapter loop."""
logger = self.logger
async for message in self.websocket.iter_text():
try:
data = json.loads(message)
if data["event"] == "media":
self.latest_media_timestamp = int(data["media"]["timestamp"])
await self.realtime_client.send_audio(audio=data["media"]["payload"])
elif data["event"] == "start":
self.stream_sid = data["start"]["streamSid"]
logger.info(f"Incoming stream has started {self.stream_sid}")
self.response_start_timestamp_twilio = None
self.latest_media_timestamp = 0
self.last_assistant_item = None
elif data["event"] == "mark":
if self.mark_queue:
self.mark_queue.pop(0)
except Exception as e:
logger.warning(f"Error processing Twilio message: {e}", stack_info=True)
async def initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {
"input_audio_format": "g711_ulaw",
"output_audio_format": "g711_ulaw",
}
await self.realtime_client.session_update(session_update)
if TYPE_CHECKING:
def twilio_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
return TwilioAudioAdapter(websocket)

View File

@@ -0,0 +1,139 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import base64
import json
from logging import Logger
from typing import TYPE_CHECKING, Optional
from .....doc_utils import export_module
from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted
from ..realtime_observer import RealtimeObserver
if TYPE_CHECKING:
from ..websockets import WebSocketProtocol as WebSocket
LOG_EVENT_TYPES = [
"error",
"response.content.done",
"rate_limits.updated",
"response.done",
"input_audio_buffer.committed",
"input_audio_buffer.speech_stopped",
"input_audio_buffer.speech_started",
"session.created",
]
SHOW_TIMING_MATH = False
@export_module("autogen.agentchat.realtime.experimental")
class WebSocketAudioAdapter(RealtimeObserver):
def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None) -> None:
"""Observer for handling function calls from the OpenAI Realtime API.
Args:
websocket (WebSocket): The websocket connection.
logger (Logger): The logger for the observer.
"""
super().__init__(logger=logger)
self.websocket = websocket
# Connection specific state
self.stream_sid = None
self.latest_media_timestamp = 0
self.last_assistant_item: Optional[str] = None
self.mark_queue: list[str] = []
self.response_start_timestamp_socket: Optional[int] = None
async def on_event(self, event: RealtimeEvent) -> None:
"""Receive events from the OpenAI Realtime API, send audio back to websocket."""
logger = self.logger
if isinstance(event, AudioDelta):
audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8")
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
await self.websocket.send_json(audio_delta)
if self.response_start_timestamp_socket is None:
self.response_start_timestamp_socket = self.latest_media_timestamp
if SHOW_TIMING_MATH:
logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_socket}ms")
# Update last_assistant_item safely
if event.item_id:
self.last_assistant_item = event.item_id
await self.send_mark()
# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
if isinstance(event, SpeechStarted):
logger.info("Speech start detected.")
if self.last_assistant_item:
logger.info(f"Interrupting response with id: {self.last_assistant_item}")
await self.handle_speech_started_event()
async def handle_speech_started_event(self) -> None:
"""Handle interruption when the caller's speech starts."""
logger = self.logger
logger.info("Handling speech started event.")
if self.mark_queue and self.response_start_timestamp_socket is not None:
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_socket
if SHOW_TIMING_MATH:
logger.info(
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_socket} = {elapsed_time}ms"
)
if self.last_assistant_item:
if SHOW_TIMING_MATH:
logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")
await self.realtime_client.truncate_audio(
audio_end_ms=elapsed_time,
content_index=0,
item_id=self.last_assistant_item,
)
await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})
self.mark_queue.clear()
self.last_assistant_item = None
self.response_start_timestamp_socket = None
async def send_mark(self) -> None:
if self.stream_sid:
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
await self.websocket.send_json(mark_event)
self.mark_queue.append("responsePart")
async def initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {"input_audio_format": "pcm16", "output_audio_format": "pcm16"}
await self.realtime_client.session_update(session_update)
async def run_loop(self) -> None:
"""Reads data from websocket and sends it to the RealtimeClient."""
logger = self.logger
async for message in self.websocket.iter_text():
try:
data = json.loads(message)
if data["event"] == "media":
self.latest_media_timestamp = int(data["media"]["timestamp"])
await self.realtime_client.send_audio(audio=data["media"]["payload"])
elif data["event"] == "start":
self.stream_sid = data["start"]["streamSid"]
logger.info(f"Incoming stream has started {self.stream_sid}")
self.response_start_timestamp_socket = None
self.latest_media_timestamp = 0
self.last_assistant_item = None
elif data["event"] == "mark":
if self.mark_queue:
self.mark_queue.pop(0)
except Exception as e:
logger.warning(f"Failed to process message: {e}", stack_info=True)
if TYPE_CHECKING:
def websocket_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
return WebSocketAudioAdapter(websocket)

View File

@@ -0,0 +1,42 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional
from ....doc_utils import export_module
from .realtime_events import InputAudioBufferDelta, RealtimeEvent
from .realtime_observer import RealtimeObserver
if TYPE_CHECKING:
from logging import Logger
@export_module("autogen.agentchat.realtime.experimental")
class AudioObserver(RealtimeObserver):
"""Observer for user voice input"""
def __init__(self, *, logger: Optional["Logger"] = None) -> None:
"""Observer for user voice input"""
super().__init__(logger=logger)
async def on_event(self, event: RealtimeEvent) -> None:
"""Observe voice input events from the Realtime.
Args:
event (dict[str, Any]): The event from the OpenAI Realtime API.
"""
if isinstance(event, InputAudioBufferDelta):
self.logger.info("Received audio buffer delta")
async def initialize_session(self) -> None:
"""No need to initialize session from this observer"""
pass
async def run_loop(self) -> None:
"""Run the observer loop."""
pass
if TYPE_CHECKING:
function_observer: RealtimeObserver = AudioObserver()

View File

@@ -0,0 +1,15 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .gemini.client import GeminiRealtimeClient
from .oai.base_client import OpenAIRealtimeClient
from .realtime_client import RealtimeClientProtocol, Role, get_client
__all__ = [
"GeminiRealtimeClient",
"OpenAIRealtimeClient",
"RealtimeClientProtocol",
"Role",
"get_client",
]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .client import GeminiRealtimeClient
__all__ = ["GeminiRealtimeClient"]

View File

@@ -0,0 +1,274 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from ......doc_utils import export_module
from ......import_utils import optional_import_block, require_optional_import
from ......llm_config import LLMConfig
from ...realtime_events import AudioDelta, FunctionCall, RealtimeEvent, SessionCreated
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
with optional_import_block():
from websockets.asyncio.client import connect
if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
from ..realtime_client import RealtimeClientProtocol
__all__ = ["GeminiRealtimeClient"]
global_logger = getLogger(__name__)
HOST = "generativelanguage.googleapis.com"
API_VERSION = "v1alpha"
@register_realtime_client()
@require_optional_import("websockets", "gemini", except_for=["get_factory", "__init__"])
@export_module("autogen.agentchat.realtime.experimental.clients")
class GeminiRealtimeClient(RealtimeClientBase):
"""(Experimental) Client for Gemini Realtime API."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for Gemini Realtime API.
Args:
llm_config: The config for the client.
logger: The logger for the client.
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._connection: Optional["ClientConnection"] = None
config = llm_config["config_list"][0]
self._model: str = config["model"]
self._voice = config.get("voice", "charon")
self._temperature: float = config.get("temperature", 0.8) # type: ignore[union-attr]
self._response_modality = "AUDIO"
self._api_key = config.get("api_key", None)
# todo: add test with base_url just to make sure it works
self._base_url: str = config.get(
"base_url",
f"wss://{HOST}/ws/google.ai.generativelanguage.{API_VERSION}.GenerativeService.BidiGenerateContent?key={self._api_key}",
)
self._final_config: dict[str, Any] = {}
self._pending_session_updates: dict[str, Any] = {}
self._is_reading_events = False
@property
def logger(self) -> Logger:
"""Get the logger for the Gemini Realtime API."""
return self._logger or global_logger
@property
def connection(self) -> "ClientConnection":
"""Get the Gemini WebSocket connection."""
if self._connection is None:
raise RuntimeError("Gemini WebSocket is not initialized")
return self._connection
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the Gemini Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
msg = {
"tool_response": {"function_responses": [{"id": call_id, "response": {"result": {"string_value": result}}}]}
}
if self._is_reading_events:
await self.connection.send(json.dumps(msg))
async def send_text(self, *, role: Role, text: str, turn_complete: bool = True) -> None:
"""Send a text message to the Gemini Realtime API.
Args:
role: The role of the message.
text: The text of the message.
turn_complete: A flag indicating if the turn is complete.
"""
msg = {
"client_content": {
"turn_complete": turn_complete,
"turns": [{"role": role, "parts": [{"text": text}]}],
}
}
if self._is_reading_events:
await self.connection.send(json.dumps(msg))
async def send_audio(self, audio: str) -> None:
"""Send audio to the Gemini Realtime API.
Args:
audio (str): The audio to send.
"""
msg = {
"realtime_input": {
"media_chunks": [
{
"data": audio,
"mime_type": "audio/pcm",
}
]
}
}
await self.queue_input_audio_buffer_delta(audio)
if self._is_reading_events:
await self.connection.send(json.dumps(msg))
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
self.logger.info("This is not natively supported by Gemini Realtime API.")
pass
async def _initialize_session(self) -> None:
"""Initialize the session with the Gemini Realtime API."""
session_config = {
"setup": {
"system_instruction": {
"role": "system",
"parts": [{"text": self._pending_session_updates.get("instructions", "")}],
},
"model": f"models/{self._model}",
"tools": [
{
"function_declarations": [
{
"name": tool_schema["name"],
"description": tool_schema["description"],
"parameters": tool_schema["parameters"],
}
for tool_schema in self._pending_session_updates.get("tools", [])
]
},
],
"generation_config": {
"response_modalities": [self._response_modality],
"speech_config": {"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": self._voice}}},
"temperature": self._temperature,
},
}
}
self.logger.info(f"Sending session update: {session_config}")
await self.connection.send(json.dumps(session_config))
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Record session updates to be applied when the connection is established.
Args:
session_options (dict[str, Any]): The session options to update.
"""
if self._is_reading_events:
self.logger.warning("Is reading events. Session update will be ignored.")
else:
self._pending_session_updates.update(session_options)
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the Gemini Realtime API."""
try:
async with connect(
self._base_url, additional_headers={"Content-Type": "application/json"}
) as self._connection:
yield
finally:
self._connection = None
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read Events from the Gemini Realtime Client"""
if self._connection is None:
raise RuntimeError("Client is not connected, call connect() first.")
await self._initialize_session()
self._is_reading_events = True
async for event in self._read_events():
yield event
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the Gemini Realtime connection."""
async for raw_message in self.connection:
message = raw_message.decode("ascii") if isinstance(raw_message, bytes) else raw_message
events = self._parse_message(json.loads(message))
for event in events:
yield event
def _parse_message(self, response: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the Gemini Realtime API.
Args:
response (dict[str, Any]): The response to parse.
Returns:
list[RealtimeEvent]: The parsed events.
"""
if "serverContent" in response and "modelTurn" in response["serverContent"]:
try:
b64data = response["serverContent"]["modelTurn"]["parts"][0]["inlineData"].pop("data")
return [
AudioDelta(
delta=b64data,
item_id=None,
raw_message=response,
)
]
except KeyError:
return []
elif "toolCall" in response:
return [
FunctionCall(
raw_message=response,
call_id=call["id"],
name=call["name"],
arguments=call["args"],
)
for call in response["toolCall"]["functionCalls"]
]
elif "setupComplete" in response:
return [
SessionCreated(raw_message=response),
]
else:
return [RealtimeEvent(raw_message=response)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The LLM config for the client.
logger: The logger for the client.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type") == "google" and list(kwargs.keys()) == []:
return lambda: GeminiRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if GeminiRealtimeClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_client: RealtimeClientProtocol = GeminiRealtimeClient(llm_config={})

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .base_client import OpenAIRealtimeClient
from .rtc_client import OpenAIRealtimeWebRTCClient
__all__ = ["OpenAIRealtimeClient", "OpenAIRealtimeWebRTCClient"]

View File

@@ -0,0 +1,220 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from ......doc_utils import export_module
from ......import_utils import optional_import_block, require_optional_import
from ......llm_config import LLMConfig
from ...realtime_events import RealtimeEvent
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message
with optional_import_block():
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection
if TYPE_CHECKING:
from ..realtime_client import RealtimeClientProtocol
__all__ = ["OpenAIRealtimeClient"]
global_logger = getLogger(__name__)
@register_realtime_client()
@require_optional_import("openai>=1.66.2", "openai-realtime", except_for=["get_factory", "__init__"])
@export_module("autogen.agentchat.realtime.experimental.clients")
class OpenAIRealtimeClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Args:
llm_config: The config for the client.
logger: the logger to use for logging events
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._connection: Optional["AsyncRealtimeConnection"] = None
self.config = llm_config["config_list"][0]
# model is passed to self._client.beta.realtime.connect function later
self._model: str = self.config["model"]
self._voice: str = self.config.get("voice", "alloy")
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
self._client: Optional["AsyncOpenAI"] = None
@property
def logger(self) -> Logger:
"""Get the logger for the OpenAI Realtime API."""
return self._logger or global_logger
@property
def connection(self) -> "AsyncRealtimeConnection":
"""Get the OpenAI WebSocket connection."""
if self._connection is None:
raise RuntimeError("OpenAI WebSocket is not initialized")
return self._connection
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
await self.connection.conversation.item.create(
item={
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
)
await self.connection.response.create()
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
await self.connection.response.cancel()
await self.connection.conversation.item.create(
item={"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}
)
await self.connection.response.create()
async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
Args:
audio (str): The audio to send.
"""
await self.queue_input_audio_buffer_delta(audio)
await self.connection.input_audio_buffer.append(audio=audio)
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
await self.connection.conversation.item.truncate(
audio_end_ms=audio_end_ms, content_index=content_index, item_id=item_id
)
async def _initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._voice,
"modalities": ["audio", "text"],
"temperature": self._temperature,
}
await self.session_update(session_options=session_update)
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
Args:
session_options (dict[str, Any]): The session options to update.
"""
logger = self.logger
logger.info(f"Sending session update: {session_options}")
await self.connection.session.update(session=session_options) # type: ignore[arg-type]
logger.info("Sending session update finished")
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the OpenAI Realtime API."""
try:
if not self._client:
self._client = AsyncOpenAI(
api_key=self.config.get("api_key", None),
organization=self.config.get("organization", None),
project=self.config.get("project", None),
base_url=self.config.get("base_url", None),
websocket_base_url=self.config.get("websocket_base_url", None),
timeout=self.config.get("timeout", NOT_GIVEN),
max_retries=self.config.get("max_retries", DEFAULT_MAX_RETRIES),
default_headers=self.config.get("default_headers", None),
default_query=self.config.get("default_query", None),
)
async with self._client.beta.realtime.connect(
model=self._model,
) as self._connection:
await self._initialize_session()
yield
finally:
self._connection = None
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API."""
if self._connection is None:
raise RuntimeError("Client is not connected, call connect() first.")
try:
async for event in self._read_events():
yield event
finally:
self._connection = None
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API."""
async for message in self._connection:
for event in self._parse_message(message.model_dump()):
yield event
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
return [parse_oai_message(message)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == []:
return lambda: OpenAIRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_client: RealtimeClientProtocol = OpenAIRealtimeClient(llm_config={})

View File

@@ -0,0 +1,243 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from autogen.import_utils import optional_import_block, require_optional_import
from ......doc_utils import export_module
from ......llm_config import LLMConfig
from ...realtime_events import RealtimeEvent
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message
if TYPE_CHECKING:
from ...websockets import WebSocketProtocol as WebSocket
from ..realtime_client import RealtimeClientProtocol
with optional_import_block():
import httpx
__all__ = ["OpenAIRealtimeWebRTCClient"]
global_logger = getLogger(__name__)
@register_realtime_client()
@require_optional_import("httpx", "openai-realtime", except_for="get_factory")
@export_module("autogen.agentchat.realtime.experimental.clients.oai")
class OpenAIRealtimeWebRTCClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
websocket: "WebSocket",
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Args:
llm_config: The config for the client.
websocket: the websocket to use for the connection
logger: the logger to use for logging events
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._websocket = websocket
config = llm_config["config_list"][0]
self._model: str = config["model"]
self._voice: str = config.get("voice", "alloy")
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
self._config = config
self._base_url = config.get("base_url", "https://api.openai.com/v1/realtime/sessions")
@property
def logger(self) -> Logger:
"""Get the logger for the OpenAI Realtime API."""
return self._logger or global_logger
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
await self._websocket.send_json({
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
})
await self._websocket.send_json({"type": "response.create"})
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
# await self.connection.response.cancel() #why is this here?
await self._websocket.send_json({
"type": "response.cancel",
})
await self._websocket.send_json({
"type": "conversation.item.create",
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
})
# await self.connection.response.create()
await self._websocket.send_json({"type": "response.create"})
async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
in case of WebRTC, audio is already sent by js client, so we just queue it in order to be logged.
Args:
audio (str): The audio to send.
"""
await self.queue_input_audio_buffer_delta(audio)
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
await self._websocket.send_json({
"type": "conversation.item.truncate",
"content_index": content_index,
"item_id": item_id,
"audio_end_ms": audio_end_ms,
})
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
In the case of WebRTC we can not send it directly, but we can send it
to the javascript over the websocket, and rely on it to send session
update to OpenAI
Args:
session_options (dict[str, Any]): The session options to update.
"""
logger = self.logger
logger.info(f"Sending session update: {session_options}")
# await self.connection.session.update(session=session_options) # type: ignore[arg-type]
await self._websocket.send_json({"type": "session.update", "session": session_options})
logger.info("Sending session update finished")
def session_init_data(self) -> list[dict[str, Any]]:
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._voice,
"modalities": ["audio", "text"],
"temperature": self._temperature,
}
return [{"type": "session.update", "session": session_update}]
async def _initialize_session(self) -> None: ...
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the OpenAI Realtime API.
In the case of WebRTC, we pass connection information over the
websocket, so that javascript on the other end of websocket open
actual connection to OpenAI
"""
try:
base_url = self._base_url
api_key = self._config.get("api_key", None)
headers = {
"Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment
"Content-Type": "application/json",
}
data = {
# "model": "gpt-4o-realtime-preview-2024-12-17",
"model": self._model,
"voice": self._voice,
}
async with httpx.AsyncClient() as client:
response = await client.post(base_url, headers=headers, json=data)
response.raise_for_status()
json_data = response.json()
json_data["model"] = self._model
if self._websocket is not None:
session_init = self.session_init_data()
await self._websocket.send_json({"type": "ag2.init", "config": json_data, "init": session_init})
yield
finally:
pass
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from the OpenAI Realtime API."""
async for event in self._read_events():
yield event
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API connection.
Again, in case of WebRTC, we do not read OpenAI messages directly since we
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
"""
while True:
try:
message_json = await self._websocket.receive_text()
message = json.loads(message_json)
for event in self._parse_message(message):
yield event
except Exception as e:
self.logger.exception(f"Error reading from connection {e}")
break
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
return [parse_oai_message(message)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == ["websocket"]:
return lambda: OpenAIRealtimeWebRTCClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol:
return OpenAIRealtimeWebRTCClient(llm_config={}, websocket=websocket)

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Any
from ...realtime_events import (
AudioDelta,
FunctionCall,
InputAudioBufferDelta,
RealtimeEvent,
SessionCreated,
SessionUpdated,
SpeechStarted,
)
__all__ = ["parse_oai_message"]
def parse_oai_message(message: dict[str, Any]) -> RealtimeEvent:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
if message.get("type") == "session.created":
return SessionCreated(raw_message=message)
elif message.get("type") == "session.updated":
return SessionUpdated(raw_message=message)
elif message.get("type") == "response.audio.delta":
return AudioDelta(raw_message=message, delta=message["delta"], item_id=message["item_id"])
elif message.get("type") == "input_audio_buffer.speech_started":
return SpeechStarted(raw_message=message)
elif message.get("type") == "input_audio_buffer.delta":
return InputAudioBufferDelta(delta=message["delta"], item_id=None, raw_message=message)
elif message.get("type") == "response.function_call_arguments.done":
return FunctionCall(
raw_message=message,
call_id=message["call_id"],
name=message["name"],
arguments=json.loads(message["arguments"]),
)
else:
return RealtimeEvent(raw_message=message)

View File

@@ -0,0 +1,190 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
from collections.abc import AsyncGenerator
from logging import Logger
from typing import Any, AsyncContextManager, Callable, Literal, Optional, Protocol, TypeVar, Union, runtime_checkable
from asyncer import create_task_group
from .....doc_utils import export_module
from .....llm_config import LLMConfig
from ..realtime_events import InputAudioBufferDelta, RealtimeEvent
__all__ = ["RealtimeClientProtocol", "Role", "get_client", "register_realtime_client"]
# define role literal type for typing
Role = Literal["user", "assistant", "system"]
@runtime_checkable
@export_module("autogen.agentchat.realtime.experimental.clients")
class RealtimeClientProtocol(Protocol):
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to a Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
...
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to a Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
...
async def send_audio(self, audio: str) -> None:
"""Send audio to a Realtime API.
Args:
audio (str): The audio to send.
"""
...
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in a Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
...
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to a Realtime API.
Args:
session_options (dict[str, Any]): The session options to update.
"""
...
def connect(self) -> AsyncContextManager[None]: ...
def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime Client."""
...
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime connection."""
...
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from a Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
list[RealtimeEvent]: The parsed events.
"""
...
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
...
class RealtimeClientBase:
def __init__(self):
self._eventQueue = asyncio.Queue()
async def add_event(self, event: Optional[RealtimeEvent]):
await self._eventQueue.put(event)
async def get_event(self) -> Optional[RealtimeEvent]:
return await self._eventQueue.get()
async def _read_from_connection_task(self):
async for event in self._read_from_connection():
await self.add_event(event)
await self.add_event(None)
async def _read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime Client."""
async with create_task_group() as tg:
tg.start_soon(self._read_from_connection_task)
while True:
try:
event = await self._eventQueue.get()
if event is not None:
yield event
else:
break
except Exception:
break
async def queue_input_audio_buffer_delta(self, audio: str) -> None:
"""queue InputAudioBufferDelta.
Args:
audio (str): The audio.
"""
await self.add_event(InputAudioBufferDelta(delta=audio, item_id=None, raw_message=dict()))
_realtime_client_classes: dict[str, type[RealtimeClientProtocol]] = {}
T = TypeVar("T", bound=RealtimeClientProtocol)
def register_realtime_client() -> Callable[[type[T]], type[T]]:
"""Register a Realtime API client.
Returns:
Callable[[type[T]], type[T]]: The decorator to register the Realtime API client
"""
def decorator(client_cls: type[T]) -> type[T]:
"""Register a Realtime API client.
Args:
client_cls: The client to register.
"""
global _realtime_client_classes
fqn = f"{client_cls.__module__}.{client_cls.__name__}"
_realtime_client_classes[fqn] = client_cls
return client_cls
return decorator
@export_module("autogen.agentchat.realtime.experimental.clients")
def get_client(llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any) -> "RealtimeClientProtocol":
"""Get a registered Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client.
"""
global _realtime_client_classes
for _, client_cls in _realtime_client_classes.items():
factory = client_cls.get_factory(llm_config=llm_config, logger=logger, **kwargs)
if factory:
return factory()
raise ValueError("Realtime API client not found.")

View File

@@ -0,0 +1,85 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
from typing import TYPE_CHECKING, Any, Optional
from asyncer import asyncify
from pydantic import BaseModel
from ....doc_utils import export_module
from .realtime_events import FunctionCall, RealtimeEvent
from .realtime_observer import RealtimeObserver
if TYPE_CHECKING:
from logging import Logger
@export_module("autogen.agentchat.realtime.experimental")
class FunctionObserver(RealtimeObserver):
"""Observer for handling function calls from the OpenAI Realtime API."""
def __init__(self, *, logger: Optional["Logger"] = None) -> None:
"""Observer for handling function calls from the OpenAI Realtime API."""
super().__init__(logger=logger)
async def on_event(self, event: RealtimeEvent) -> None:
"""Handle function call events from the OpenAI Realtime API.
Args:
event (dict[str, Any]): The event from the OpenAI Realtime API.
"""
if isinstance(event, FunctionCall):
self.logger.info("Received function call event")
await self.call_function(
call_id=event.call_id,
name=event.name,
kwargs=event.arguments,
)
async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -> None:
"""Call a function registered with the agent.
Args:
call_id (str): The ID of the function call.
name (str): The name of the function to call.
kwargs (Any[str, Any]): The arguments to pass to the function.
"""
if name in self.agent.registered_realtime_tools:
func = self.agent.registered_realtime_tools[name].func
func = func if asyncio.iscoroutinefunction(func) else asyncify(func)
try:
result = await func(**kwargs)
except Exception:
result = "Function call failed"
self.logger.info(f"Function call failed: {name=}, {kwargs=}", stack_info=True)
if isinstance(result, BaseModel):
result = result.model_dump_json()
elif not isinstance(result, str):
try:
result = json.dumps(result)
except Exception:
result = str(result)
await self.realtime_client.send_function_result(call_id, result)
else:
self.logger.warning(f"Function {name} called, but is not registered with the realtime agent.")
async def initialize_session(self) -> None:
"""Add registered tools to OpenAI with a session update."""
session_update = {
"tools": [tool.realtime_tool_schema for tool in self.agent.registered_realtime_tools.values()],
"tool_choice": "auto",
}
await self.realtime_client.session_update(session_update)
async def run_loop(self) -> None:
"""Run the observer loop."""
pass
if TYPE_CHECKING:
function_observer: RealtimeObserver = FunctionObserver()

View File

@@ -0,0 +1,158 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from logging import Logger, getLogger
from typing import Any, Callable, Optional, TypeVar, Union
from anyio import lowlevel
from asyncer import create_task_group
from ....doc_utils import export_module
from ....llm_config import LLMConfig
from ....tools import Tool
from .clients.realtime_client import RealtimeClientProtocol, get_client
from .function_observer import FunctionObserver
from .realtime_observer import RealtimeObserver
F = TypeVar("F", bound=Callable[..., Any])
global_logger = getLogger(__name__)
@dataclass
class RealtimeAgentCallbacks:
"""Callbacks for the Realtime Agent."""
# async empty placeholder function
on_observers_ready: Callable[[], Any] = lambda: lowlevel.checkpoint()
@export_module("autogen.agentchat.realtime.experimental")
class RealtimeAgent:
def __init__(
self,
*,
name: str,
audio_adapter: Optional[RealtimeObserver] = None,
system_message: str = "You are a helpful AI Assistant.",
llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
logger: Optional[Logger] = None,
observers: Optional[list[RealtimeObserver]] = None,
**client_kwargs: Any,
):
"""(Experimental) Agent for interacting with the Realtime Clients.
Args:
name (str): The name of the agent.
audio_adapter (Optional[RealtimeObserver] = None): The audio adapter for the agent.
system_message (str): The system message for the agent.
llm_config (LLMConfig, dict[str, Any], bool): The config for the agent.
logger (Optional[Logger]): The logger for the agent.
observers (Optional[list[RealtimeObserver]]): The additional observers for the agent.
**client_kwargs (Any): The keyword arguments for the client.
"""
self._logger = logger
self._name = name
self._system_message = system_message
llm_config = LLMConfig.get_current_llm_config(llm_config)
self._realtime_client: RealtimeClientProtocol = get_client(
llm_config=llm_config, logger=self.logger, **client_kwargs
)
self._registered_realtime_tools: dict[str, Tool] = {}
self._observers: list[RealtimeObserver] = observers if observers else []
self._observers.append(FunctionObserver(logger=logger))
if audio_adapter:
self._observers.append(audio_adapter)
self.callbacks = RealtimeAgentCallbacks()
@property
def system_message(self) -> str:
"""Get the system message for the agent."""
return self._system_message
@property
def logger(self) -> Logger:
"""Get the logger for the agent."""
return self._logger or global_logger
@property
def realtime_client(self) -> RealtimeClientProtocol:
"""Get the OpenAI Realtime Client."""
return self._realtime_client
@property
def registered_realtime_tools(self) -> dict[str, Tool]:
"""Get the registered realtime tools."""
return self._registered_realtime_tools
def register_observer(self, observer: RealtimeObserver) -> None:
"""Register an observer with the Realtime Agent.
Args:
observer (RealtimeObserver): The observer to register.
"""
self._observers.append(observer)
async def start_observers(self) -> None:
for observer in self._observers:
self._tg.soonify(observer.run)(self)
# wait for the observers to be ready
for observer in self._observers:
await observer.wait_for_ready()
await self.callbacks.on_observers_ready()
async def run(self) -> None:
"""Run the agent."""
# everything is run in the same task group to enable easy cancellation using self._tg.cancel_scope.cancel()
async with create_task_group() as self._tg: # noqa: SIM117
# connect with the client first (establishes a connection and initializes a session)
async with self._realtime_client.connect():
# start the observers and wait for them to be ready
await self.realtime_client.session_update(session_options={"instructions": self.system_message})
await self.start_observers()
# iterate over the events
async for event in self.realtime_client.read_events():
for observer in self._observers:
await observer.on_event(event)
def register_realtime_function(
self,
*,
name: Optional[str] = None,
description: Optional[str] = None,
) -> Callable[[Union[F, Tool]], Tool]:
"""Decorator for registering a function to be used by an agent.
Args:
name (str): The name of the function.
description (str): The description of the function.
Returns:
Callable[[Union[F, Tool]], Tool]: The decorator for registering a function.
"""
def _decorator(func_or_tool: Union[F, Tool]) -> Tool:
"""Decorator for registering a function to be used by an agent.
Args:
func_or_tool (Union[F, Tool]): The function or tool to register.
Returns:
Tool: The registered tool.
"""
tool = Tool(func_or_tool=func_or_tool, name=name, description=description)
self._registered_realtime_tools[tool.name] = tool
return tool
return _decorator

View File

@@ -0,0 +1,42 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Literal
from pydantic import BaseModel
class RealtimeEvent(BaseModel):
raw_message: dict[str, Any]
class SessionCreated(RealtimeEvent):
type: Literal["session.created"] = "session.created"
class SessionUpdated(RealtimeEvent):
type: Literal["session.updated"] = "session.updated"
class AudioDelta(RealtimeEvent):
type: Literal["response.audio.delta"] = "response.audio.delta"
delta: str
item_id: Any
class InputAudioBufferDelta(RealtimeEvent):
type: Literal["input_audio_buffer.delta"] = "input_audio_buffer.delta"
delta: str
item_id: Any
class SpeechStarted(RealtimeEvent):
type: Literal["input_audio_buffer.speech_started"] = "input_audio_buffer.speech_started"
class FunctionCall(RealtimeEvent):
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
name: str
arguments: dict[str, Any]
call_id: str

View File

@@ -0,0 +1,100 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Optional
from anyio import Event
from ....doc_utils import export_module
from .clients.realtime_client import RealtimeClientProtocol
from .realtime_events import RealtimeEvent
if TYPE_CHECKING:
from .realtime_agent import RealtimeAgent
__all__ = ["RealtimeObserver"]
global_logger = getLogger(__name__)
@export_module("autogen.agentchat.realtime.experimental")
class RealtimeObserver(ABC):
"""Observer for the OpenAI Realtime API."""
def __init__(self, *, logger: Optional[Logger] = None) -> None:
"""Observer for the OpenAI Realtime API.
Args:
logger (Logger): The logger for the observer.
"""
self._ready_event = Event()
self._agent: Optional[RealtimeAgent] = None
self._logger = logger
@property
def logger(self) -> Logger:
return self._logger or global_logger
@property
def agent(self) -> "RealtimeAgent":
if self._agent is None:
raise RuntimeError("Agent has not been set.")
return self._agent
@property
def realtime_client(self) -> RealtimeClientProtocol:
if self._agent is None:
raise RuntimeError("Agent has not been set.")
if self._agent.realtime_client is None:
raise RuntimeError("Realtime client has not been set.")
return self._agent.realtime_client
async def run(self, agent: "RealtimeAgent") -> None:
"""Run the observer with the agent.
When implementing, be sure to call `self._ready_event.set()` when the observer is ready to process events.
Args:
agent (RealtimeAgent): The realtime agent attached to the observer.
"""
self._agent = agent
await self.initialize_session()
self._ready_event.set()
await self.run_loop()
@abstractmethod
async def run_loop(self) -> None:
"""Run the loop if needed.
This method is called after the observer is ready to process events.
Events will be processed by the on_event method, this is just a hook for additional processing.
Use initialize_session to set up the session.
"""
...
@abstractmethod
async def initialize_session(self) -> None:
"""Initialize the session for the observer."""
...
async def wait_for_ready(self) -> None:
"""Get the event that is set when the observer is ready."""
await self._ready_event.wait()
@abstractmethod
async def on_event(self, event: RealtimeEvent) -> None:
"""Handle an event from the OpenAI Realtime API.
Args:
event (RealtimeServerEvent): The event from the OpenAI Realtime API.
"""
...
async def on_close(self) -> None:
"""Handle close of RealtimeClient."""
...

View File

@@ -0,0 +1,483 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import logging
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
import anyio
from asyncer import asyncify, create_task_group, syncify
from ....agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat
from ....cache import AbstractCache
from ....code_utils import content_str
from ....doc_utils import export_module
from ... import Agent, ChatResult, ConversableAgent, LLMAgent
from ...utils import consolidate_chat_info, gather_usage_summary
if TYPE_CHECKING:
from .clients import Role
from .realtime_agent import RealtimeAgent
__all__ = ["register_swarm"]
SWARM_SYSTEM_MESSAGE = (
"You are a helpful voice assistant. Your task is to listen to user and to coordinate the tasks based on his/her inputs."
"You can and will communicate using audio output only."
)
QUESTION_ROLE: "Role" = "user"
QUESTION_MESSAGE = (
"I have a question/information for myself. DO NOT ANSWER YOURSELF, GET THE ANSWER FROM ME. "
"repeat the question to me **WITH AUDIO OUTPUT** and AFTER YOU GET THE ANSWER FROM ME call 'answer_task_question' with the answer in first person\n\n"
"IMPORTANT: repeat just the question, without any additional information or context\n\n"
"The question is: '{}'\n\n"
)
QUESTION_TIMEOUT_SECONDS = 20
logger = logging.getLogger(__name__)
F = TypeVar("F", bound=Callable[..., Any])
def message_to_dict(message: Union[dict[str, Any], str]) -> dict[str, Any]:
if isinstance(message, str):
return {"content": message}
elif isinstance(message, dict):
return message
else:
return dict(message)
def parse_oai_message(message: Union[dict[str, Any], str], role: str, adressee: Agent) -> dict[str, Any]:
"""
Parse a message into an OpenAI-compatible message format.
Args:
message: The message to parse.
role: The role associated with the message.
adressee: The agent that will receive the message.
Returns:
The parsed message in OpenAI-compatible format.
Raises:
ValueError: If the message lacks required fields like 'content', 'function_call', or 'tool_calls'.
"""
message = message_to_dict(message)
# Extract relevant fields while ensuring none are None
oai_message = {
key: message[key]
for key in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context")
if key in message and message[key] is not None
}
# Validate or set the content field
if "content" not in oai_message:
if "function_call" in oai_message or "tool_calls" in oai_message:
oai_message["content"] = None
else:
raise ValueError("Message must have either 'content', 'function_call', or 'tool_calls' field.")
# Determine and assign the role
if message.get("role") in ["function", "tool"]:
oai_message["role"] = message["role"]
# Ensure all tool responses have string content
for tool_response in oai_message.get("tool_responses", []):
tool_response["content"] = str(tool_response["content"])
elif "override_role" in message:
oai_message["role"] = message["override_role"]
else:
oai_message["role"] = role
# Enforce specific role requirements for assistant messages
if oai_message.get("function_call") or oai_message.get("tool_calls"):
oai_message["role"] = "assistant"
# Add a name field if missing
if "name" not in oai_message:
oai_message["name"] = adressee.name
return oai_message
class SwarmableAgent(Agent):
"""A class for an agent that can participate in a swarm chat."""
def __init__(
self,
name: str,
system_message: str = "You are a helpful AI Assistant.",
is_termination_msg: Optional[Callable[..., bool]] = None,
description: Optional[str] = None,
silent: Optional[bool] = None,
):
self._oai_messages: dict[Agent, Any] = defaultdict(list)
self._system_message = system_message
self._description = description if description is not None else system_message
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
self.silent = silent
self._name = name
# Initialize standalone client cache object.
self.client_cache = None
self.previous_cache = None
self.reply_at_receive: dict[Agent, bool] = defaultdict(bool)
@property
def system_message(self) -> str:
return self._system_message
def update_system_message(self, system_message: str) -> None:
"""Update this agent's system message.
Args:
system_message (str): system message for inference.
"""
self._system_message = system_message
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
def send(
self,
message: Union[dict[str, Any], str],
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> None:
self._oai_messages[recipient].append(parse_oai_message(message, "assistant", recipient))
recipient.receive(message, self, request_reply)
def receive(
self,
message: Union[dict[str, Any], str],
sender: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> None:
self._oai_messages[sender].append(parse_oai_message(message, "user", self))
if request_reply is False or (request_reply is None and self.reply_at_receive[sender] is False):
return
reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender)
if reply is not None:
self.send(reply, sender, silent=silent)
def generate_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, dict[str, Any], None]:
if messages is None:
if sender is None:
raise ValueError("Either messages or sender must be provided.")
messages = self._oai_messages[sender]
_, reply = self.check_termination_and_human_reply(messages=messages, sender=sender, config=None)
return reply
def check_termination_and_human_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Union[str, None]]:
raise NotImplementedError
def initiate_chat(
self,
recipient: ConversableAgent,
message: Union[dict[str, Any], str],
clear_history: bool = True,
silent: Optional[bool] = False,
cache: Optional[AbstractCache] = None,
summary_args: Optional[dict[str, Any]] = {},
**kwargs: dict[str, Any],
) -> ChatResult:
_chat_info = locals().copy()
_chat_info["sender"] = self
consolidate_chat_info(_chat_info, uniform_sender=self)
recipient._raise_exception_on_async_reply_functions()
recipient.previous_cache = recipient.client_cache # type: ignore[attr-defined]
recipient.client_cache = cache # type: ignore[attr-defined, assignment]
self._prepare_chat(recipient, clear_history)
self.send(message, recipient, silent=silent)
summary = self._last_msg_as_summary(self, recipient, summary_args)
recipient.client_cache = recipient.previous_cache # type: ignore[attr-defined]
recipient.previous_cache = None # type: ignore[attr-defined]
chat_result = ChatResult(
chat_history=self.chat_messages[recipient],
summary=summary,
cost=gather_usage_summary([self, recipient]), # type: ignore[arg-type]
human_input=[],
)
return chat_result
async def a_generate_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, dict[str, Any], None]:
return self.generate_reply(messages=messages, sender=sender, **kwargs)
async def a_receive(
self,
message: Union[dict[str, Any], str],
sender: "Agent",
request_reply: Optional[bool] = None,
) -> None:
self.receive(message, sender, request_reply)
async def a_send(
self,
message: Union[dict[str, Any], str],
recipient: "Agent",
request_reply: Optional[bool] = None,
) -> None:
self.send(message, recipient, request_reply)
@property
def chat_messages(self) -> dict[Agent, list[dict[str, Any]]]:
"""A dictionary of conversations from agent to list of messages."""
return self._oai_messages
def last_message(self, agent: Optional[Agent] = None) -> Optional[dict[str, Any]]:
if agent is None:
n_conversations = len(self._oai_messages)
if n_conversations == 0:
return None
if n_conversations == 1:
for conversation in self._oai_messages.values():
return conversation[-1] # type: ignore[no-any-return]
raise ValueError("More than one conversation is found. Please specify the sender to get the last message.")
if agent not in self._oai_messages():
raise KeyError(
f"The agent '{agent.name}' is not present in any conversation. No history available for this agent."
)
return self._oai_messages[agent][-1] # type: ignore[no-any-return]
def _prepare_chat(
self,
recipient: ConversableAgent,
clear_history: bool,
prepare_recipient: bool = True,
reply_at_receive: bool = True,
) -> None:
self.reply_at_receive[recipient] = reply_at_receive
if clear_history:
self._oai_messages[recipient].clear()
if prepare_recipient:
recipient._prepare_chat(self, clear_history, False, reply_at_receive) # type: ignore[arg-type]
def _raise_exception_on_async_reply_functions(self) -> None:
pass
def set_ui_tools(self, tools: Optional[list] = None) -> None:
"""Set UI tools for the agent."""
pass
def unset_ui_tools(self) -> None:
"""Unset UI tools for the agent."""
pass
@staticmethod
def _last_msg_as_summary(sender: Agent, recipient: Agent, summary_args: Optional[dict[str, Any]]) -> str:
"""Get a chat summary from the last message of the recipient."""
summary = ""
try:
content = recipient.last_message(sender)["content"] # type: ignore[attr-defined]
if isinstance(content, str):
summary = content.replace("TERMINATE", "")
elif isinstance(content, list):
summary = "\n".join(
x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x
)
except (IndexError, AttributeError) as e:
warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning)
return summary
# check that the SwarmableAgent class is implementing LLMAgent protocol
if TYPE_CHECKING:
def _create_swarmable_agent(
name: str,
system_message: str,
is_termination_msg: Optional[Callable[..., bool]],
description: Optional[str],
silent: Optional[bool],
) -> LLMAgent:
return SwarmableAgent(
name=name,
system_message=system_message,
is_termination_msg=is_termination_msg,
description=description,
silent=silent,
)
class SwarmableRealtimeAgent(SwarmableAgent):
def __init__(
self,
realtime_agent: "RealtimeAgent",
initial_agent: ConversableAgent,
agents: list[ConversableAgent],
question_message: Optional[str] = None,
) -> None:
self._initial_agent = initial_agent
self._agents = agents
self._realtime_agent = realtime_agent
self._answer_event: anyio.Event = anyio.Event()
self._answer: str = ""
self.question_message = question_message or QUESTION_MESSAGE
super().__init__(
name=realtime_agent._name,
is_termination_msg=None,
description=None,
silent=None,
)
def reset_answer(self) -> None:
"""Reset the answer event."""
self._answer_event = anyio.Event()
def set_answer(self, answer: str) -> str:
"""Set the answer to the question."""
self._answer = answer
self._answer_event.set()
return "Answer set successfully."
async def get_answer(self) -> str:
"""Get the answer to the question."""
await self._answer_event.wait()
return self._answer
async def ask_question(self, question: str, question_timeout: int) -> None:
"""Send a question for the user to the agent and wait for the answer.
If the answer is not received within the timeout, the question is repeated.
Args:
question: The question to ask the user.
question_timeout: The time in seconds to wait for the answer.
"""
self.reset_answer()
realtime_client = self._realtime_agent._realtime_client
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
async def _check_event_set(timeout: int = question_timeout) -> bool:
for _ in range(timeout):
if self._answer_event.is_set():
return True
await anyio.sleep(1)
return False
while not await _check_event_set():
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
def check_termination_and_human_reply(
self,
messages: Optional[list[dict[str, Any]]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Optional[str]]:
"""Check if the conversation should be terminated and if the agent should reply.
Called when its agents turn in the chat conversation.
Args:
messages (list[dict[str, Any]]): The messages in the conversation.
sender (Agent): The agent that sent the message.
config (Optional[Any]): The configuration for the agent.
"""
if not messages:
return False, None
async def get_input() -> None:
async with create_task_group() as tg:
tg.soonify(self.ask_question)(
self.question_message.format(messages[-1]["content"]),
question_timeout=QUESTION_TIMEOUT_SECONDS,
)
syncify(get_input)()
return True, {"role": "user", "content": self._answer} # type: ignore[return-value]
def start_chat(self) -> None:
raise NotImplementedError
def configure_realtime_agent(self, system_message: Optional[str]) -> None:
realtime_agent = self._realtime_agent
logger = realtime_agent.logger
if not system_message:
if realtime_agent.system_message != "You are a helpful AI Assistant.":
logger.warning(
"Overriding system message set up in `__init__`, please use `system_message` parameter of the `register_swarm` function instead."
)
system_message = SWARM_SYSTEM_MESSAGE
realtime_agent._system_message = system_message
realtime_agent.register_realtime_function(
name="answer_task_question", description="Answer question from the task"
)(self.set_answer)
async def on_observers_ready() -> None:
self._realtime_agent._tg.soonify(asyncify(initiate_swarm_chat))(
initial_agent=self._initial_agent,
agents=self._agents,
user_agent=self, # type: ignore[arg-type]
messages="Find out what the user wants.",
after_work=AfterWorkOption.REVERT_TO_USER,
)
self._realtime_agent.callbacks.on_observers_ready = on_observers_ready
@export_module("autogen.agentchat.realtime.experimental")
def register_swarm(
*,
realtime_agent: "RealtimeAgent",
initial_agent: ConversableAgent,
agents: list[ConversableAgent],
system_message: Optional[str] = None,
question_message: Optional[str] = None,
) -> None:
"""Create a SwarmableRealtimeAgent.
Args:
realtime_agent (RealtimeAgent): The RealtimeAgent to create the SwarmableRealtimeAgent from.
initial_agent (ConversableAgent): The initial agent.
agents (list[ConversableAgent]): The agents in the swarm.
system_message (Optional[str]): The system message to set for the agent. If None, the default system message is used.
question_message (Optional[str]): The question message to set for the agent. If None, the default QUESTION_MESSAGE is used.
"""
swarmable_agent = SwarmableRealtimeAgent(
realtime_agent=realtime_agent, initial_agent=initial_agent, agents=agents, question_message=question_message
)
swarmable_agent.configure_realtime_agent(system_message=system_message)

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncIterator
from typing import Any, Protocol, runtime_checkable
__all__ = ["WebSocketProtocol"]
@runtime_checkable
class WebSocketProtocol(Protocol):
"""WebSocket protocol for sending and receiving JSON data modelled after FastAPI's WebSocket."""
async def send_json(self, data: Any, mode: str = "text") -> None: ...
async def receive_json(self, mode: str = "text") -> Any: ...
async def receive_text(self) -> str: ...
def iter_text(self) -> AsyncIterator[str]: ...