CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .gemini.client import GeminiRealtimeClient
from .oai.base_client import OpenAIRealtimeClient
from .realtime_client import RealtimeClientProtocol, Role, get_client
__all__ = [
"GeminiRealtimeClient",
"OpenAIRealtimeClient",
"RealtimeClientProtocol",
"Role",
"get_client",
]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .client import GeminiRealtimeClient
__all__ = ["GeminiRealtimeClient"]

View File

@@ -0,0 +1,274 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from ......doc_utils import export_module
from ......import_utils import optional_import_block, require_optional_import
from ......llm_config import LLMConfig
from ...realtime_events import AudioDelta, FunctionCall, RealtimeEvent, SessionCreated
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
with optional_import_block():
from websockets.asyncio.client import connect
if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
from ..realtime_client import RealtimeClientProtocol
__all__ = ["GeminiRealtimeClient"]
global_logger = getLogger(__name__)
HOST = "generativelanguage.googleapis.com"
API_VERSION = "v1alpha"
@register_realtime_client()
@require_optional_import("websockets", "gemini", except_for=["get_factory", "__init__"])
@export_module("autogen.agentchat.realtime.experimental.clients")
class GeminiRealtimeClient(RealtimeClientBase):
"""(Experimental) Client for Gemini Realtime API."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for Gemini Realtime API.
Args:
llm_config: The config for the client.
logger: The logger for the client.
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._connection: Optional["ClientConnection"] = None
config = llm_config["config_list"][0]
self._model: str = config["model"]
self._voice = config.get("voice", "charon")
self._temperature: float = config.get("temperature", 0.8) # type: ignore[union-attr]
self._response_modality = "AUDIO"
self._api_key = config.get("api_key", None)
# todo: add test with base_url just to make sure it works
self._base_url: str = config.get(
"base_url",
f"wss://{HOST}/ws/google.ai.generativelanguage.{API_VERSION}.GenerativeService.BidiGenerateContent?key={self._api_key}",
)
self._final_config: dict[str, Any] = {}
self._pending_session_updates: dict[str, Any] = {}
self._is_reading_events = False
@property
def logger(self) -> Logger:
"""Get the logger for the Gemini Realtime API."""
return self._logger or global_logger
@property
def connection(self) -> "ClientConnection":
"""Get the Gemini WebSocket connection."""
if self._connection is None:
raise RuntimeError("Gemini WebSocket is not initialized")
return self._connection
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the Gemini Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
msg = {
"tool_response": {"function_responses": [{"id": call_id, "response": {"result": {"string_value": result}}}]}
}
if self._is_reading_events:
await self.connection.send(json.dumps(msg))
async def send_text(self, *, role: Role, text: str, turn_complete: bool = True) -> None:
"""Send a text message to the Gemini Realtime API.
Args:
role: The role of the message.
text: The text of the message.
turn_complete: A flag indicating if the turn is complete.
"""
msg = {
"client_content": {
"turn_complete": turn_complete,
"turns": [{"role": role, "parts": [{"text": text}]}],
}
}
if self._is_reading_events:
await self.connection.send(json.dumps(msg))
async def send_audio(self, audio: str) -> None:
"""Send audio to the Gemini Realtime API.
Args:
audio (str): The audio to send.
"""
msg = {
"realtime_input": {
"media_chunks": [
{
"data": audio,
"mime_type": "audio/pcm",
}
]
}
}
await self.queue_input_audio_buffer_delta(audio)
if self._is_reading_events:
await self.connection.send(json.dumps(msg))
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
self.logger.info("This is not natively supported by Gemini Realtime API.")
pass
async def _initialize_session(self) -> None:
"""Initialize the session with the Gemini Realtime API."""
session_config = {
"setup": {
"system_instruction": {
"role": "system",
"parts": [{"text": self._pending_session_updates.get("instructions", "")}],
},
"model": f"models/{self._model}",
"tools": [
{
"function_declarations": [
{
"name": tool_schema["name"],
"description": tool_schema["description"],
"parameters": tool_schema["parameters"],
}
for tool_schema in self._pending_session_updates.get("tools", [])
]
},
],
"generation_config": {
"response_modalities": [self._response_modality],
"speech_config": {"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": self._voice}}},
"temperature": self._temperature,
},
}
}
self.logger.info(f"Sending session update: {session_config}")
await self.connection.send(json.dumps(session_config))
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Record session updates to be applied when the connection is established.
Args:
session_options (dict[str, Any]): The session options to update.
"""
if self._is_reading_events:
self.logger.warning("Is reading events. Session update will be ignored.")
else:
self._pending_session_updates.update(session_options)
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the Gemini Realtime API."""
try:
async with connect(
self._base_url, additional_headers={"Content-Type": "application/json"}
) as self._connection:
yield
finally:
self._connection = None
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read Events from the Gemini Realtime Client"""
if self._connection is None:
raise RuntimeError("Client is not connected, call connect() first.")
await self._initialize_session()
self._is_reading_events = True
async for event in self._read_events():
yield event
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the Gemini Realtime connection."""
async for raw_message in self.connection:
message = raw_message.decode("ascii") if isinstance(raw_message, bytes) else raw_message
events = self._parse_message(json.loads(message))
for event in events:
yield event
def _parse_message(self, response: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the Gemini Realtime API.
Args:
response (dict[str, Any]): The response to parse.
Returns:
list[RealtimeEvent]: The parsed events.
"""
if "serverContent" in response and "modelTurn" in response["serverContent"]:
try:
b64data = response["serverContent"]["modelTurn"]["parts"][0]["inlineData"].pop("data")
return [
AudioDelta(
delta=b64data,
item_id=None,
raw_message=response,
)
]
except KeyError:
return []
elif "toolCall" in response:
return [
FunctionCall(
raw_message=response,
call_id=call["id"],
name=call["name"],
arguments=call["args"],
)
for call in response["toolCall"]["functionCalls"]
]
elif "setupComplete" in response:
return [
SessionCreated(raw_message=response),
]
else:
return [RealtimeEvent(raw_message=response)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The LLM config for the client.
logger: The logger for the client.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type") == "google" and list(kwargs.keys()) == []:
return lambda: GeminiRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if GeminiRealtimeClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_client: RealtimeClientProtocol = GeminiRealtimeClient(llm_config={})

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .base_client import OpenAIRealtimeClient
from .rtc_client import OpenAIRealtimeWebRTCClient
__all__ = ["OpenAIRealtimeClient", "OpenAIRealtimeWebRTCClient"]

View File

@@ -0,0 +1,220 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from ......doc_utils import export_module
from ......import_utils import optional_import_block, require_optional_import
from ......llm_config import LLMConfig
from ...realtime_events import RealtimeEvent
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message
with optional_import_block():
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection
if TYPE_CHECKING:
from ..realtime_client import RealtimeClientProtocol
__all__ = ["OpenAIRealtimeClient"]
global_logger = getLogger(__name__)
@register_realtime_client()
@require_optional_import("openai>=1.66.2", "openai-realtime", except_for=["get_factory", "__init__"])
@export_module("autogen.agentchat.realtime.experimental.clients")
class OpenAIRealtimeClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Args:
llm_config: The config for the client.
logger: the logger to use for logging events
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._connection: Optional["AsyncRealtimeConnection"] = None
self.config = llm_config["config_list"][0]
# model is passed to self._client.beta.realtime.connect function later
self._model: str = self.config["model"]
self._voice: str = self.config.get("voice", "alloy")
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
self._client: Optional["AsyncOpenAI"] = None
@property
def logger(self) -> Logger:
"""Get the logger for the OpenAI Realtime API."""
return self._logger or global_logger
@property
def connection(self) -> "AsyncRealtimeConnection":
"""Get the OpenAI WebSocket connection."""
if self._connection is None:
raise RuntimeError("OpenAI WebSocket is not initialized")
return self._connection
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
await self.connection.conversation.item.create(
item={
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
)
await self.connection.response.create()
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
await self.connection.response.cancel()
await self.connection.conversation.item.create(
item={"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}
)
await self.connection.response.create()
async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
Args:
audio (str): The audio to send.
"""
await self.queue_input_audio_buffer_delta(audio)
await self.connection.input_audio_buffer.append(audio=audio)
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
await self.connection.conversation.item.truncate(
audio_end_ms=audio_end_ms, content_index=content_index, item_id=item_id
)
async def _initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._voice,
"modalities": ["audio", "text"],
"temperature": self._temperature,
}
await self.session_update(session_options=session_update)
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
Args:
session_options (dict[str, Any]): The session options to update.
"""
logger = self.logger
logger.info(f"Sending session update: {session_options}")
await self.connection.session.update(session=session_options) # type: ignore[arg-type]
logger.info("Sending session update finished")
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the OpenAI Realtime API."""
try:
if not self._client:
self._client = AsyncOpenAI(
api_key=self.config.get("api_key", None),
organization=self.config.get("organization", None),
project=self.config.get("project", None),
base_url=self.config.get("base_url", None),
websocket_base_url=self.config.get("websocket_base_url", None),
timeout=self.config.get("timeout", NOT_GIVEN),
max_retries=self.config.get("max_retries", DEFAULT_MAX_RETRIES),
default_headers=self.config.get("default_headers", None),
default_query=self.config.get("default_query", None),
)
async with self._client.beta.realtime.connect(
model=self._model,
) as self._connection:
await self._initialize_session()
yield
finally:
self._connection = None
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API."""
if self._connection is None:
raise RuntimeError("Client is not connected, call connect() first.")
try:
async for event in self._read_events():
yield event
finally:
self._connection = None
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API."""
async for message in self._connection:
for event in self._parse_message(message.model_dump()):
yield event
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
return [parse_oai_message(message)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == []:
return lambda: OpenAIRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_client: RealtimeClientProtocol = OpenAIRealtimeClient(llm_config={})

View File

@@ -0,0 +1,243 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from autogen.import_utils import optional_import_block, require_optional_import
from ......doc_utils import export_module
from ......llm_config import LLMConfig
from ...realtime_events import RealtimeEvent
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message
if TYPE_CHECKING:
from ...websockets import WebSocketProtocol as WebSocket
from ..realtime_client import RealtimeClientProtocol
with optional_import_block():
import httpx
__all__ = ["OpenAIRealtimeWebRTCClient"]
global_logger = getLogger(__name__)
@register_realtime_client()
@require_optional_import("httpx", "openai-realtime", except_for="get_factory")
@export_module("autogen.agentchat.realtime.experimental.clients.oai")
class OpenAIRealtimeWebRTCClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
websocket: "WebSocket",
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Args:
llm_config: The config for the client.
websocket: the websocket to use for the connection
logger: the logger to use for logging events
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._websocket = websocket
config = llm_config["config_list"][0]
self._model: str = config["model"]
self._voice: str = config.get("voice", "alloy")
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
self._config = config
self._base_url = config.get("base_url", "https://api.openai.com/v1/realtime/sessions")
@property
def logger(self) -> Logger:
"""Get the logger for the OpenAI Realtime API."""
return self._logger or global_logger
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
await self._websocket.send_json({
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
})
await self._websocket.send_json({"type": "response.create"})
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
# await self.connection.response.cancel() #why is this here?
await self._websocket.send_json({
"type": "response.cancel",
})
await self._websocket.send_json({
"type": "conversation.item.create",
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
})
# await self.connection.response.create()
await self._websocket.send_json({"type": "response.create"})
async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
in case of WebRTC, audio is already sent by js client, so we just queue it in order to be logged.
Args:
audio (str): The audio to send.
"""
await self.queue_input_audio_buffer_delta(audio)
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
await self._websocket.send_json({
"type": "conversation.item.truncate",
"content_index": content_index,
"item_id": item_id,
"audio_end_ms": audio_end_ms,
})
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
In the case of WebRTC we can not send it directly, but we can send it
to the javascript over the websocket, and rely on it to send session
update to OpenAI
Args:
session_options (dict[str, Any]): The session options to update.
"""
logger = self.logger
logger.info(f"Sending session update: {session_options}")
# await self.connection.session.update(session=session_options) # type: ignore[arg-type]
await self._websocket.send_json({"type": "session.update", "session": session_options})
logger.info("Sending session update finished")
def session_init_data(self) -> list[dict[str, Any]]:
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._voice,
"modalities": ["audio", "text"],
"temperature": self._temperature,
}
return [{"type": "session.update", "session": session_update}]
async def _initialize_session(self) -> None: ...
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the OpenAI Realtime API.
In the case of WebRTC, we pass connection information over the
websocket, so that javascript on the other end of websocket open
actual connection to OpenAI
"""
try:
base_url = self._base_url
api_key = self._config.get("api_key", None)
headers = {
"Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment
"Content-Type": "application/json",
}
data = {
# "model": "gpt-4o-realtime-preview-2024-12-17",
"model": self._model,
"voice": self._voice,
}
async with httpx.AsyncClient() as client:
response = await client.post(base_url, headers=headers, json=data)
response.raise_for_status()
json_data = response.json()
json_data["model"] = self._model
if self._websocket is not None:
session_init = self.session_init_data()
await self._websocket.send_json({"type": "ag2.init", "config": json_data, "init": session_init})
yield
finally:
pass
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from the OpenAI Realtime API."""
async for event in self._read_events():
yield event
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API connection.
Again, in case of WebRTC, we do not read OpenAI messages directly since we
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
"""
while True:
try:
message_json = await self._websocket.receive_text()
message = json.loads(message_json)
for event in self._parse_message(message):
yield event
except Exception as e:
self.logger.exception(f"Error reading from connection {e}")
break
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
return [parse_oai_message(message)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == ["websocket"]:
return lambda: OpenAIRealtimeWebRTCClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol:
return OpenAIRealtimeWebRTCClient(llm_config={}, websocket=websocket)

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Any
from ...realtime_events import (
AudioDelta,
FunctionCall,
InputAudioBufferDelta,
RealtimeEvent,
SessionCreated,
SessionUpdated,
SpeechStarted,
)
__all__ = ["parse_oai_message"]
def parse_oai_message(message: dict[str, Any]) -> RealtimeEvent:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
if message.get("type") == "session.created":
return SessionCreated(raw_message=message)
elif message.get("type") == "session.updated":
return SessionUpdated(raw_message=message)
elif message.get("type") == "response.audio.delta":
return AudioDelta(raw_message=message, delta=message["delta"], item_id=message["item_id"])
elif message.get("type") == "input_audio_buffer.speech_started":
return SpeechStarted(raw_message=message)
elif message.get("type") == "input_audio_buffer.delta":
return InputAudioBufferDelta(delta=message["delta"], item_id=None, raw_message=message)
elif message.get("type") == "response.function_call_arguments.done":
return FunctionCall(
raw_message=message,
call_id=message["call_id"],
name=message["name"],
arguments=json.loads(message["arguments"]),
)
else:
return RealtimeEvent(raw_message=message)

View File

@@ -0,0 +1,190 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
from collections.abc import AsyncGenerator
from logging import Logger
from typing import Any, AsyncContextManager, Callable, Literal, Optional, Protocol, TypeVar, Union, runtime_checkable
from asyncer import create_task_group
from .....doc_utils import export_module
from .....llm_config import LLMConfig
from ..realtime_events import InputAudioBufferDelta, RealtimeEvent
__all__ = ["RealtimeClientProtocol", "Role", "get_client", "register_realtime_client"]
# define role literal type for typing
Role = Literal["user", "assistant", "system"]
@runtime_checkable
@export_module("autogen.agentchat.realtime.experimental.clients")
class RealtimeClientProtocol(Protocol):
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to a Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
...
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to a Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
...
async def send_audio(self, audio: str) -> None:
"""Send audio to a Realtime API.
Args:
audio (str): The audio to send.
"""
...
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in a Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
...
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to a Realtime API.
Args:
session_options (dict[str, Any]): The session options to update.
"""
...
def connect(self) -> AsyncContextManager[None]: ...
def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime Client."""
...
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime connection."""
...
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from a Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
list[RealtimeEvent]: The parsed events.
"""
...
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
...
class RealtimeClientBase:
def __init__(self):
self._eventQueue = asyncio.Queue()
async def add_event(self, event: Optional[RealtimeEvent]):
await self._eventQueue.put(event)
async def get_event(self) -> Optional[RealtimeEvent]:
return await self._eventQueue.get()
async def _read_from_connection_task(self):
async for event in self._read_from_connection():
await self.add_event(event)
await self.add_event(None)
async def _read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime Client."""
async with create_task_group() as tg:
tg.start_soon(self._read_from_connection_task)
while True:
try:
event = await self._eventQueue.get()
if event is not None:
yield event
else:
break
except Exception:
break
async def queue_input_audio_buffer_delta(self, audio: str) -> None:
"""queue InputAudioBufferDelta.
Args:
audio (str): The audio.
"""
await self.add_event(InputAudioBufferDelta(delta=audio, item_id=None, raw_message=dict()))
_realtime_client_classes: dict[str, type[RealtimeClientProtocol]] = {}
T = TypeVar("T", bound=RealtimeClientProtocol)
def register_realtime_client() -> Callable[[type[T]], type[T]]:
"""Register a Realtime API client.
Returns:
Callable[[type[T]], type[T]]: The decorator to register the Realtime API client
"""
def decorator(client_cls: type[T]) -> type[T]:
"""Register a Realtime API client.
Args:
client_cls: The client to register.
"""
global _realtime_client_classes
fqn = f"{client_cls.__module__}.{client_cls.__name__}"
_realtime_client_classes[fqn] = client_cls
return client_cls
return decorator
@export_module("autogen.agentchat.realtime.experimental.clients")
def get_client(llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any) -> "RealtimeClientProtocol":
"""Get a registered Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client.
"""
global _realtime_client_classes
for _, client_cls in _realtime_client_classes.items():
factory = client_cls.get_factory(llm_config=llm_config, logger=logger, **kwargs)
if factory:
return factory()
raise ValueError("Realtime API client not found.")