CoACT initialize (#292)
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .gemini.client import GeminiRealtimeClient
|
||||
from .oai.base_client import OpenAIRealtimeClient
|
||||
from .realtime_client import RealtimeClientProtocol, Role, get_client
|
||||
|
||||
__all__ = [
|
||||
"GeminiRealtimeClient",
|
||||
"OpenAIRealtimeClient",
|
||||
"RealtimeClientProtocol",
|
||||
"Role",
|
||||
"get_client",
|
||||
]
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .client import GeminiRealtimeClient
|
||||
|
||||
__all__ = ["GeminiRealtimeClient"]
|
||||
@@ -0,0 +1,274 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger, getLogger
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from ......doc_utils import export_module
|
||||
from ......import_utils import optional_import_block, require_optional_import
|
||||
from ......llm_config import LLMConfig
|
||||
from ...realtime_events import AudioDelta, FunctionCall, RealtimeEvent, SessionCreated
|
||||
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
|
||||
|
||||
with optional_import_block():
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
from ..realtime_client import RealtimeClientProtocol
|
||||
|
||||
__all__ = ["GeminiRealtimeClient"]
|
||||
|
||||
global_logger = getLogger(__name__)
|
||||
|
||||
|
||||
HOST = "generativelanguage.googleapis.com"
|
||||
API_VERSION = "v1alpha"
|
||||
|
||||
|
||||
@register_realtime_client()
|
||||
@require_optional_import("websockets", "gemini", except_for=["get_factory", "__init__"])
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||
class GeminiRealtimeClient(RealtimeClientBase):
|
||||
"""(Experimental) Client for Gemini Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
logger: Optional[Logger] = None,
|
||||
) -> None:
|
||||
"""(Experimental) Client for Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger for the client.
|
||||
"""
|
||||
super().__init__()
|
||||
self._llm_config = llm_config
|
||||
self._logger = logger
|
||||
|
||||
self._connection: Optional["ClientConnection"] = None
|
||||
config = llm_config["config_list"][0]
|
||||
|
||||
self._model: str = config["model"]
|
||||
self._voice = config.get("voice", "charon")
|
||||
self._temperature: float = config.get("temperature", 0.8) # type: ignore[union-attr]
|
||||
|
||||
self._response_modality = "AUDIO"
|
||||
|
||||
self._api_key = config.get("api_key", None)
|
||||
# todo: add test with base_url just to make sure it works
|
||||
self._base_url: str = config.get(
|
||||
"base_url",
|
||||
f"wss://{HOST}/ws/google.ai.generativelanguage.{API_VERSION}.GenerativeService.BidiGenerateContent?key={self._api_key}",
|
||||
)
|
||||
self._final_config: dict[str, Any] = {}
|
||||
self._pending_session_updates: dict[str, Any] = {}
|
||||
self._is_reading_events = False
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Get the logger for the Gemini Realtime API."""
|
||||
return self._logger or global_logger
|
||||
|
||||
@property
|
||||
def connection(self) -> "ClientConnection":
|
||||
"""Get the Gemini WebSocket connection."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Gemini WebSocket is not initialized")
|
||||
return self._connection
|
||||
|
||||
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||
"""Send the result of a function call to the Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
call_id (str): The ID of the function call.
|
||||
result (str): The result of the function call.
|
||||
"""
|
||||
msg = {
|
||||
"tool_response": {"function_responses": [{"id": call_id, "response": {"result": {"string_value": result}}}]}
|
||||
}
|
||||
if self._is_reading_events:
|
||||
await self.connection.send(json.dumps(msg))
|
||||
|
||||
async def send_text(self, *, role: Role, text: str, turn_complete: bool = True) -> None:
|
||||
"""Send a text message to the Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
role: The role of the message.
|
||||
text: The text of the message.
|
||||
turn_complete: A flag indicating if the turn is complete.
|
||||
"""
|
||||
msg = {
|
||||
"client_content": {
|
||||
"turn_complete": turn_complete,
|
||||
"turns": [{"role": role, "parts": [{"text": text}]}],
|
||||
}
|
||||
}
|
||||
if self._is_reading_events:
|
||||
await self.connection.send(json.dumps(msg))
|
||||
|
||||
async def send_audio(self, audio: str) -> None:
|
||||
"""Send audio to the Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
audio (str): The audio to send.
|
||||
"""
|
||||
msg = {
|
||||
"realtime_input": {
|
||||
"media_chunks": [
|
||||
{
|
||||
"data": audio,
|
||||
"mime_type": "audio/pcm",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await self.queue_input_audio_buffer_delta(audio)
|
||||
if self._is_reading_events:
|
||||
await self.connection.send(json.dumps(msg))
|
||||
|
||||
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||
self.logger.info("This is not natively supported by Gemini Realtime API.")
|
||||
pass
|
||||
|
||||
async def _initialize_session(self) -> None:
|
||||
"""Initialize the session with the Gemini Realtime API."""
|
||||
session_config = {
|
||||
"setup": {
|
||||
"system_instruction": {
|
||||
"role": "system",
|
||||
"parts": [{"text": self._pending_session_updates.get("instructions", "")}],
|
||||
},
|
||||
"model": f"models/{self._model}",
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": tool_schema["name"],
|
||||
"description": tool_schema["description"],
|
||||
"parameters": tool_schema["parameters"],
|
||||
}
|
||||
for tool_schema in self._pending_session_updates.get("tools", [])
|
||||
]
|
||||
},
|
||||
],
|
||||
"generation_config": {
|
||||
"response_modalities": [self._response_modality],
|
||||
"speech_config": {"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": self._voice}}},
|
||||
"temperature": self._temperature,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
self.logger.info(f"Sending session update: {session_config}")
|
||||
await self.connection.send(json.dumps(session_config))
|
||||
|
||||
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||
"""Record session updates to be applied when the connection is established.
|
||||
|
||||
Args:
|
||||
session_options (dict[str, Any]): The session options to update.
|
||||
"""
|
||||
if self._is_reading_events:
|
||||
self.logger.warning("Is reading events. Session update will be ignored.")
|
||||
else:
|
||||
self._pending_session_updates.update(session_options)
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self) -> AsyncGenerator[None, None]:
|
||||
"""Connect to the Gemini Realtime API."""
|
||||
try:
|
||||
async with connect(
|
||||
self._base_url, additional_headers={"Content-Type": "application/json"}
|
||||
) as self._connection:
|
||||
yield
|
||||
finally:
|
||||
self._connection = None
|
||||
|
||||
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read Events from the Gemini Realtime Client"""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Client is not connected, call connect() first.")
|
||||
await self._initialize_session()
|
||||
|
||||
self._is_reading_events = True
|
||||
|
||||
async for event in self._read_events():
|
||||
yield event
|
||||
|
||||
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read messages from the Gemini Realtime connection."""
|
||||
async for raw_message in self.connection:
|
||||
message = raw_message.decode("ascii") if isinstance(raw_message, bytes) else raw_message
|
||||
events = self._parse_message(json.loads(message))
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def _parse_message(self, response: dict[str, Any]) -> list[RealtimeEvent]:
|
||||
"""Parse a message from the Gemini Realtime API.
|
||||
|
||||
Args:
|
||||
response (dict[str, Any]): The response to parse.
|
||||
|
||||
Returns:
|
||||
list[RealtimeEvent]: The parsed events.
|
||||
"""
|
||||
if "serverContent" in response and "modelTurn" in response["serverContent"]:
|
||||
try:
|
||||
b64data = response["serverContent"]["modelTurn"]["parts"][0]["inlineData"].pop("data")
|
||||
return [
|
||||
AudioDelta(
|
||||
delta=b64data,
|
||||
item_id=None,
|
||||
raw_message=response,
|
||||
)
|
||||
]
|
||||
except KeyError:
|
||||
return []
|
||||
elif "toolCall" in response:
|
||||
return [
|
||||
FunctionCall(
|
||||
raw_message=response,
|
||||
call_id=call["id"],
|
||||
name=call["name"],
|
||||
arguments=call["args"],
|
||||
)
|
||||
for call in response["toolCall"]["functionCalls"]
|
||||
]
|
||||
elif "setupComplete" in response:
|
||||
return [
|
||||
SessionCreated(raw_message=response),
|
||||
]
|
||||
else:
|
||||
return [RealtimeEvent(raw_message=response)]
|
||||
|
||||
@classmethod
|
||||
def get_factory(
|
||||
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||
"""Create a Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The LLM config for the client.
|
||||
logger: The logger for the client.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||
"""
|
||||
if llm_config["config_list"][0].get("api_type") == "google" and list(kwargs.keys()) == []:
|
||||
return lambda: GeminiRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
# needed for mypy to check if GeminiRealtimeClient implements RealtimeClientProtocol
|
||||
if TYPE_CHECKING:
|
||||
_client: RealtimeClientProtocol = GeminiRealtimeClient(llm_config={})
|
||||
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .base_client import OpenAIRealtimeClient
|
||||
from .rtc_client import OpenAIRealtimeWebRTCClient
|
||||
|
||||
__all__ = ["OpenAIRealtimeClient", "OpenAIRealtimeWebRTCClient"]
|
||||
@@ -0,0 +1,220 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger, getLogger
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from ......doc_utils import export_module
|
||||
from ......import_utils import optional_import_block, require_optional_import
|
||||
from ......llm_config import LLMConfig
|
||||
from ...realtime_events import RealtimeEvent
|
||||
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
|
||||
from .utils import parse_oai_message
|
||||
|
||||
with optional_import_block():
|
||||
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
|
||||
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..realtime_client import RealtimeClientProtocol
|
||||
|
||||
__all__ = ["OpenAIRealtimeClient"]
|
||||
|
||||
global_logger = getLogger(__name__)
|
||||
|
||||
|
||||
@register_realtime_client()
|
||||
@require_optional_import("openai>=1.66.2", "openai-realtime", except_for=["get_factory", "__init__"])
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||
class OpenAIRealtimeClient(RealtimeClientBase):
|
||||
"""(Experimental) Client for OpenAI Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
logger: Optional[Logger] = None,
|
||||
) -> None:
|
||||
"""(Experimental) Client for OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: the logger to use for logging events
|
||||
"""
|
||||
super().__init__()
|
||||
self._llm_config = llm_config
|
||||
self._logger = logger
|
||||
|
||||
self._connection: Optional["AsyncRealtimeConnection"] = None
|
||||
|
||||
self.config = llm_config["config_list"][0]
|
||||
# model is passed to self._client.beta.realtime.connect function later
|
||||
self._model: str = self.config["model"]
|
||||
self._voice: str = self.config.get("voice", "alloy")
|
||||
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
|
||||
|
||||
self._client: Optional["AsyncOpenAI"] = None
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Get the logger for the OpenAI Realtime API."""
|
||||
return self._logger or global_logger
|
||||
|
||||
@property
|
||||
def connection(self) -> "AsyncRealtimeConnection":
|
||||
"""Get the OpenAI WebSocket connection."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("OpenAI WebSocket is not initialized")
|
||||
return self._connection
|
||||
|
||||
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||
"""Send the result of a function call to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
call_id (str): The ID of the function call.
|
||||
result (str): The result of the function call.
|
||||
"""
|
||||
await self.connection.conversation.item.create(
|
||||
item={
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": result,
|
||||
},
|
||||
)
|
||||
|
||||
await self.connection.response.create()
|
||||
|
||||
async def send_text(self, *, role: Role, text: str) -> None:
|
||||
"""Send a text message to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
role (str): The role of the message.
|
||||
text (str): The text of the message.
|
||||
"""
|
||||
await self.connection.response.cancel()
|
||||
await self.connection.conversation.item.create(
|
||||
item={"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}
|
||||
)
|
||||
await self.connection.response.create()
|
||||
|
||||
async def send_audio(self, audio: str) -> None:
|
||||
"""Send audio to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
audio (str): The audio to send.
|
||||
"""
|
||||
await self.queue_input_audio_buffer_delta(audio)
|
||||
await self.connection.input_audio_buffer.append(audio=audio)
|
||||
|
||||
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||
"""Truncate audio in the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
audio_end_ms (int): The end of the audio to truncate.
|
||||
content_index (int): The index of the content to truncate.
|
||||
item_id (str): The ID of the item to truncate.
|
||||
"""
|
||||
await self.connection.conversation.item.truncate(
|
||||
audio_end_ms=audio_end_ms, content_index=content_index, item_id=item_id
|
||||
)
|
||||
|
||||
async def _initialize_session(self) -> None:
|
||||
"""Control initial session with OpenAI."""
|
||||
session_update = {
|
||||
"turn_detection": {"type": "server_vad"},
|
||||
"voice": self._voice,
|
||||
"modalities": ["audio", "text"],
|
||||
"temperature": self._temperature,
|
||||
}
|
||||
await self.session_update(session_options=session_update)
|
||||
|
||||
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||
"""Send a session update to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
session_options (dict[str, Any]): The session options to update.
|
||||
"""
|
||||
logger = self.logger
|
||||
logger.info(f"Sending session update: {session_options}")
|
||||
await self.connection.session.update(session=session_options) # type: ignore[arg-type]
|
||||
logger.info("Sending session update finished")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self) -> AsyncGenerator[None, None]:
|
||||
"""Connect to the OpenAI Realtime API."""
|
||||
try:
|
||||
if not self._client:
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.config.get("api_key", None),
|
||||
organization=self.config.get("organization", None),
|
||||
project=self.config.get("project", None),
|
||||
base_url=self.config.get("base_url", None),
|
||||
websocket_base_url=self.config.get("websocket_base_url", None),
|
||||
timeout=self.config.get("timeout", NOT_GIVEN),
|
||||
max_retries=self.config.get("max_retries", DEFAULT_MAX_RETRIES),
|
||||
default_headers=self.config.get("default_headers", None),
|
||||
default_query=self.config.get("default_query", None),
|
||||
)
|
||||
async with self._client.beta.realtime.connect(
|
||||
model=self._model,
|
||||
) as self._connection:
|
||||
await self._initialize_session()
|
||||
yield
|
||||
finally:
|
||||
self._connection = None
|
||||
|
||||
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read messages from the OpenAI Realtime API."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Client is not connected, call connect() first.")
|
||||
|
||||
try:
|
||||
async for event in self._read_events():
|
||||
yield event
|
||||
|
||||
finally:
|
||||
self._connection = None
|
||||
|
||||
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read messages from the OpenAI Realtime API."""
|
||||
async for message in self._connection:
|
||||
for event in self._parse_message(message.model_dump()):
|
||||
yield event
|
||||
|
||||
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
|
||||
"""Parse a message from the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
message (dict[str, Any]): The message to parse.
|
||||
|
||||
Returns:
|
||||
RealtimeEvent: The parsed event.
|
||||
"""
|
||||
return [parse_oai_message(message)]
|
||||
|
||||
@classmethod
|
||||
def get_factory(
|
||||
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||
"""Create a Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger to use for logging events.
|
||||
kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||
"""
|
||||
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == []:
|
||||
return lambda: OpenAIRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
|
||||
if TYPE_CHECKING:
|
||||
_client: RealtimeClientProtocol = OpenAIRealtimeClient(llm_config={})
|
||||
@@ -0,0 +1,243 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger, getLogger
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from autogen.import_utils import optional_import_block, require_optional_import
|
||||
|
||||
from ......doc_utils import export_module
|
||||
from ......llm_config import LLMConfig
|
||||
from ...realtime_events import RealtimeEvent
|
||||
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
|
||||
from .utils import parse_oai_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...websockets import WebSocketProtocol as WebSocket
|
||||
from ..realtime_client import RealtimeClientProtocol
|
||||
|
||||
with optional_import_block():
|
||||
import httpx
|
||||
|
||||
__all__ = ["OpenAIRealtimeWebRTCClient"]
|
||||
|
||||
global_logger = getLogger(__name__)
|
||||
|
||||
|
||||
@register_realtime_client()
|
||||
@require_optional_import("httpx", "openai-realtime", except_for="get_factory")
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients.oai")
|
||||
class OpenAIRealtimeWebRTCClient(RealtimeClientBase):
|
||||
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
websocket: "WebSocket",
|
||||
logger: Optional[Logger] = None,
|
||||
) -> None:
|
||||
"""(Experimental) Client for OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
websocket: the websocket to use for the connection
|
||||
logger: the logger to use for logging events
|
||||
"""
|
||||
super().__init__()
|
||||
self._llm_config = llm_config
|
||||
self._logger = logger
|
||||
self._websocket = websocket
|
||||
|
||||
config = llm_config["config_list"][0]
|
||||
self._model: str = config["model"]
|
||||
self._voice: str = config.get("voice", "alloy")
|
||||
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
|
||||
self._config = config
|
||||
self._base_url = config.get("base_url", "https://api.openai.com/v1/realtime/sessions")
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Get the logger for the OpenAI Realtime API."""
|
||||
return self._logger or global_logger
|
||||
|
||||
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||
"""Send the result of a function call to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
call_id (str): The ID of the function call.
|
||||
result (str): The result of the function call.
|
||||
"""
|
||||
await self._websocket.send_json({
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": result,
|
||||
},
|
||||
})
|
||||
await self._websocket.send_json({"type": "response.create"})
|
||||
|
||||
async def send_text(self, *, role: Role, text: str) -> None:
|
||||
"""Send a text message to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
role (str): The role of the message.
|
||||
text (str): The text of the message.
|
||||
"""
|
||||
# await self.connection.response.cancel() #why is this here?
|
||||
await self._websocket.send_json({
|
||||
"type": "response.cancel",
|
||||
})
|
||||
await self._websocket.send_json({
|
||||
"type": "conversation.item.create",
|
||||
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
|
||||
})
|
||||
# await self.connection.response.create()
|
||||
await self._websocket.send_json({"type": "response.create"})
|
||||
|
||||
async def send_audio(self, audio: str) -> None:
|
||||
"""Send audio to the OpenAI Realtime API.
|
||||
in case of WebRTC, audio is already sent by js client, so we just queue it in order to be logged.
|
||||
|
||||
Args:
|
||||
audio (str): The audio to send.
|
||||
"""
|
||||
await self.queue_input_audio_buffer_delta(audio)
|
||||
|
||||
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||
"""Truncate audio in the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
audio_end_ms (int): The end of the audio to truncate.
|
||||
content_index (int): The index of the content to truncate.
|
||||
item_id (str): The ID of the item to truncate.
|
||||
"""
|
||||
await self._websocket.send_json({
|
||||
"type": "conversation.item.truncate",
|
||||
"content_index": content_index,
|
||||
"item_id": item_id,
|
||||
"audio_end_ms": audio_end_ms,
|
||||
})
|
||||
|
||||
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||
"""Send a session update to the OpenAI Realtime API.
|
||||
|
||||
In the case of WebRTC we can not send it directly, but we can send it
|
||||
to the javascript over the websocket, and rely on it to send session
|
||||
update to OpenAI
|
||||
|
||||
Args:
|
||||
session_options (dict[str, Any]): The session options to update.
|
||||
"""
|
||||
logger = self.logger
|
||||
logger.info(f"Sending session update: {session_options}")
|
||||
# await self.connection.session.update(session=session_options) # type: ignore[arg-type]
|
||||
await self._websocket.send_json({"type": "session.update", "session": session_options})
|
||||
logger.info("Sending session update finished")
|
||||
|
||||
def session_init_data(self) -> list[dict[str, Any]]:
|
||||
"""Control initial session with OpenAI."""
|
||||
session_update = {
|
||||
"turn_detection": {"type": "server_vad"},
|
||||
"voice": self._voice,
|
||||
"modalities": ["audio", "text"],
|
||||
"temperature": self._temperature,
|
||||
}
|
||||
return [{"type": "session.update", "session": session_update}]
|
||||
|
||||
async def _initialize_session(self) -> None: ...
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self) -> AsyncGenerator[None, None]:
|
||||
"""Connect to the OpenAI Realtime API.
|
||||
|
||||
In the case of WebRTC, we pass connection information over the
|
||||
websocket, so that javascript on the other end of websocket open
|
||||
actual connection to OpenAI
|
||||
"""
|
||||
try:
|
||||
base_url = self._base_url
|
||||
api_key = self._config.get("api_key", None)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
# "model": "gpt-4o-realtime-preview-2024-12-17",
|
||||
"model": self._model,
|
||||
"voice": self._voice,
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(base_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
json_data = response.json()
|
||||
json_data["model"] = self._model
|
||||
if self._websocket is not None:
|
||||
session_init = self.session_init_data()
|
||||
await self._websocket.send_json({"type": "ag2.init", "config": json_data, "init": session_init})
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read events from the OpenAI Realtime API."""
|
||||
async for event in self._read_events():
|
||||
yield event
|
||||
|
||||
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read messages from the OpenAI Realtime API connection.
|
||||
Again, in case of WebRTC, we do not read OpenAI messages directly since we
|
||||
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
|
||||
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
message_json = await self._websocket.receive_text()
|
||||
message = json.loads(message_json)
|
||||
for event in self._parse_message(message):
|
||||
yield event
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error reading from connection {e}")
|
||||
break
|
||||
|
||||
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
|
||||
"""Parse a message from the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
message (dict[str, Any]): The message to parse.
|
||||
|
||||
Returns:
|
||||
RealtimeEvent: The parsed event.
|
||||
"""
|
||||
return [parse_oai_message(message)]
|
||||
|
||||
@classmethod
|
||||
def get_factory(
|
||||
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||
"""Create a Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger to use for logging events.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||
"""
|
||||
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == ["websocket"]:
|
||||
return lambda: OpenAIRealtimeWebRTCClient(llm_config=llm_config, logger=logger, **kwargs)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol:
|
||||
return OpenAIRealtimeWebRTCClient(llm_config={}, websocket=websocket)
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from ...realtime_events import (
|
||||
AudioDelta,
|
||||
FunctionCall,
|
||||
InputAudioBufferDelta,
|
||||
RealtimeEvent,
|
||||
SessionCreated,
|
||||
SessionUpdated,
|
||||
SpeechStarted,
|
||||
)
|
||||
|
||||
__all__ = ["parse_oai_message"]
|
||||
|
||||
|
||||
def parse_oai_message(message: dict[str, Any]) -> RealtimeEvent:
|
||||
"""Parse a message from the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
message (dict[str, Any]): The message to parse.
|
||||
|
||||
Returns:
|
||||
RealtimeEvent: The parsed event.
|
||||
"""
|
||||
if message.get("type") == "session.created":
|
||||
return SessionCreated(raw_message=message)
|
||||
elif message.get("type") == "session.updated":
|
||||
return SessionUpdated(raw_message=message)
|
||||
elif message.get("type") == "response.audio.delta":
|
||||
return AudioDelta(raw_message=message, delta=message["delta"], item_id=message["item_id"])
|
||||
elif message.get("type") == "input_audio_buffer.speech_started":
|
||||
return SpeechStarted(raw_message=message)
|
||||
elif message.get("type") == "input_audio_buffer.delta":
|
||||
return InputAudioBufferDelta(delta=message["delta"], item_id=None, raw_message=message)
|
||||
elif message.get("type") == "response.function_call_arguments.done":
|
||||
return FunctionCall(
|
||||
raw_message=message,
|
||||
call_id=message["call_id"],
|
||||
name=message["name"],
|
||||
arguments=json.loads(message["arguments"]),
|
||||
)
|
||||
else:
|
||||
return RealtimeEvent(raw_message=message)
|
||||
@@ -0,0 +1,190 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from logging import Logger
|
||||
from typing import Any, AsyncContextManager, Callable, Literal, Optional, Protocol, TypeVar, Union, runtime_checkable
|
||||
|
||||
from asyncer import create_task_group
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....llm_config import LLMConfig
|
||||
from ..realtime_events import InputAudioBufferDelta, RealtimeEvent
|
||||
|
||||
__all__ = ["RealtimeClientProtocol", "Role", "get_client", "register_realtime_client"]
|
||||
|
||||
# define role literal type for typing
|
||||
Role = Literal["user", "assistant", "system"]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||
class RealtimeClientProtocol(Protocol):
|
||||
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||
"""Send the result of a function call to a Realtime API.
|
||||
|
||||
Args:
|
||||
call_id (str): The ID of the function call.
|
||||
result (str): The result of the function call.
|
||||
"""
|
||||
...
|
||||
|
||||
async def send_text(self, *, role: Role, text: str) -> None:
|
||||
"""Send a text message to a Realtime API.
|
||||
|
||||
Args:
|
||||
role (str): The role of the message.
|
||||
text (str): The text of the message.
|
||||
"""
|
||||
...
|
||||
|
||||
async def send_audio(self, audio: str) -> None:
|
||||
"""Send audio to a Realtime API.
|
||||
|
||||
Args:
|
||||
audio (str): The audio to send.
|
||||
"""
|
||||
...
|
||||
|
||||
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||
"""Truncate audio in a Realtime API.
|
||||
|
||||
Args:
|
||||
audio_end_ms (int): The end of the audio to truncate.
|
||||
content_index (int): The index of the content to truncate.
|
||||
item_id (str): The ID of the item to truncate.
|
||||
"""
|
||||
...
|
||||
|
||||
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||
"""Send a session update to a Realtime API.
|
||||
|
||||
Args:
|
||||
session_options (dict[str, Any]): The session options to update.
|
||||
"""
|
||||
...
|
||||
|
||||
def connect(self) -> AsyncContextManager[None]: ...
|
||||
|
||||
def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read events from a Realtime Client."""
|
||||
...
|
||||
|
||||
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read events from a Realtime connection."""
|
||||
...
|
||||
|
||||
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
|
||||
"""Parse a message from a Realtime API.
|
||||
|
||||
Args:
|
||||
message (dict[str, Any]): The message to parse.
|
||||
|
||||
Returns:
|
||||
list[RealtimeEvent]: The parsed events.
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_factory(
|
||||
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||
"""Create a Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger to use for logging events.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class RealtimeClientBase:
|
||||
def __init__(self):
|
||||
self._eventQueue = asyncio.Queue()
|
||||
|
||||
async def add_event(self, event: Optional[RealtimeEvent]):
|
||||
await self._eventQueue.put(event)
|
||||
|
||||
async def get_event(self) -> Optional[RealtimeEvent]:
|
||||
return await self._eventQueue.get()
|
||||
|
||||
async def _read_from_connection_task(self):
|
||||
async for event in self._read_from_connection():
|
||||
await self.add_event(event)
|
||||
await self.add_event(None)
|
||||
|
||||
async def _read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||
"""Read events from a Realtime Client."""
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(self._read_from_connection_task)
|
||||
while True:
|
||||
try:
|
||||
event = await self._eventQueue.get()
|
||||
if event is not None:
|
||||
yield event
|
||||
else:
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
|
||||
async def queue_input_audio_buffer_delta(self, audio: str) -> None:
|
||||
"""queue InputAudioBufferDelta.
|
||||
|
||||
Args:
|
||||
audio (str): The audio.
|
||||
"""
|
||||
await self.add_event(InputAudioBufferDelta(delta=audio, item_id=None, raw_message=dict()))
|
||||
|
||||
|
||||
_realtime_client_classes: dict[str, type[RealtimeClientProtocol]] = {}
|
||||
|
||||
T = TypeVar("T", bound=RealtimeClientProtocol)
|
||||
|
||||
|
||||
def register_realtime_client() -> Callable[[type[T]], type[T]]:
|
||||
"""Register a Realtime API client.
|
||||
|
||||
Returns:
|
||||
Callable[[type[T]], type[T]]: The decorator to register the Realtime API client
|
||||
"""
|
||||
|
||||
def decorator(client_cls: type[T]) -> type[T]:
|
||||
"""Register a Realtime API client.
|
||||
|
||||
Args:
|
||||
client_cls: The client to register.
|
||||
"""
|
||||
global _realtime_client_classes
|
||||
fqn = f"{client_cls.__module__}.{client_cls.__name__}"
|
||||
_realtime_client_classes[fqn] = client_cls
|
||||
|
||||
return client_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||
def get_client(llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any) -> "RealtimeClientProtocol":
|
||||
"""Get a registered Realtime API client.
|
||||
|
||||
Args:
|
||||
llm_config: The config for the client.
|
||||
logger: The logger to use for logging events.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RealtimeClientProtocol: The Realtime API client.
|
||||
"""
|
||||
global _realtime_client_classes
|
||||
for _, client_cls in _realtime_client_classes.items():
|
||||
factory = client_cls.get_factory(llm_config=llm_config, logger=logger, **kwargs)
|
||||
if factory:
|
||||
return factory()
|
||||
|
||||
raise ValueError("Realtime API client not found.")
|
||||
Reference in New Issue
Block a user