CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .base_client import OpenAIRealtimeClient
from .rtc_client import OpenAIRealtimeWebRTCClient
__all__ = ["OpenAIRealtimeClient", "OpenAIRealtimeWebRTCClient"]

View File

@@ -0,0 +1,220 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from ......doc_utils import export_module
from ......import_utils import optional_import_block, require_optional_import
from ......llm_config import LLMConfig
from ...realtime_events import RealtimeEvent
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message
with optional_import_block():
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection
if TYPE_CHECKING:
from ..realtime_client import RealtimeClientProtocol
__all__ = ["OpenAIRealtimeClient"]
global_logger = getLogger(__name__)
@register_realtime_client()
@require_optional_import("openai>=1.66.2", "openai-realtime", except_for=["get_factory", "__init__"])
@export_module("autogen.agentchat.realtime.experimental.clients")
class OpenAIRealtimeClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Args:
llm_config: The config for the client.
logger: the logger to use for logging events
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._connection: Optional["AsyncRealtimeConnection"] = None
self.config = llm_config["config_list"][0]
# model is passed to self._client.beta.realtime.connect function later
self._model: str = self.config["model"]
self._voice: str = self.config.get("voice", "alloy")
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
self._client: Optional["AsyncOpenAI"] = None
@property
def logger(self) -> Logger:
"""Get the logger for the OpenAI Realtime API."""
return self._logger or global_logger
@property
def connection(self) -> "AsyncRealtimeConnection":
"""Get the OpenAI WebSocket connection."""
if self._connection is None:
raise RuntimeError("OpenAI WebSocket is not initialized")
return self._connection
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
await self.connection.conversation.item.create(
item={
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
)
await self.connection.response.create()
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
await self.connection.response.cancel()
await self.connection.conversation.item.create(
item={"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}
)
await self.connection.response.create()
async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
Args:
audio (str): The audio to send.
"""
await self.queue_input_audio_buffer_delta(audio)
await self.connection.input_audio_buffer.append(audio=audio)
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
await self.connection.conversation.item.truncate(
audio_end_ms=audio_end_ms, content_index=content_index, item_id=item_id
)
async def _initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._voice,
"modalities": ["audio", "text"],
"temperature": self._temperature,
}
await self.session_update(session_options=session_update)
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
Args:
session_options (dict[str, Any]): The session options to update.
"""
logger = self.logger
logger.info(f"Sending session update: {session_options}")
await self.connection.session.update(session=session_options) # type: ignore[arg-type]
logger.info("Sending session update finished")
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the OpenAI Realtime API."""
try:
if not self._client:
self._client = AsyncOpenAI(
api_key=self.config.get("api_key", None),
organization=self.config.get("organization", None),
project=self.config.get("project", None),
base_url=self.config.get("base_url", None),
websocket_base_url=self.config.get("websocket_base_url", None),
timeout=self.config.get("timeout", NOT_GIVEN),
max_retries=self.config.get("max_retries", DEFAULT_MAX_RETRIES),
default_headers=self.config.get("default_headers", None),
default_query=self.config.get("default_query", None),
)
async with self._client.beta.realtime.connect(
model=self._model,
) as self._connection:
await self._initialize_session()
yield
finally:
self._connection = None
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API."""
if self._connection is None:
raise RuntimeError("Client is not connected, call connect() first.")
try:
async for event in self._read_events():
yield event
finally:
self._connection = None
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API."""
async for message in self._connection:
for event in self._parse_message(message.model_dump()):
yield event
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
return [parse_oai_message(message)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == []:
return lambda: OpenAIRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_client: RealtimeClientProtocol = OpenAIRealtimeClient(llm_config={})

View File

@@ -0,0 +1,243 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from autogen.import_utils import optional_import_block, require_optional_import
from ......doc_utils import export_module
from ......llm_config import LLMConfig
from ...realtime_events import RealtimeEvent
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message
if TYPE_CHECKING:
from ...websockets import WebSocketProtocol as WebSocket
from ..realtime_client import RealtimeClientProtocol
with optional_import_block():
import httpx
__all__ = ["OpenAIRealtimeWebRTCClient"]
global_logger = getLogger(__name__)
@register_realtime_client()
@require_optional_import("httpx", "openai-realtime", except_for="get_factory")
@export_module("autogen.agentchat.realtime.experimental.clients.oai")
class OpenAIRealtimeWebRTCClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""
def __init__(
self,
*,
llm_config: Union[LLMConfig, dict[str, Any]],
websocket: "WebSocket",
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Args:
llm_config: The config for the client.
websocket: the websocket to use for the connection
logger: the logger to use for logging events
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._websocket = websocket
config = llm_config["config_list"][0]
self._model: str = config["model"]
self._voice: str = config.get("voice", "alloy")
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
self._config = config
self._base_url = config.get("base_url", "https://api.openai.com/v1/realtime/sessions")
@property
def logger(self) -> Logger:
"""Get the logger for the OpenAI Realtime API."""
return self._logger or global_logger
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
await self._websocket.send_json({
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
})
await self._websocket.send_json({"type": "response.create"})
async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
# await self.connection.response.cancel() #why is this here?
await self._websocket.send_json({
"type": "response.cancel",
})
await self._websocket.send_json({
"type": "conversation.item.create",
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
})
# await self.connection.response.create()
await self._websocket.send_json({"type": "response.create"})
async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
in case of WebRTC, audio is already sent by js client, so we just queue it in order to be logged.
Args:
audio (str): The audio to send.
"""
await self.queue_input_audio_buffer_delta(audio)
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
await self._websocket.send_json({
"type": "conversation.item.truncate",
"content_index": content_index,
"item_id": item_id,
"audio_end_ms": audio_end_ms,
})
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
In the case of WebRTC we can not send it directly, but we can send it
to the javascript over the websocket, and rely on it to send session
update to OpenAI
Args:
session_options (dict[str, Any]): The session options to update.
"""
logger = self.logger
logger.info(f"Sending session update: {session_options}")
# await self.connection.session.update(session=session_options) # type: ignore[arg-type]
await self._websocket.send_json({"type": "session.update", "session": session_options})
logger.info("Sending session update finished")
def session_init_data(self) -> list[dict[str, Any]]:
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._voice,
"modalities": ["audio", "text"],
"temperature": self._temperature,
}
return [{"type": "session.update", "session": session_update}]
async def _initialize_session(self) -> None: ...
@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the OpenAI Realtime API.
In the case of WebRTC, we pass connection information over the
websocket, so that javascript on the other end of websocket open
actual connection to OpenAI
"""
try:
base_url = self._base_url
api_key = self._config.get("api_key", None)
headers = {
"Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment
"Content-Type": "application/json",
}
data = {
# "model": "gpt-4o-realtime-preview-2024-12-17",
"model": self._model,
"voice": self._voice,
}
async with httpx.AsyncClient() as client:
response = await client.post(base_url, headers=headers, json=data)
response.raise_for_status()
json_data = response.json()
json_data["model"] = self._model
if self._websocket is not None:
session_init = self.session_init_data()
await self._websocket.send_json({"type": "ag2.init", "config": json_data, "init": session_init})
yield
finally:
pass
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from the OpenAI Realtime API."""
async for event in self._read_events():
yield event
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API connection.
Again, in case of WebRTC, we do not read OpenAI messages directly since we
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
"""
while True:
try:
message_json = await self._websocket.receive_text()
message = json.loads(message_json)
for event in self._parse_message(message):
yield event
except Exception as e:
self.logger.exception(f"Error reading from connection {e}")
break
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
return [parse_oai_message(message)]
@classmethod
def get_factory(
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
"""Create a Realtime API client.
Args:
llm_config: The config for the client.
logger: The logger to use for logging events.
**kwargs: Additional arguments.
Returns:
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
"""
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == ["websocket"]:
return lambda: OpenAIRealtimeWebRTCClient(llm_config=llm_config, logger=logger, **kwargs)
return None
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol:
return OpenAIRealtimeWebRTCClient(llm_config={}, websocket=websocket)

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Any
from ...realtime_events import (
AudioDelta,
FunctionCall,
InputAudioBufferDelta,
RealtimeEvent,
SessionCreated,
SessionUpdated,
SpeechStarted,
)
__all__ = ["parse_oai_message"]
def parse_oai_message(message: dict[str, Any]) -> RealtimeEvent:
"""Parse a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message to parse.
Returns:
RealtimeEvent: The parsed event.
"""
if message.get("type") == "session.created":
return SessionCreated(raw_message=message)
elif message.get("type") == "session.updated":
return SessionUpdated(raw_message=message)
elif message.get("type") == "response.audio.delta":
return AudioDelta(raw_message=message, delta=message["delta"], item_id=message["item_id"])
elif message.get("type") == "input_audio_buffer.speech_started":
return SpeechStarted(raw_message=message)
elif message.get("type") == "input_audio_buffer.delta":
return InputAudioBufferDelta(delta=message["delta"], item_id=None, raw_message=message)
elif message.get("type") == "response.function_call_arguments.done":
return FunctionCall(
raw_message=message,
call_id=message["call_id"],
name=message["name"],
arguments=json.loads(message["arguments"]),
)
else:
return RealtimeEvent(raw_message=message)