CoACT initialize (#292)
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .audio_adapters import TwilioAudioAdapter, WebSocketAudioAdapter
|
||||
from .audio_observer import AudioObserver
|
||||
from .function_observer import FunctionObserver
|
||||
from .realtime_agent import RealtimeAgent
|
||||
from .realtime_observer import RealtimeObserver
|
||||
from .realtime_swarm import register_swarm
|
||||
|
||||
__all__ = [
|
||||
"AudioObserver",
|
||||
"FunctionObserver",
|
||||
"RealtimeAgent",
|
||||
"RealtimeObserver",
|
||||
"TwilioAudioAdapter",
|
||||
"WebSocketAudioAdapter",
|
||||
"register_swarm",
|
||||
]
|
||||
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .twilio_audio_adapter import TwilioAudioAdapter
|
||||
from .websocket_audio_adapter import WebSocketAudioAdapter
|
||||
|
||||
__all__ = ["TwilioAudioAdapter", "WebSocketAudioAdapter"]
|
||||
@@ -0,0 +1,148 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import base64
|
||||
import json
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted
|
||||
from ..realtime_observer import RealtimeObserver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..websockets import WebSocketProtocol as WebSocket
|
||||
|
||||
|
||||
LOG_EVENT_TYPES = [
|
||||
"error",
|
||||
"response.content.done",
|
||||
"rate_limits.updated",
|
||||
"response.done",
|
||||
"input_audio_buffer.committed",
|
||||
"input_audio_buffer.speech_stopped",
|
||||
"input_audio_buffer.speech_started",
|
||||
"session.created",
|
||||
]
|
||||
SHOW_TIMING_MATH = False
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.realtime.experimental")
|
||||
class TwilioAudioAdapter(RealtimeObserver):
|
||||
"""Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa."""
|
||||
|
||||
def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None):
|
||||
"""Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa.
|
||||
|
||||
Args:
|
||||
websocket: the websocket connection to the Twilio service
|
||||
logger: the logger to use for logging events
|
||||
"""
|
||||
super().__init__(logger=logger)
|
||||
self.websocket = websocket
|
||||
|
||||
# Connection specific state
|
||||
self.stream_sid = None
|
||||
self.latest_media_timestamp = 0
|
||||
self.last_assistant_item: Optional[str] = None
|
||||
self.mark_queue: list[str] = []
|
||||
self.response_start_timestamp_twilio: Optional[int] = None
|
||||
|
||||
async def on_event(self, event: RealtimeEvent) -> None:
|
||||
"""Receive events from the OpenAI Realtime API, send audio back to Twilio."""
|
||||
logger = self.logger
|
||||
|
||||
if isinstance(event, AudioDelta):
|
||||
audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8")
|
||||
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
|
||||
await self.websocket.send_json(audio_delta)
|
||||
|
||||
if self.response_start_timestamp_twilio is None:
|
||||
self.response_start_timestamp_twilio = self.latest_media_timestamp
|
||||
if SHOW_TIMING_MATH:
|
||||
logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_twilio}ms")
|
||||
|
||||
# Update last_assistant_item safely
|
||||
if event.item_id:
|
||||
self.last_assistant_item = event.item_id
|
||||
|
||||
await self.send_mark()
|
||||
|
||||
# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
|
||||
if isinstance(event, SpeechStarted):
|
||||
logger.info("Speech start detected.")
|
||||
if self.last_assistant_item:
|
||||
logger.info(f"Interrupting response with id: {self.last_assistant_item}")
|
||||
await self.handle_speech_started_event()
|
||||
|
||||
async def handle_speech_started_event(self) -> None:
|
||||
"""Handle interruption when the caller's speech starts."""
|
||||
logger = self.logger
|
||||
|
||||
logger.info("Handling speech started event.")
|
||||
if self.mark_queue and self.response_start_timestamp_twilio is not None:
|
||||
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_twilio
|
||||
if SHOW_TIMING_MATH:
|
||||
logger.info(
|
||||
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_twilio} = {elapsed_time}ms"
|
||||
)
|
||||
|
||||
if self.last_assistant_item:
|
||||
if SHOW_TIMING_MATH:
|
||||
logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")
|
||||
|
||||
await self.realtime_client.truncate_audio(
|
||||
audio_end_ms=elapsed_time,
|
||||
content_index=0,
|
||||
item_id=self.last_assistant_item,
|
||||
)
|
||||
|
||||
await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})
|
||||
|
||||
self.mark_queue.clear()
|
||||
self.last_assistant_item = None
|
||||
self.response_start_timestamp_twilio = None
|
||||
|
||||
async def send_mark(self) -> None:
|
||||
"""Send a mark of audio interruption to the Twilio websocket."""
|
||||
if self.stream_sid:
|
||||
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
|
||||
await self.websocket.send_json(mark_event)
|
||||
self.mark_queue.append("responsePart")
|
||||
|
||||
async def run_loop(self) -> None:
|
||||
"""Run the adapter loop."""
|
||||
logger = self.logger
|
||||
|
||||
async for message in self.websocket.iter_text():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
if data["event"] == "media":
|
||||
self.latest_media_timestamp = int(data["media"]["timestamp"])
|
||||
await self.realtime_client.send_audio(audio=data["media"]["payload"])
|
||||
elif data["event"] == "start":
|
||||
self.stream_sid = data["start"]["streamSid"]
|
||||
logger.info(f"Incoming stream has started {self.stream_sid}")
|
||||
self.response_start_timestamp_twilio = None
|
||||
self.latest_media_timestamp = 0
|
||||
self.last_assistant_item = None
|
||||
elif data["event"] == "mark":
|
||||
if self.mark_queue:
|
||||
self.mark_queue.pop(0)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing Twilio message: {e}", stack_info=True)
|
||||
|
||||
async def initialize_session(self) -> None:
|
||||
"""Control initial session with OpenAI."""
|
||||
session_update = {
|
||||
"input_audio_format": "g711_ulaw",
|
||||
"output_audio_format": "g711_ulaw",
|
||||
}
|
||||
await self.realtime_client.session_update(session_update)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def twilio_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
|
||||
return TwilioAudioAdapter(websocket)
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import base64
|
||||
import json
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted
|
||||
from ..realtime_observer import RealtimeObserver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..websockets import WebSocketProtocol as WebSocket
|
||||
|
||||
LOG_EVENT_TYPES = [
|
||||
"error",
|
||||
"response.content.done",
|
||||
"rate_limits.updated",
|
||||
"response.done",
|
||||
"input_audio_buffer.committed",
|
||||
"input_audio_buffer.speech_stopped",
|
||||
"input_audio_buffer.speech_started",
|
||||
"session.created",
|
||||
]
|
||||
SHOW_TIMING_MATH = False
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.realtime.experimental")
|
||||
class WebSocketAudioAdapter(RealtimeObserver):
|
||||
def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None) -> None:
|
||||
"""Observer for handling function calls from the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
websocket (WebSocket): The websocket connection.
|
||||
logger (Logger): The logger for the observer.
|
||||
"""
|
||||
super().__init__(logger=logger)
|
||||
self.websocket = websocket
|
||||
|
||||
# Connection specific state
|
||||
self.stream_sid = None
|
||||
self.latest_media_timestamp = 0
|
||||
self.last_assistant_item: Optional[str] = None
|
||||
self.mark_queue: list[str] = []
|
||||
self.response_start_timestamp_socket: Optional[int] = None
|
||||
|
||||
async def on_event(self, event: RealtimeEvent) -> None:
|
||||
"""Receive events from the OpenAI Realtime API, send audio back to websocket."""
|
||||
logger = self.logger
|
||||
|
||||
if isinstance(event, AudioDelta):
|
||||
audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8")
|
||||
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
|
||||
await self.websocket.send_json(audio_delta)
|
||||
|
||||
if self.response_start_timestamp_socket is None:
|
||||
self.response_start_timestamp_socket = self.latest_media_timestamp
|
||||
if SHOW_TIMING_MATH:
|
||||
logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_socket}ms")
|
||||
|
||||
# Update last_assistant_item safely
|
||||
if event.item_id:
|
||||
self.last_assistant_item = event.item_id
|
||||
|
||||
await self.send_mark()
|
||||
|
||||
# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
|
||||
if isinstance(event, SpeechStarted):
|
||||
logger.info("Speech start detected.")
|
||||
if self.last_assistant_item:
|
||||
logger.info(f"Interrupting response with id: {self.last_assistant_item}")
|
||||
await self.handle_speech_started_event()
|
||||
|
||||
async def handle_speech_started_event(self) -> None:
|
||||
"""Handle interruption when the caller's speech starts."""
|
||||
logger = self.logger
|
||||
logger.info("Handling speech started event.")
|
||||
if self.mark_queue and self.response_start_timestamp_socket is not None:
|
||||
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_socket
|
||||
if SHOW_TIMING_MATH:
|
||||
logger.info(
|
||||
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_socket} = {elapsed_time}ms"
|
||||
)
|
||||
|
||||
if self.last_assistant_item:
|
||||
if SHOW_TIMING_MATH:
|
||||
logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")
|
||||
|
||||
await self.realtime_client.truncate_audio(
|
||||
audio_end_ms=elapsed_time,
|
||||
content_index=0,
|
||||
item_id=self.last_assistant_item,
|
||||
)
|
||||
|
||||
await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})
|
||||
|
||||
self.mark_queue.clear()
|
||||
self.last_assistant_item = None
|
||||
self.response_start_timestamp_socket = None
|
||||
|
||||
async def send_mark(self) -> None:
|
||||
if self.stream_sid:
|
||||
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
|
||||
await self.websocket.send_json(mark_event)
|
||||
self.mark_queue.append("responsePart")
|
||||
|
||||
async def initialize_session(self) -> None:
|
||||
"""Control initial session with OpenAI."""
|
||||
session_update = {"input_audio_format": "pcm16", "output_audio_format": "pcm16"}
|
||||
await self.realtime_client.session_update(session_update)
|
||||
|
||||
async def run_loop(self) -> None:
|
||||
"""Reads data from websocket and sends it to the RealtimeClient."""
|
||||
logger = self.logger
|
||||
async for message in self.websocket.iter_text():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
if data["event"] == "media":
|
||||
self.latest_media_timestamp = int(data["media"]["timestamp"])
|
||||
await self.realtime_client.send_audio(audio=data["media"]["payload"])
|
||||
elif data["event"] == "start":
|
||||
self.stream_sid = data["start"]["streamSid"]
|
||||
logger.info(f"Incoming stream has started {self.stream_sid}")
|
||||
self.response_start_timestamp_socket = None
|
||||
self.latest_media_timestamp = 0
|
||||
self.last_assistant_item = None
|
||||
elif data["event"] == "mark":
|
||||
if self.mark_queue:
|
||||
self.mark_queue.pop(0)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process message: {e}", stack_info=True)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def websocket_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
|
||||
return WebSocketAudioAdapter(websocket)
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from .realtime_events import InputAudioBufferDelta, RealtimeEvent
|
||||
from .realtime_observer import RealtimeObserver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.realtime.experimental")
|
||||
class AudioObserver(RealtimeObserver):
|
||||
"""Observer for user voice input"""
|
||||
|
||||
def __init__(self, *, logger: Optional["Logger"] = None) -> None:
|
||||
"""Observer for user voice input"""
|
||||
super().__init__(logger=logger)
|
||||
|
||||
async def on_event(self, event: RealtimeEvent) -> None:
|
||||
"""Observe voice input events from the Realtime.
|
||||
|
||||
Args:
|
||||
event (dict[str, Any]): The event from the OpenAI Realtime API.
|
||||
"""
|
||||
if isinstance(event, InputAudioBufferDelta):
|
||||
self.logger.info("Received audio buffer delta")
|
||||
|
||||
async def initialize_session(self) -> None:
|
||||
"""No need to initialize session from this observer"""
|
||||
pass
|
||||
|
||||
async def run_loop(self) -> None:
|
||||
"""Run the observer loop."""
|
||||
pass
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
function_observer: RealtimeObserver = AudioObserver()
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .gemini.client import GeminiRealtimeClient
|
||||
from .oai.base_client import OpenAIRealtimeClient
|
||||
from .realtime_client import RealtimeClientProtocol, Role, get_client
|
||||
|
||||
__all__ = [
|
||||
"GeminiRealtimeClient",
|
||||
"OpenAIRealtimeClient",
|
||||
"RealtimeClientProtocol",
|
||||
"Role",
|
||||
"get_client",
|
||||
]
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .client import GeminiRealtimeClient
|
||||
|
||||
__all__ = ["GeminiRealtimeClient"]
|
||||
@@ -0,0 +1,274 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger, getLogger
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from ......doc_utils import export_module
|
||||
from ......import_utils import optional_import_block, require_optional_import
|
||||
from ......llm_config import LLMConfig
|
||||
from ...realtime_events import AudioDelta, FunctionCall, RealtimeEvent, SessionCreated
|
||||
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
|
||||
|
||||
with optional_import_block():
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
from ..realtime_client import RealtimeClientProtocol
|
||||
|
||||
__all__ = ["GeminiRealtimeClient"]
|
||||
|
||||
global_logger = getLogger(__name__)
|
||||
|
||||
|
||||
HOST = "generativelanguage.googleapis.com"
|
||||
API_VERSION = "v1alpha"
|
||||
|
||||
|
||||
@register_realtime_client()
|
||||
@require_optional_import("websockets", "gemini", except_for=["get_factory", "__init__"])
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||
class GeminiRealtimeClient(RealtimeClientBase):
|
||||
"""(Experimental) Client for Gemini Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
logger: Optional[Logger] = None,
|
||||
) -> None:
|
||||
"""(Experimental) Client for Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger for the client.
|
||||
"""
|
||||
super().__init__()
|
||||
self._llm_config = llm_config
|
||||
self._logger = logger
|
||||
|
||||
self._connection: Optional["ClientConnection"] = None
|
||||
config = llm_config["config_list"][0]
|
||||
|
||||
self._model: str = config["model"]
|
||||
self._voice = config.get("voice", "charon")
|
||||
self._temperature: float = config.get("temperature", 0.8) # type: ignore[union-attr]
|
||||
|
||||
self._response_modality = "AUDIO"
|
||||
|
||||
self._api_key = config.get("api_key", None)
|
||||
# todo: add test with base_url just to make sure it works
|
||||
self._base_url: str = config.get(
|
||||
"base_url",
|
||||
f"wss://{HOST}/ws/google.ai.generativelanguage.{API_VERSION}.GenerativeService.BidiGenerateContent?key={self._api_key}",
|
||||
)
|
||||
self._final_config: dict[str, Any] = {}
|
||||
self._pending_session_updates: dict[str, Any] = {}
|
||||
self._is_reading_events = False
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Get the logger for the Gemini Realtime API."""
|
||||
return self._logger or global_logger
|
||||
|
||||
@property
|
||||
def connection(self) -> "ClientConnection":
|
||||
"""Get the Gemini WebSocket connection."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Gemini WebSocket is not initialized")
|
||||
return self._connection
|
||||
|
||||
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||
"""Send the result of a function call to the Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
call_id (str): The ID of the function call.
|
||||
result (str): The result of the function call.
|
||||
"""
|
||||
msg = {
|
||||
"tool_response": {"function_responses": [{"id": call_id, "response": {"result": {"string_value": result}}}]}
|
||||
}
|
||||
if self._is_reading_events:
|
||||
await self.connection.send(json.dumps(msg))
|
||||
|
||||
async def send_text(self, *, role: Role, text: str, turn_complete: bool = True) -> None:
|
||||
"""Send a text message to the Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
role: The role of the message.
|
||||
text: The text of the message.
|
||||
turn_complete: A flag indicating if the turn is complete.
|
||||
"""
|
||||
msg = {
|
||||
"client_content": {
|
||||
"turn_complete": turn_complete,
|
||||
"turns": [{"role": role, "parts": [{"text": text}]}],
|
||||
}
|
||||
}
|
||||
if self._is_reading_events:
|
||||
await self.connection.send(json.dumps(msg))
|
||||
|
||||
async def send_audio(self, audio: str) -> None:
|
||||
"""Send audio to the Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
audio (str): The audio to send.
|
||||
"""
|
||||
msg = {
|
||||
"realtime_input": {
|
||||
"media_chunks": [
|
||||
{
|
||||
"data": audio,
|
||||
"mime_type": "audio/pcm",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await self.queue_input_audio_buffer_delta(audio)
|
||||
if self._is_reading_events:
|
||||
await self.connection.send(json.dumps(msg))
|
||||
|
||||
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||
self.logger.info("This is not natively supported by Gemini Realtime API.")
|
||||
pass
|
||||
|
||||
async def _initialize_session(self) -> None:
|
||||
"""Initialize the session with the Gemini Realtime API."""
|
||||
session_config = {
|
||||
"setup": {
|
||||
"system_instruction": {
|
||||
"role": "system",
|
||||
"parts": [{"text": self._pending_session_updates.get("instructions", "")}],
|
||||
},
|
||||
"model": f"models/{self._model}",
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": tool_schema["name"],
|
||||
"description": tool_schema["description"],
|
||||
"parameters": tool_schema["parameters"],
|
||||
}
|
||||
for tool_schema in self._pending_session_updates.get("tools", [])
|
||||
]
|
||||
},
|
||||
],
|
||||
"generation_config": {
|
||||
"response_modalities": [self._response_modality],
|
||||
"speech_config": {"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": self._voice}}},
|
||||
"temperature": self._temperature,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
self.logger.info(f"Sending session update: {session_config}")
|
||||
await self.connection.send(json.dumps(session_config))
|
||||
|
||||
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||
"""Record session updates to be applied when the connection is established.
|
||||
|
||||
Args:
|
||||
session_options (dict[str, Any]): The session options to update.
|
||||
"""
|
||||
if self._is_reading_events:
|
||||
self.logger.warning("Is reading events. Session update will be ignored.")
|
||||
else:
|
||||
self._pending_session_updates.update(session_options)
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self) -> AsyncGenerator[None, None]:
|
||||
"""Connect to the Gemini Realtime API."""
|
||||
try:
|
||||
async with connect(
|
||||
self._base_url, additional_headers={"Content-Type": "application/json"}
|
||||
) as self._connection:
|
||||
yield
|
||||
finally:
|
||||
self._connection = None
|
||||
|
||||
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read Events from the Gemini Realtime Client"""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Client is not connected, call connect() first.")
|
||||
await self._initialize_session()
|
||||
|
||||
self._is_reading_events = True
|
||||
|
||||
async for event in self._read_events():
|
||||
yield event
|
||||
|
||||
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read messages from the Gemini Realtime connection."""
|
||||
async for raw_message in self.connection:
|
||||
message = raw_message.decode("ascii") if isinstance(raw_message, bytes) else raw_message
|
||||
events = self._parse_message(json.loads(message))
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def _parse_message(self, response: dict[str, Any]) -> list[RealtimeEvent]:
|
||||
"""Parse a message from the Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
response (dict[str, Any]): The response to parse.
|
||||
|
||||
Returns:
|
||||
list[RealtimeEvent]: The parsed events.
|
||||
"""
|
||||
if "serverContent" in response and "modelTurn" in response["serverContent"]:
|
||||
try:
|
||||
b64data = response["serverContent"]["modelTurn"]["parts"][0]["inlineData"].pop("data")
|
||||
return [
|
||||
AudioDelta(
|
||||
delta=b64data,
|
||||
item_id=None,
|
||||
raw_message=response,
|
||||
)
|
||||
]
|
||||
except KeyError:
|
||||
return []
|
||||
elif "toolCall" in response:
|
||||
return [
|
||||
FunctionCall(
|
||||
raw_message=response,
|
||||
call_id=call["id"],
|
||||
name=call["name"],
|
||||
arguments=call["args"],
|
||||
)
|
||||
for call in response["toolCall"]["functionCalls"]
|
||||
]
|
||||
elif "setupComplete" in response:
|
||||
return [
|
||||
SessionCreated(raw_message=response),
|
||||
]
|
||||
else:
|
||||
return [RealtimeEvent(raw_message=response)]
|
||||
|
||||
@classmethod
|
||||
def get_factory(
|
||||
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||
"""Create a Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The LLM config for the client.
|
||||
logger: The logger for the client.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||
"""
|
||||
if llm_config["config_list"][0].get("api_type") == "google" and list(kwargs.keys()) == []:
|
||||
return lambda: GeminiRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
# needed for mypy to check if GeminiRealtimeClient implements RealtimeClientProtocol
|
||||
if TYPE_CHECKING:
|
||||
_client: RealtimeClientProtocol = GeminiRealtimeClient(llm_config={})
|
||||
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .base_client import OpenAIRealtimeClient
|
||||
from .rtc_client import OpenAIRealtimeWebRTCClient
|
||||
|
||||
__all__ = ["OpenAIRealtimeClient", "OpenAIRealtimeWebRTCClient"]
|
||||
@@ -0,0 +1,220 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger, getLogger
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from ......doc_utils import export_module
|
||||
from ......import_utils import optional_import_block, require_optional_import
|
||||
from ......llm_config import LLMConfig
|
||||
from ...realtime_events import RealtimeEvent
|
||||
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
|
||||
from .utils import parse_oai_message
|
||||
|
||||
with optional_import_block():
|
||||
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
|
||||
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..realtime_client import RealtimeClientProtocol
|
||||
|
||||
__all__ = ["OpenAIRealtimeClient"]
|
||||
|
||||
global_logger = getLogger(__name__)
|
||||
|
||||
|
||||
@register_realtime_client()
|
||||
@require_optional_import("openai>=1.66.2", "openai-realtime", except_for=["get_factory", "__init__"])
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||
class OpenAIRealtimeClient(RealtimeClientBase):
|
||||
"""(Experimental) Client for OpenAI Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
logger: Optional[Logger] = None,
|
||||
) -> None:
|
||||
"""(Experimental) Client for OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: the logger to use for logging events
|
||||
"""
|
||||
super().__init__()
|
||||
self._llm_config = llm_config
|
||||
self._logger = logger
|
||||
|
||||
self._connection: Optional["AsyncRealtimeConnection"] = None
|
||||
|
||||
self.config = llm_config["config_list"][0]
|
||||
# model is passed to self._client.beta.realtime.connect function later
|
||||
self._model: str = self.config["model"]
|
||||
self._voice: str = self.config.get("voice", "alloy")
|
||||
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
|
||||
|
||||
self._client: Optional["AsyncOpenAI"] = None
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Get the logger for the OpenAI Realtime API."""
|
||||
return self._logger or global_logger
|
||||
|
||||
@property
|
||||
def connection(self) -> "AsyncRealtimeConnection":
|
||||
"""Get the OpenAI WebSocket connection."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("OpenAI WebSocket is not initialized")
|
||||
return self._connection
|
||||
|
||||
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||
"""Send the result of a function call to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
call_id (str): The ID of the function call.
|
||||
result (str): The result of the function call.
|
||||
"""
|
||||
await self.connection.conversation.item.create(
|
||||
item={
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": result,
|
||||
},
|
||||
)
|
||||
|
||||
await self.connection.response.create()
|
||||
|
||||
async def send_text(self, *, role: Role, text: str) -> None:
|
||||
"""Send a text message to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
role (str): The role of the message.
|
||||
text (str): The text of the message.
|
||||
"""
|
||||
await self.connection.response.cancel()
|
||||
await self.connection.conversation.item.create(
|
||||
item={"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}
|
||||
)
|
||||
await self.connection.response.create()
|
||||
|
||||
async def send_audio(self, audio: str) -> None:
|
||||
"""Send audio to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
audio (str): The audio to send.
|
||||
"""
|
||||
await self.queue_input_audio_buffer_delta(audio)
|
||||
await self.connection.input_audio_buffer.append(audio=audio)
|
||||
|
||||
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||
"""Truncate audio in the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
audio_end_ms (int): The end of the audio to truncate.
|
||||
content_index (int): The index of the content to truncate.
|
||||
item_id (str): The ID of the item to truncate.
|
||||
"""
|
||||
await self.connection.conversation.item.truncate(
|
||||
audio_end_ms=audio_end_ms, content_index=content_index, item_id=item_id
|
||||
)
|
||||
|
||||
async def _initialize_session(self) -> None:
|
||||
"""Control initial session with OpenAI."""
|
||||
session_update = {
|
||||
"turn_detection": {"type": "server_vad"},
|
||||
"voice": self._voice,
|
||||
"modalities": ["audio", "text"],
|
||||
"temperature": self._temperature,
|
||||
}
|
||||
await self.session_update(session_options=session_update)
|
||||
|
||||
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||
"""Send a session update to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
session_options (dict[str, Any]): The session options to update.
|
||||
"""
|
||||
logger = self.logger
|
||||
logger.info(f"Sending session update: {session_options}")
|
||||
await self.connection.session.update(session=session_options) # type: ignore[arg-type]
|
||||
logger.info("Sending session update finished")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self) -> AsyncGenerator[None, None]:
|
||||
"""Connect to the OpenAI Realtime API."""
|
||||
try:
|
||||
if not self._client:
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.config.get("api_key", None),
|
||||
organization=self.config.get("organization", None),
|
||||
project=self.config.get("project", None),
|
||||
base_url=self.config.get("base_url", None),
|
||||
websocket_base_url=self.config.get("websocket_base_url", None),
|
||||
timeout=self.config.get("timeout", NOT_GIVEN),
|
||||
max_retries=self.config.get("max_retries", DEFAULT_MAX_RETRIES),
|
||||
default_headers=self.config.get("default_headers", None),
|
||||
default_query=self.config.get("default_query", None),
|
||||
)
|
||||
async with self._client.beta.realtime.connect(
|
||||
model=self._model,
|
||||
) as self._connection:
|
||||
await self._initialize_session()
|
||||
yield
|
||||
finally:
|
||||
self._connection = None
|
||||
|
||||
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read messages from the OpenAI Realtime API."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Client is not connected, call connect() first.")
|
||||
|
||||
try:
|
||||
async for event in self._read_events():
|
||||
yield event
|
||||
|
||||
finally:
|
||||
self._connection = None
|
||||
|
||||
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read messages from the OpenAI Realtime API."""
|
||||
async for message in self._connection:
|
||||
for event in self._parse_message(message.model_dump()):
|
||||
yield event
|
||||
|
||||
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
|
||||
"""Parse a message from the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
message (dict[str, Any]): The message to parse.
|
||||
|
||||
Returns:
|
||||
RealtimeEvent: The parsed event.
|
||||
"""
|
||||
return [parse_oai_message(message)]
|
||||
|
||||
@classmethod
|
||||
def get_factory(
|
||||
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||
"""Create a Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger to use for logging events.
|
||||
kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||
"""
|
||||
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == []:
|
||||
return lambda: OpenAIRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
|
||||
if TYPE_CHECKING:
|
||||
_client: RealtimeClientProtocol = OpenAIRealtimeClient(llm_config={})
|
||||
@@ -0,0 +1,243 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger, getLogger
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from autogen.import_utils import optional_import_block, require_optional_import
|
||||
|
||||
from ......doc_utils import export_module
|
||||
from ......llm_config import LLMConfig
|
||||
from ...realtime_events import RealtimeEvent
|
||||
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
|
||||
from .utils import parse_oai_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...websockets import WebSocketProtocol as WebSocket
|
||||
from ..realtime_client import RealtimeClientProtocol
|
||||
|
||||
with optional_import_block():
|
||||
import httpx
|
||||
|
||||
__all__ = ["OpenAIRealtimeWebRTCClient"]
|
||||
|
||||
global_logger = getLogger(__name__)
|
||||
|
||||
|
||||
@register_realtime_client()
|
||||
@require_optional_import("httpx", "openai-realtime", except_for="get_factory")
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients.oai")
|
||||
class OpenAIRealtimeWebRTCClient(RealtimeClientBase):
|
||||
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
websocket: "WebSocket",
|
||||
logger: Optional[Logger] = None,
|
||||
) -> None:
|
||||
"""(Experimental) Client for OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
websocket: the websocket to use for the connection
|
||||
logger: the logger to use for logging events
|
||||
"""
|
||||
super().__init__()
|
||||
self._llm_config = llm_config
|
||||
self._logger = logger
|
||||
self._websocket = websocket
|
||||
|
||||
config = llm_config["config_list"][0]
|
||||
self._model: str = config["model"]
|
||||
self._voice: str = config.get("voice", "alloy")
|
||||
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
|
||||
self._config = config
|
||||
self._base_url = config.get("base_url", "https://api.openai.com/v1/realtime/sessions")
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Get the logger for the OpenAI Realtime API."""
|
||||
return self._logger or global_logger
|
||||
|
||||
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||
"""Send the result of a function call to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
call_id (str): The ID of the function call.
|
||||
result (str): The result of the function call.
|
||||
"""
|
||||
await self._websocket.send_json({
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": result,
|
||||
},
|
||||
})
|
||||
await self._websocket.send_json({"type": "response.create"})
|
||||
|
||||
async def send_text(self, *, role: Role, text: str) -> None:
|
||||
"""Send a text message to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
role (str): The role of the message.
|
||||
text (str): The text of the message.
|
||||
"""
|
||||
# await self.connection.response.cancel() #why is this here?
|
||||
await self._websocket.send_json({
|
||||
"type": "response.cancel",
|
||||
})
|
||||
await self._websocket.send_json({
|
||||
"type": "conversation.item.create",
|
||||
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
|
||||
})
|
||||
# await self.connection.response.create()
|
||||
await self._websocket.send_json({"type": "response.create"})
|
||||
|
||||
async def send_audio(self, audio: str) -> None:
|
||||
"""Send audio to the OpenAI Realtime API.
|
||||
in case of WebRTC, audio is already sent by js client, so we just queue it in order to be logged.
|
||||
|
||||
Args:
|
||||
audio (str): The audio to send.
|
||||
"""
|
||||
await self.queue_input_audio_buffer_delta(audio)
|
||||
|
||||
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||
"""Truncate audio in the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
audio_end_ms (int): The end of the audio to truncate.
|
||||
content_index (int): The index of the content to truncate.
|
||||
item_id (str): The ID of the item to truncate.
|
||||
"""
|
||||
await self._websocket.send_json({
|
||||
"type": "conversation.item.truncate",
|
||||
"content_index": content_index,
|
||||
"item_id": item_id,
|
||||
"audio_end_ms": audio_end_ms,
|
||||
})
|
||||
|
||||
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||
"""Send a session update to the OpenAI Realtime API.
|
||||
|
||||
In the case of WebRTC we can not send it directly, but we can send it
|
||||
to the javascript over the websocket, and rely on it to send session
|
||||
update to OpenAI
|
||||
|
||||
Args:
|
||||
session_options (dict[str, Any]): The session options to update.
|
||||
"""
|
||||
logger = self.logger
|
||||
logger.info(f"Sending session update: {session_options}")
|
||||
# await self.connection.session.update(session=session_options) # type: ignore[arg-type]
|
||||
await self._websocket.send_json({"type": "session.update", "session": session_options})
|
||||
logger.info("Sending session update finished")
|
||||
|
||||
def session_init_data(self) -> list[dict[str, Any]]:
|
||||
"""Control initial session with OpenAI."""
|
||||
session_update = {
|
||||
"turn_detection": {"type": "server_vad"},
|
||||
"voice": self._voice,
|
||||
"modalities": ["audio", "text"],
|
||||
"temperature": self._temperature,
|
||||
}
|
||||
return [{"type": "session.update", "session": session_update}]
|
||||
|
||||
async def _initialize_session(self) -> None: ...
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self) -> AsyncGenerator[None, None]:
|
||||
"""Connect to the OpenAI Realtime API.
|
||||
|
||||
In the case of WebRTC, we pass connection information over the
|
||||
websocket, so that javascript on the other end of websocket open
|
||||
actual connection to OpenAI
|
||||
"""
|
||||
try:
|
||||
base_url = self._base_url
|
||||
api_key = self._config.get("api_key", None)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
# "model": "gpt-4o-realtime-preview-2024-12-17",
|
||||
"model": self._model,
|
||||
"voice": self._voice,
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(base_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
json_data = response.json()
|
||||
json_data["model"] = self._model
|
||||
if self._websocket is not None:
|
||||
session_init = self.session_init_data()
|
||||
await self._websocket.send_json({"type": "ag2.init", "config": json_data, "init": session_init})
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read events from the OpenAI Realtime API."""
|
||||
async for event in self._read_events():
|
||||
yield event
|
||||
|
||||
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read messages from the OpenAI Realtime API connection.
|
||||
Again, in case of WebRTC, we do not read OpenAI messages directly since we
|
||||
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
|
||||
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
message_json = await self._websocket.receive_text()
|
||||
message = json.loads(message_json)
|
||||
for event in self._parse_message(message):
|
||||
yield event
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error reading from connection {e}")
|
||||
break
|
||||
|
||||
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
|
||||
"""Parse a message from the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
message (dict[str, Any]): The message to parse.
|
||||
|
||||
Returns:
|
||||
RealtimeEvent: The parsed event.
|
||||
"""
|
||||
return [parse_oai_message(message)]
|
||||
|
||||
@classmethod
|
||||
def get_factory(
|
||||
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||
"""Create a Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger to use for logging events.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||
"""
|
||||
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == ["websocket"]:
|
||||
return lambda: OpenAIRealtimeWebRTCClient(llm_config=llm_config, logger=logger, **kwargs)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol:
|
||||
return OpenAIRealtimeWebRTCClient(llm_config={}, websocket=websocket)
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from ...realtime_events import (
|
||||
AudioDelta,
|
||||
FunctionCall,
|
||||
InputAudioBufferDelta,
|
||||
RealtimeEvent,
|
||||
SessionCreated,
|
||||
SessionUpdated,
|
||||
SpeechStarted,
|
||||
)
|
||||
|
||||
__all__ = ["parse_oai_message"]
|
||||
|
||||
|
||||
def parse_oai_message(message: dict[str, Any]) -> RealtimeEvent:
|
||||
"""Parse a message from the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
message (dict[str, Any]): The message to parse.
|
||||
|
||||
Returns:
|
||||
RealtimeEvent: The parsed event.
|
||||
"""
|
||||
if message.get("type") == "session.created":
|
||||
return SessionCreated(raw_message=message)
|
||||
elif message.get("type") == "session.updated":
|
||||
return SessionUpdated(raw_message=message)
|
||||
elif message.get("type") == "response.audio.delta":
|
||||
return AudioDelta(raw_message=message, delta=message["delta"], item_id=message["item_id"])
|
||||
elif message.get("type") == "input_audio_buffer.speech_started":
|
||||
return SpeechStarted(raw_message=message)
|
||||
elif message.get("type") == "input_audio_buffer.delta":
|
||||
return InputAudioBufferDelta(delta=message["delta"], item_id=None, raw_message=message)
|
||||
elif message.get("type") == "response.function_call_arguments.done":
|
||||
return FunctionCall(
|
||||
raw_message=message,
|
||||
call_id=message["call_id"],
|
||||
name=message["name"],
|
||||
arguments=json.loads(message["arguments"]),
|
||||
)
|
||||
else:
|
||||
return RealtimeEvent(raw_message=message)
|
||||
@@ -0,0 +1,190 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from logging import Logger
|
||||
from typing import Any, AsyncContextManager, Callable, Literal, Optional, Protocol, TypeVar, Union, runtime_checkable
|
||||
|
||||
from asyncer import create_task_group
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....llm_config import LLMConfig
|
||||
from ..realtime_events import InputAudioBufferDelta, RealtimeEvent
|
||||
|
||||
__all__ = ["RealtimeClientProtocol", "Role", "get_client", "register_realtime_client"]
|
||||
|
||||
# define role literal type for typing
|
||||
Role = Literal["user", "assistant", "system"]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||
class RealtimeClientProtocol(Protocol):
|
||||
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||
"""Send the result of a function call to a Realtime API.
|
||||
|
||||
Args:
|
||||
call_id (str): The ID of the function call.
|
||||
result (str): The result of the function call.
|
||||
"""
|
||||
...
|
||||
|
||||
async def send_text(self, *, role: Role, text: str) -> None:
|
||||
"""Send a text message to a Realtime API.
|
||||
|
||||
Args:
|
||||
role (str): The role of the message.
|
||||
text (str): The text of the message.
|
||||
"""
|
||||
...
|
||||
|
||||
async def send_audio(self, audio: str) -> None:
|
||||
"""Send audio to a Realtime API.
|
||||
|
||||
Args:
|
||||
audio (str): The audio to send.
|
||||
"""
|
||||
...
|
||||
|
||||
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||
"""Truncate audio in a Realtime API.
|
||||
|
||||
Args:
|
||||
audio_end_ms (int): The end of the audio to truncate.
|
||||
content_index (int): The index of the content to truncate.
|
||||
item_id (str): The ID of the item to truncate.
|
||||
"""
|
||||
...
|
||||
|
||||
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||
"""Send a session update to a Realtime API.
|
||||
|
||||
Args:
|
||||
session_options (dict[str, Any]): The session options to update.
|
||||
"""
|
||||
...
|
||||
|
||||
def connect(self) -> AsyncContextManager[None]: ...
|
||||
|
||||
def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read events from a Realtime Client."""
|
||||
...
|
||||
|
||||
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read events from a Realtime connection."""
|
||||
...
|
||||
|
||||
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
|
||||
"""Parse a message from a Realtime API.
|
||||
|
||||
Args:
|
||||
message (dict[str, Any]): The message to parse.
|
||||
|
||||
Returns:
|
||||
list[RealtimeEvent]: The parsed events.
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_factory(
|
||||
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||
"""Create a Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger to use for logging events.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class RealtimeClientBase:
|
||||
def __init__(self):
|
||||
self._eventQueue = asyncio.Queue()
|
||||
|
||||
async def add_event(self, event: Optional[RealtimeEvent]):
|
||||
await self._eventQueue.put(event)
|
||||
|
||||
async def get_event(self) -> Optional[RealtimeEvent]:
|
||||
return await self._eventQueue.get()
|
||||
|
||||
async def _read_from_connection_task(self):
|
||||
async for event in self._read_from_connection():
|
||||
await self.add_event(event)
|
||||
await self.add_event(None)
|
||||
|
||||
async def _read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read events from a Realtime Client."""
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(self._read_from_connection_task)
|
||||
while True:
|
||||
try:
|
||||
event = await self._eventQueue.get()
|
||||
if event is not None:
|
||||
yield event
|
||||
else:
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
|
||||
async def queue_input_audio_buffer_delta(self, audio: str) -> None:
|
||||
"""queue InputAudioBufferDelta.
|
||||
|
||||
Args:
|
||||
audio (str): The audio.
|
||||
"""
|
||||
await self.add_event(InputAudioBufferDelta(delta=audio, item_id=None, raw_message=dict()))
|
||||
|
||||
|
||||
_realtime_client_classes: dict[str, type[RealtimeClientProtocol]] = {}
|
||||
|
||||
T = TypeVar("T", bound=RealtimeClientProtocol)
|
||||
|
||||
|
||||
def register_realtime_client() -> Callable[[type[T]], type[T]]:
|
||||
"""Register a Realtime API client.
|
||||
|
||||
Returns:
|
||||
Callable[[type[T]], type[T]]: The decorator to register the Realtime API client
|
||||
"""
|
||||
|
||||
def decorator(client_cls: type[T]) -> type[T]:
|
||||
"""Register a Realtime API client.
|
||||
|
||||
Args:
|
||||
client_cls: The client to register.
|
||||
"""
|
||||
global _realtime_client_classes
|
||||
fqn = f"{client_cls.__module__}.{client_cls.__name__}"
|
||||
_realtime_client_classes[fqn] = client_cls
|
||||
|
||||
return client_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||
def get_client(llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any) -> "RealtimeClientProtocol":
|
||||
"""Get a registered Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger to use for logging events.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client.
|
||||
"""
|
||||
global _realtime_client_classes
|
||||
for _, client_cls in _realtime_client_classes.items():
|
||||
factory = client_cls.get_factory(llm_config=llm_config, logger=logger, **kwargs)
|
||||
if factory:
|
||||
return factory()
|
||||
|
||||
raise ValueError("Realtime API client not found.")
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from asyncer import asyncify
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from .realtime_events import FunctionCall, RealtimeEvent
|
||||
from .realtime_observer import RealtimeObserver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.realtime.experimental")
|
||||
class FunctionObserver(RealtimeObserver):
|
||||
"""Observer for handling function calls from the OpenAI Realtime API."""
|
||||
|
||||
def __init__(self, *, logger: Optional["Logger"] = None) -> None:
|
||||
"""Observer for handling function calls from the OpenAI Realtime API."""
|
||||
super().__init__(logger=logger)
|
||||
|
||||
async def on_event(self, event: RealtimeEvent) -> None:
|
||||
"""Handle function call events from the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
event (dict[str, Any]): The event from the OpenAI Realtime API.
|
||||
"""
|
||||
if isinstance(event, FunctionCall):
|
||||
self.logger.info("Received function call event")
|
||||
await self.call_function(
|
||||
call_id=event.call_id,
|
||||
name=event.name,
|
||||
kwargs=event.arguments,
|
||||
)
|
||||
|
||||
async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -> None:
|
||||
"""Call a function registered with the agent.
|
||||
|
||||
Args:
|
||||
call_id (str): The ID of the function call.
|
||||
name (str): The name of the function to call.
|
||||
kwargs (Any[str, Any]): The arguments to pass to the function.
|
||||
"""
|
||||
if name in self.agent.registered_realtime_tools:
|
||||
func = self.agent.registered_realtime_tools[name].func
|
||||
func = func if asyncio.iscoroutinefunction(func) else asyncify(func)
|
||||
try:
|
||||
result = await func(**kwargs)
|
||||
except Exception:
|
||||
result = "Function call failed"
|
||||
self.logger.info(f"Function call failed: {name=}, {kwargs=}", stack_info=True)
|
||||
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_dump_json()
|
||||
elif not isinstance(result, str):
|
||||
try:
|
||||
result = json.dumps(result)
|
||||
except Exception:
|
||||
result = str(result)
|
||||
|
||||
await self.realtime_client.send_function_result(call_id, result)
|
||||
else:
|
||||
self.logger.warning(f"Function {name} called, but is not registered with the realtime agent.")
|
||||
|
||||
async def initialize_session(self) -> None:
|
||||
"""Add registered tools to OpenAI with a session update."""
|
||||
session_update = {
|
||||
"tools": [tool.realtime_tool_schema for tool in self.agent.registered_realtime_tools.values()],
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
await self.realtime_client.session_update(session_update)
|
||||
|
||||
async def run_loop(self) -> None:
|
||||
"""Run the observer loop."""
|
||||
pass
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
function_observer: RealtimeObserver = FunctionObserver()
|
||||
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger, getLogger
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
from anyio import lowlevel
|
||||
from asyncer import create_task_group
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ....llm_config import LLMConfig
|
||||
from ....tools import Tool
|
||||
from .clients.realtime_client import RealtimeClientProtocol, get_client
|
||||
from .function_observer import FunctionObserver
|
||||
from .realtime_observer import RealtimeObserver
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
global_logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeAgentCallbacks:
|
||||
"""Callbacks for the Realtime Agent."""
|
||||
|
||||
# async empty placeholder function
|
||||
on_observers_ready: Callable[[], Any] = lambda: lowlevel.checkpoint()
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.realtime.experimental")
|
||||
class RealtimeAgent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
audio_adapter: Optional[RealtimeObserver] = None,
|
||||
system_message: str = "You are a helpful AI Assistant.",
|
||||
llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
observers: Optional[list[RealtimeObserver]] = None,
|
||||
**client_kwargs: Any,
|
||||
):
|
||||
"""(Experimental) Agent for interacting with the Realtime Clients.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
audio_adapter (Optional[RealtimeObserver] = None): The audio adapter for the agent.
|
||||
system_message (str): The system message for the agent.
|
||||
llm_config (LLMConfig, dict[str, Any], bool): The config for the agent.
|
||||
logger (Optional[Logger]): The logger for the agent.
|
||||
observers (Optional[list[RealtimeObserver]]): The additional observers for the agent.
|
||||
**client_kwargs (Any): The keyword arguments for the client.
|
||||
"""
|
||||
self._logger = logger
|
||||
self._name = name
|
||||
self._system_message = system_message
|
||||
|
||||
llm_config = LLMConfig.get_current_llm_config(llm_config)
|
||||
|
||||
self._realtime_client: RealtimeClientProtocol = get_client(
|
||||
llm_config=llm_config, logger=self.logger, **client_kwargs
|
||||
)
|
||||
|
||||
self._registered_realtime_tools: dict[str, Tool] = {}
|
||||
self._observers: list[RealtimeObserver] = observers if observers else []
|
||||
self._observers.append(FunctionObserver(logger=logger))
|
||||
if audio_adapter:
|
||||
self._observers.append(audio_adapter)
|
||||
|
||||
self.callbacks = RealtimeAgentCallbacks()
|
||||
|
||||
@property
|
||||
def system_message(self) -> str:
|
||||
"""Get the system message for the agent."""
|
||||
return self._system_message
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Get the logger for the agent."""
|
||||
return self._logger or global_logger
|
||||
|
||||
@property
|
||||
def realtime_client(self) -> RealtimeClientProtocol:
|
||||
"""Get the OpenAI Realtime Client."""
|
||||
return self._realtime_client
|
||||
|
||||
@property
|
||||
def registered_realtime_tools(self) -> dict[str, Tool]:
|
||||
"""Get the registered realtime tools."""
|
||||
return self._registered_realtime_tools
|
||||
|
||||
def register_observer(self, observer: RealtimeObserver) -> None:
|
||||
"""Register an observer with the Realtime Agent.
|
||||
|
||||
Args:
|
||||
observer (RealtimeObserver): The observer to register.
|
||||
"""
|
||||
self._observers.append(observer)
|
||||
|
||||
async def start_observers(self) -> None:
|
||||
for observer in self._observers:
|
||||
self._tg.soonify(observer.run)(self)
|
||||
|
||||
# wait for the observers to be ready
|
||||
for observer in self._observers:
|
||||
await observer.wait_for_ready()
|
||||
|
||||
await self.callbacks.on_observers_ready()
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent."""
|
||||
# everything is run in the same task group to enable easy cancellation using self._tg.cancel_scope.cancel()
|
||||
async with create_task_group() as self._tg: # noqa: SIM117
|
||||
# connect with the client first (establishes a connection and initializes a session)
|
||||
async with self._realtime_client.connect():
|
||||
# start the observers and wait for them to be ready
|
||||
await self.realtime_client.session_update(session_options={"instructions": self.system_message})
|
||||
await self.start_observers()
|
||||
|
||||
# iterate over the events
|
||||
async for event in self.realtime_client.read_events():
|
||||
for observer in self._observers:
|
||||
await observer.on_event(event)
|
||||
|
||||
def register_realtime_function(
|
||||
self,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> Callable[[Union[F, Tool]], Tool]:
|
||||
"""Decorator for registering a function to be used by an agent.
|
||||
|
||||
Args:
|
||||
name (str): The name of the function.
|
||||
description (str): The description of the function.
|
||||
|
||||
Returns:
|
||||
Callable[[Union[F, Tool]], Tool]: The decorator for registering a function.
|
||||
"""
|
||||
|
||||
def _decorator(func_or_tool: Union[F, Tool]) -> Tool:
|
||||
"""Decorator for registering a function to be used by an agent.
|
||||
|
||||
Args:
|
||||
func_or_tool (Union[F, Tool]): The function or tool to register.
|
||||
|
||||
Returns:
|
||||
Tool: The registered tool.
|
||||
"""
|
||||
tool = Tool(func_or_tool=func_or_tool, name=name, description=description)
|
||||
|
||||
self._registered_realtime_tools[tool.name] = tool
|
||||
|
||||
return tool
|
||||
|
||||
return _decorator
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RealtimeEvent(BaseModel):
|
||||
raw_message: dict[str, Any]
|
||||
|
||||
|
||||
class SessionCreated(RealtimeEvent):
|
||||
type: Literal["session.created"] = "session.created"
|
||||
|
||||
|
||||
class SessionUpdated(RealtimeEvent):
|
||||
type: Literal["session.updated"] = "session.updated"
|
||||
|
||||
|
||||
class AudioDelta(RealtimeEvent):
|
||||
type: Literal["response.audio.delta"] = "response.audio.delta"
|
||||
delta: str
|
||||
item_id: Any
|
||||
|
||||
|
||||
class InputAudioBufferDelta(RealtimeEvent):
|
||||
type: Literal["input_audio_buffer.delta"] = "input_audio_buffer.delta"
|
||||
delta: str
|
||||
item_id: Any
|
||||
|
||||
|
||||
class SpeechStarted(RealtimeEvent):
|
||||
type: Literal["input_audio_buffer.speech_started"] = "input_audio_buffer.speech_started"
|
||||
|
||||
|
||||
class FunctionCall(RealtimeEvent):
|
||||
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
call_id: str
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger, getLogger
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from anyio import Event
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from .clients.realtime_client import RealtimeClientProtocol
|
||||
from .realtime_events import RealtimeEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .realtime_agent import RealtimeAgent
|
||||
|
||||
__all__ = ["RealtimeObserver"]
|
||||
|
||||
global_logger = getLogger(__name__)
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.realtime.experimental")
|
||||
class RealtimeObserver(ABC):
|
||||
"""Observer for the OpenAI Realtime API."""
|
||||
|
||||
def __init__(self, *, logger: Optional[Logger] = None) -> None:
|
||||
"""Observer for the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
logger (Logger): The logger for the observer.
|
||||
"""
|
||||
self._ready_event = Event()
|
||||
self._agent: Optional[RealtimeAgent] = None
|
||||
self._logger = logger
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
return self._logger or global_logger
|
||||
|
||||
@property
|
||||
def agent(self) -> "RealtimeAgent":
|
||||
if self._agent is None:
|
||||
raise RuntimeError("Agent has not been set.")
|
||||
return self._agent
|
||||
|
||||
@property
|
||||
def realtime_client(self) -> RealtimeClientProtocol:
|
||||
if self._agent is None:
|
||||
raise RuntimeError("Agent has not been set.")
|
||||
if self._agent.realtime_client is None:
|
||||
raise RuntimeError("Realtime client has not been set.")
|
||||
|
||||
return self._agent.realtime_client
|
||||
|
||||
async def run(self, agent: "RealtimeAgent") -> None:
|
||||
"""Run the observer with the agent.
|
||||
|
||||
When implementing, be sure to call `self._ready_event.set()` when the observer is ready to process events.
|
||||
|
||||
Args:
|
||||
agent (RealtimeAgent): The realtime agent attached to the observer.
|
||||
"""
|
||||
self._agent = agent
|
||||
await self.initialize_session()
|
||||
self._ready_event.set()
|
||||
|
||||
await self.run_loop()
|
||||
|
||||
@abstractmethod
|
||||
async def run_loop(self) -> None:
|
||||
"""Run the loop if needed.
|
||||
|
||||
This method is called after the observer is ready to process events.
|
||||
Events will be processed by the on_event method, this is just a hook for additional processing.
|
||||
Use initialize_session to set up the session.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def initialize_session(self) -> None:
|
||||
"""Initialize the session for the observer."""
|
||||
...
|
||||
|
||||
async def wait_for_ready(self) -> None:
|
||||
"""Get the event that is set when the observer is ready."""
|
||||
await self._ready_event.wait()
|
||||
|
||||
@abstractmethod
|
||||
async def on_event(self, event: RealtimeEvent) -> None:
|
||||
"""Handle an event from the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
event (RealtimeServerEvent): The event from the OpenAI Realtime API.
|
||||
"""
|
||||
...
|
||||
|
||||
async def on_close(self) -> None:
|
||||
"""Handle close of RealtimeClient."""
|
||||
...
|
||||
@@ -0,0 +1,483 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import anyio
|
||||
from asyncer import asyncify, create_task_group, syncify
|
||||
|
||||
from ....agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat
|
||||
from ....cache import AbstractCache
|
||||
from ....code_utils import content_str
|
||||
from ....doc_utils import export_module
|
||||
from ... import Agent, ChatResult, ConversableAgent, LLMAgent
|
||||
from ...utils import consolidate_chat_info, gather_usage_summary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .clients import Role
|
||||
from .realtime_agent import RealtimeAgent
|
||||
|
||||
__all__ = ["register_swarm"]
|
||||
|
||||
SWARM_SYSTEM_MESSAGE = (
|
||||
"You are a helpful voice assistant. Your task is to listen to user and to coordinate the tasks based on his/her inputs."
|
||||
"You can and will communicate using audio output only."
|
||||
)
|
||||
|
||||
QUESTION_ROLE: "Role" = "user"
|
||||
QUESTION_MESSAGE = (
|
||||
"I have a question/information for myself. DO NOT ANSWER YOURSELF, GET THE ANSWER FROM ME. "
|
||||
"repeat the question to me **WITH AUDIO OUTPUT** and AFTER YOU GET THE ANSWER FROM ME call 'answer_task_question' with the answer in first person\n\n"
|
||||
"IMPORTANT: repeat just the question, without any additional information or context\n\n"
|
||||
"The question is: '{}'\n\n"
|
||||
)
|
||||
QUESTION_TIMEOUT_SECONDS = 20
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def message_to_dict(message: Union[dict[str, Any], str]) -> dict[str, Any]:
|
||||
if isinstance(message, str):
|
||||
return {"content": message}
|
||||
elif isinstance(message, dict):
|
||||
return message
|
||||
else:
|
||||
return dict(message)
|
||||
|
||||
|
||||
def parse_oai_message(message: Union[dict[str, Any], str], role: str, adressee: Agent) -> dict[str, Any]:
|
||||
"""
|
||||
Parse a message into an OpenAI-compatible message format.
|
||||
|
||||
Args:
|
||||
message: The message to parse.
|
||||
role: The role associated with the message.
|
||||
adressee: The agent that will receive the message.
|
||||
|
||||
Returns:
|
||||
The parsed message in OpenAI-compatible format.
|
||||
|
||||
Raises:
|
||||
ValueError: If the message lacks required fields like 'content', 'function_call', or 'tool_calls'.
|
||||
"""
|
||||
message = message_to_dict(message)
|
||||
|
||||
# Extract relevant fields while ensuring none are None
|
||||
oai_message = {
|
||||
key: message[key]
|
||||
for key in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context")
|
||||
if key in message and message[key] is not None
|
||||
}
|
||||
|
||||
# Validate or set the content field
|
||||
if "content" not in oai_message:
|
||||
if "function_call" in oai_message or "tool_calls" in oai_message:
|
||||
oai_message["content"] = None
|
||||
else:
|
||||
raise ValueError("Message must have either 'content', 'function_call', or 'tool_calls' field.")
|
||||
|
||||
# Determine and assign the role
|
||||
if message.get("role") in ["function", "tool"]:
|
||||
oai_message["role"] = message["role"]
|
||||
# Ensure all tool responses have string content
|
||||
for tool_response in oai_message.get("tool_responses", []):
|
||||
tool_response["content"] = str(tool_response["content"])
|
||||
elif "override_role" in message:
|
||||
oai_message["role"] = message["override_role"]
|
||||
else:
|
||||
oai_message["role"] = role
|
||||
|
||||
# Enforce specific role requirements for assistant messages
|
||||
if oai_message.get("function_call") or oai_message.get("tool_calls"):
|
||||
oai_message["role"] = "assistant"
|
||||
|
||||
# Add a name field if missing
|
||||
if "name" not in oai_message:
|
||||
oai_message["name"] = adressee.name
|
||||
|
||||
return oai_message
|
||||
|
||||
|
||||
class SwarmableAgent(Agent):
|
||||
"""A class for an agent that can participate in a swarm chat."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
system_message: str = "You are a helpful AI Assistant.",
|
||||
is_termination_msg: Optional[Callable[..., bool]] = None,
|
||||
description: Optional[str] = None,
|
||||
silent: Optional[bool] = None,
|
||||
):
|
||||
self._oai_messages: dict[Agent, Any] = defaultdict(list)
|
||||
|
||||
self._system_message = system_message
|
||||
self._description = description if description is not None else system_message
|
||||
self._is_termination_msg = (
|
||||
is_termination_msg
|
||||
if is_termination_msg is not None
|
||||
else (lambda x: content_str(x.get("content")) == "TERMINATE")
|
||||
)
|
||||
self.silent = silent
|
||||
|
||||
self._name = name
|
||||
|
||||
# Initialize standalone client cache object.
|
||||
self.client_cache = None
|
||||
self.previous_cache = None
|
||||
|
||||
self.reply_at_receive: dict[Agent, bool] = defaultdict(bool)
|
||||
|
||||
@property
|
||||
def system_message(self) -> str:
|
||||
return self._system_message
|
||||
|
||||
def update_system_message(self, system_message: str) -> None:
|
||||
"""Update this agent's system message.
|
||||
|
||||
Args:
|
||||
system_message (str): system message for inference.
|
||||
"""
|
||||
self._system_message = system_message
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
def send(
|
||||
self,
|
||||
message: Union[dict[str, Any], str],
|
||||
recipient: Agent,
|
||||
request_reply: Optional[bool] = None,
|
||||
silent: Optional[bool] = False,
|
||||
) -> None:
|
||||
self._oai_messages[recipient].append(parse_oai_message(message, "assistant", recipient))
|
||||
recipient.receive(message, self, request_reply)
|
||||
|
||||
def receive(
|
||||
self,
|
||||
message: Union[dict[str, Any], str],
|
||||
sender: Agent,
|
||||
request_reply: Optional[bool] = None,
|
||||
silent: Optional[bool] = False,
|
||||
) -> None:
|
||||
self._oai_messages[sender].append(parse_oai_message(message, "user", self))
|
||||
if request_reply is False or (request_reply is None and self.reply_at_receive[sender] is False):
|
||||
return
|
||||
reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender)
|
||||
if reply is not None:
|
||||
self.send(reply, sender, silent=silent)
|
||||
|
||||
def generate_reply(
|
||||
self,
|
||||
messages: Optional[list[dict[str, Any]]] = None,
|
||||
sender: Optional["Agent"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, dict[str, Any], None]:
|
||||
if messages is None:
|
||||
if sender is None:
|
||||
raise ValueError("Either messages or sender must be provided.")
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
_, reply = self.check_termination_and_human_reply(messages=messages, sender=sender, config=None)
|
||||
|
||||
return reply
|
||||
|
||||
def check_termination_and_human_reply(
|
||||
self,
|
||||
messages: Optional[list[dict[str, Any]]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> tuple[bool, Union[str, None]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def initiate_chat(
|
||||
self,
|
||||
recipient: ConversableAgent,
|
||||
message: Union[dict[str, Any], str],
|
||||
clear_history: bool = True,
|
||||
silent: Optional[bool] = False,
|
||||
cache: Optional[AbstractCache] = None,
|
||||
summary_args: Optional[dict[str, Any]] = {},
|
||||
**kwargs: dict[str, Any],
|
||||
) -> ChatResult:
|
||||
_chat_info = locals().copy()
|
||||
_chat_info["sender"] = self
|
||||
consolidate_chat_info(_chat_info, uniform_sender=self)
|
||||
recipient._raise_exception_on_async_reply_functions()
|
||||
recipient.previous_cache = recipient.client_cache # type: ignore[attr-defined]
|
||||
recipient.client_cache = cache # type: ignore[attr-defined, assignment]
|
||||
|
||||
self._prepare_chat(recipient, clear_history)
|
||||
self.send(message, recipient, silent=silent)
|
||||
summary = self._last_msg_as_summary(self, recipient, summary_args)
|
||||
|
||||
recipient.client_cache = recipient.previous_cache # type: ignore[attr-defined]
|
||||
recipient.previous_cache = None # type: ignore[attr-defined]
|
||||
|
||||
chat_result = ChatResult(
|
||||
chat_history=self.chat_messages[recipient],
|
||||
summary=summary,
|
||||
cost=gather_usage_summary([self, recipient]), # type: ignore[arg-type]
|
||||
human_input=[],
|
||||
)
|
||||
return chat_result
|
||||
|
||||
async def a_generate_reply(
|
||||
self,
|
||||
messages: Optional[list[dict[str, Any]]] = None,
|
||||
sender: Optional["Agent"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, dict[str, Any], None]:
|
||||
return self.generate_reply(messages=messages, sender=sender, **kwargs)
|
||||
|
||||
async def a_receive(
|
||||
self,
|
||||
message: Union[dict[str, Any], str],
|
||||
sender: "Agent",
|
||||
request_reply: Optional[bool] = None,
|
||||
) -> None:
|
||||
self.receive(message, sender, request_reply)
|
||||
|
||||
async def a_send(
|
||||
self,
|
||||
message: Union[dict[str, Any], str],
|
||||
recipient: "Agent",
|
||||
request_reply: Optional[bool] = None,
|
||||
) -> None:
|
||||
self.send(message, recipient, request_reply)
|
||||
|
||||
@property
|
||||
def chat_messages(self) -> dict[Agent, list[dict[str, Any]]]:
|
||||
"""A dictionary of conversations from agent to list of messages."""
|
||||
return self._oai_messages
|
||||
|
||||
def last_message(self, agent: Optional[Agent] = None) -> Optional[dict[str, Any]]:
|
||||
if agent is None:
|
||||
n_conversations = len(self._oai_messages)
|
||||
if n_conversations == 0:
|
||||
return None
|
||||
if n_conversations == 1:
|
||||
for conversation in self._oai_messages.values():
|
||||
return conversation[-1] # type: ignore[no-any-return]
|
||||
raise ValueError("More than one conversation is found. Please specify the sender to get the last message.")
|
||||
if agent not in self._oai_messages():
|
||||
raise KeyError(
|
||||
f"The agent '{agent.name}' is not present in any conversation. No history available for this agent."
|
||||
)
|
||||
return self._oai_messages[agent][-1] # type: ignore[no-any-return]
|
||||
|
||||
def _prepare_chat(
|
||||
self,
|
||||
recipient: ConversableAgent,
|
||||
clear_history: bool,
|
||||
prepare_recipient: bool = True,
|
||||
reply_at_receive: bool = True,
|
||||
) -> None:
|
||||
self.reply_at_receive[recipient] = reply_at_receive
|
||||
if clear_history:
|
||||
self._oai_messages[recipient].clear()
|
||||
if prepare_recipient:
|
||||
recipient._prepare_chat(self, clear_history, False, reply_at_receive) # type: ignore[arg-type]
|
||||
|
||||
def _raise_exception_on_async_reply_functions(self) -> None:
|
||||
pass
|
||||
|
||||
def set_ui_tools(self, tools: Optional[list] = None) -> None:
|
||||
"""Set UI tools for the agent."""
|
||||
pass
|
||||
|
||||
def unset_ui_tools(self) -> None:
|
||||
"""Unset UI tools for the agent."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _last_msg_as_summary(sender: Agent, recipient: Agent, summary_args: Optional[dict[str, Any]]) -> str:
|
||||
"""Get a chat summary from the last message of the recipient."""
|
||||
summary = ""
|
||||
try:
|
||||
content = recipient.last_message(sender)["content"] # type: ignore[attr-defined]
|
||||
if isinstance(content, str):
|
||||
summary = content.replace("TERMINATE", "")
|
||||
elif isinstance(content, list):
|
||||
summary = "\n".join(
|
||||
x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x
|
||||
)
|
||||
except (IndexError, AttributeError) as e:
|
||||
warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning)
|
||||
return summary
|
||||
|
||||
|
||||
# check that the SwarmableAgent class is implementing LLMAgent protocol
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def _create_swarmable_agent(
|
||||
name: str,
|
||||
system_message: str,
|
||||
is_termination_msg: Optional[Callable[..., bool]],
|
||||
description: Optional[str],
|
||||
silent: Optional[bool],
|
||||
) -> LLMAgent:
|
||||
return SwarmableAgent(
|
||||
name=name,
|
||||
system_message=system_message,
|
||||
is_termination_msg=is_termination_msg,
|
||||
description=description,
|
||||
silent=silent,
|
||||
)
|
||||
|
||||
|
||||
class SwarmableRealtimeAgent(SwarmableAgent):
|
||||
def __init__(
|
||||
self,
|
||||
realtime_agent: "RealtimeAgent",
|
||||
initial_agent: ConversableAgent,
|
||||
agents: list[ConversableAgent],
|
||||
question_message: Optional[str] = None,
|
||||
) -> None:
|
||||
self._initial_agent = initial_agent
|
||||
self._agents = agents
|
||||
self._realtime_agent = realtime_agent
|
||||
|
||||
self._answer_event: anyio.Event = anyio.Event()
|
||||
self._answer: str = ""
|
||||
self.question_message = question_message or QUESTION_MESSAGE
|
||||
|
||||
super().__init__(
|
||||
name=realtime_agent._name,
|
||||
is_termination_msg=None,
|
||||
description=None,
|
||||
silent=None,
|
||||
)
|
||||
|
||||
def reset_answer(self) -> None:
|
||||
"""Reset the answer event."""
|
||||
self._answer_event = anyio.Event()
|
||||
|
||||
def set_answer(self, answer: str) -> str:
|
||||
"""Set the answer to the question."""
|
||||
self._answer = answer
|
||||
self._answer_event.set()
|
||||
return "Answer set successfully."
|
||||
|
||||
async def get_answer(self) -> str:
|
||||
"""Get the answer to the question."""
|
||||
await self._answer_event.wait()
|
||||
return self._answer
|
||||
|
||||
async def ask_question(self, question: str, question_timeout: int) -> None:
|
||||
"""Send a question for the user to the agent and wait for the answer.
|
||||
If the answer is not received within the timeout, the question is repeated.
|
||||
|
||||
Args:
|
||||
question: The question to ask the user.
|
||||
question_timeout: The time in seconds to wait for the answer.
|
||||
"""
|
||||
self.reset_answer()
|
||||
realtime_client = self._realtime_agent._realtime_client
|
||||
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
|
||||
|
||||
async def _check_event_set(timeout: int = question_timeout) -> bool:
|
||||
for _ in range(timeout):
|
||||
if self._answer_event.is_set():
|
||||
return True
|
||||
await anyio.sleep(1)
|
||||
return False
|
||||
|
||||
while not await _check_event_set():
|
||||
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
|
||||
|
||||
def check_termination_and_human_reply(
|
||||
self,
|
||||
messages: Optional[list[dict[str, Any]]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""Check if the conversation should be terminated and if the agent should reply.
|
||||
|
||||
Called when its agents turn in the chat conversation.
|
||||
|
||||
Args:
|
||||
messages (list[dict[str, Any]]): The messages in the conversation.
|
||||
sender (Agent): The agent that sent the message.
|
||||
config (Optional[Any]): The configuration for the agent.
|
||||
"""
|
||||
if not messages:
|
||||
return False, None
|
||||
|
||||
async def get_input() -> None:
|
||||
async with create_task_group() as tg:
|
||||
tg.soonify(self.ask_question)(
|
||||
self.question_message.format(messages[-1]["content"]),
|
||||
question_timeout=QUESTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
syncify(get_input)()
|
||||
|
||||
return True, {"role": "user", "content": self._answer} # type: ignore[return-value]
|
||||
|
||||
def start_chat(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def configure_realtime_agent(self, system_message: Optional[str]) -> None:
|
||||
realtime_agent = self._realtime_agent
|
||||
|
||||
logger = realtime_agent.logger
|
||||
if not system_message:
|
||||
if realtime_agent.system_message != "You are a helpful AI Assistant.":
|
||||
logger.warning(
|
||||
"Overriding system message set up in `__init__`, please use `system_message` parameter of the `register_swarm` function instead."
|
||||
)
|
||||
system_message = SWARM_SYSTEM_MESSAGE
|
||||
|
||||
realtime_agent._system_message = system_message
|
||||
|
||||
realtime_agent.register_realtime_function(
|
||||
name="answer_task_question", description="Answer question from the task"
|
||||
)(self.set_answer)
|
||||
|
||||
async def on_observers_ready() -> None:
|
||||
self._realtime_agent._tg.soonify(asyncify(initiate_swarm_chat))(
|
||||
initial_agent=self._initial_agent,
|
||||
agents=self._agents,
|
||||
user_agent=self, # type: ignore[arg-type]
|
||||
messages="Find out what the user wants.",
|
||||
after_work=AfterWorkOption.REVERT_TO_USER,
|
||||
)
|
||||
|
||||
self._realtime_agent.callbacks.on_observers_ready = on_observers_ready
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.realtime.experimental")
|
||||
def register_swarm(
|
||||
*,
|
||||
realtime_agent: "RealtimeAgent",
|
||||
initial_agent: ConversableAgent,
|
||||
agents: list[ConversableAgent],
|
||||
system_message: Optional[str] = None,
|
||||
question_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create a SwarmableRealtimeAgent.
|
||||
|
||||
Args:
|
||||
realtime_agent (RealtimeAgent): The RealtimeAgent to create the SwarmableRealtimeAgent from.
|
||||
initial_agent (ConversableAgent): The initial agent.
|
||||
agents (list[ConversableAgent]): The agents in the swarm.
|
||||
system_message (Optional[str]): The system message to set for the agent. If None, the default system message is used.
|
||||
question_message (Optional[str]): The question message to set for the agent. If None, the default QUESTION_MESSAGE is used.
|
||||
"""
|
||||
swarmable_agent = SwarmableRealtimeAgent(
|
||||
realtime_agent=realtime_agent, initial_agent=initial_agent, agents=agents, question_message=question_message
|
||||
)
|
||||
|
||||
swarmable_agent.configure_realtime_agent(system_message=system_message)
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
__all__ = ["WebSocketProtocol"]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class WebSocketProtocol(Protocol):
|
||||
"""WebSocket protocol for sending and receiving JSON data modelled after FastAPI's WebSocket."""
|
||||
|
||||
async def send_json(self, data: Any, mode: str = "text") -> None: ...
|
||||
|
||||
async def receive_json(self, mode: str = "text") -> Any: ...
|
||||
|
||||
async def receive_text(self) -> str: ...
|
||||
|
||||
def iter_text(self) -> AsyncIterator[str]: ...
|
||||
Reference in New Issue
Block a user