CoACT initialize (#292)
This commit is contained in:
213
mm_agents/coact/autogen/io/websockets.py
Normal file
213
mm_agents/coact/autogen/io/websockets.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import logging
|
||||
import ssl
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Any, Callable, Iterable, Iterator, Optional, Protocol, Union
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from ..events.base_event import BaseEvent
|
||||
from ..events.print_event import PrintEvent
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from .base import IOStream
|
||||
|
||||
# Check if the websockets module is available
|
||||
with optional_import_block():
|
||||
from websockets.sync.server import serve as ws_serve
|
||||
|
||||
__all__ = ("IOWebsockets",)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# The following type and protocols are used to define the ServerConnection and WebSocketServer classes
|
||||
# if websockets is not installed, they would be untyped
|
||||
Data = Union[str, bytes]
|
||||
|
||||
|
||||
class ServerConnection(Protocol):
|
||||
def send(self, message: Union[Data, Iterable[Data]]) -> None:
|
||||
"""Send a message to the client.
|
||||
|
||||
Args:
|
||||
message (Union[Data, Iterable[Data]]): The message to send.
|
||||
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
def recv(self, timeout: Optional[float] = None) -> Data:
|
||||
"""Receive a message from the client.
|
||||
|
||||
Args:
|
||||
timeout (Optional[float], optional): The timeout for the receive operation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Data: The message received from the client.
|
||||
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
...
|
||||
|
||||
|
||||
class WebSocketServer(Protocol):
|
||||
def serve_forever(self) -> None:
|
||||
"""Run the server forever."""
|
||||
... # pragma: no cover
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the server."""
|
||||
... # pragma: no cover
|
||||
|
||||
def __enter__(self) -> "WebSocketServer":
|
||||
"""Enter the server context."""
|
||||
... # pragma: no cover
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
"""Exit the server context."""
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@require_optional_import("websockets", "websockets")
|
||||
@export_module("autogen.io")
|
||||
class IOWebsockets(IOStream):
|
||||
"""A websocket input/output stream."""
|
||||
|
||||
def __init__(self, websocket: ServerConnection) -> None:
|
||||
"""Initialize the websocket input/output stream.
|
||||
|
||||
Args:
|
||||
websocket (ServerConnection): The websocket server.
|
||||
"""
|
||||
self._websocket = websocket
|
||||
|
||||
@staticmethod
|
||||
def _handler(websocket: ServerConnection, on_connect: Callable[["IOWebsockets"], None]) -> None:
|
||||
"""The handler function for the websocket server."""
|
||||
logger.info(f" - IOWebsockets._handler(): Client connected on {websocket}")
|
||||
# create a new IOWebsockets instance using the websocket that is create when a client connects
|
||||
try:
|
||||
iowebsocket = IOWebsockets(websocket)
|
||||
with IOStream.set_default(iowebsocket):
|
||||
# call the on_connect function
|
||||
try:
|
||||
on_connect(iowebsocket)
|
||||
except Exception as e:
|
||||
logger.warning(f" - IOWebsockets._handler(): Error in on_connect: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f" - IOWebsockets._handler(): Unexpected error in IOWebsockets: {e}")
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def run_server_in_thread(
|
||||
*,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
on_connect: Callable[["IOWebsockets"], None],
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
"""Factory function to create a websocket input/output stream.
|
||||
|
||||
Args:
|
||||
host (str, optional): The host to bind the server to. Defaults to "127.0.0.1".
|
||||
port (int, optional): The port to bind the server to. Defaults to 8765.
|
||||
on_connect (Callable[[IOWebsockets], None]): The function to be executed on client connection. Typically creates agents and initiate chat.
|
||||
ssl_context (Optional[ssl.SSLContext], optional): The SSL context to use for secure connections. Defaults to None.
|
||||
kwargs (Any): Additional keyword arguments to pass to the websocket server.
|
||||
|
||||
Yields:
|
||||
str: The URI of the websocket server.
|
||||
"""
|
||||
server_dict: dict[str, WebSocketServer] = {}
|
||||
|
||||
def _run_server() -> None:
|
||||
# print(f" - _run_server(): starting server on ws://{host}:{port}", flush=True)
|
||||
with ws_serve(
|
||||
handler=partial(IOWebsockets._handler, on_connect=on_connect),
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_context=ssl_context,
|
||||
**kwargs,
|
||||
) as server:
|
||||
# print(f" - _run_server(): server {server} started on ws://{host}:{port}", flush=True)
|
||||
|
||||
server_dict["server"] = server
|
||||
|
||||
# runs until the server is shutdown
|
||||
server.serve_forever()
|
||||
|
||||
return
|
||||
|
||||
# start server in a separate thread
|
||||
thread = threading.Thread(target=_run_server)
|
||||
thread.start()
|
||||
try:
|
||||
while "server" not in server_dict:
|
||||
sleep(0.1)
|
||||
|
||||
yield f"ws://{host}:{port}"
|
||||
|
||||
finally:
|
||||
# print(f" - run_server_in_thread(): shutting down server on ws://{host}:{port}", flush=True)
|
||||
# gracefully stop server
|
||||
if "server" in server_dict:
|
||||
# print(f" - run_server_in_thread(): shutting down server {server_dict['server']}", flush=True)
|
||||
server_dict["server"].shutdown()
|
||||
|
||||
# wait for the thread to stop
|
||||
if thread:
|
||||
thread.join()
|
||||
|
||||
@property
|
||||
def websocket(self) -> "ServerConnection":
|
||||
"""The URI of the websocket server."""
|
||||
return self._websocket
|
||||
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
"""Print data to the output stream.
|
||||
|
||||
Args:
|
||||
objects (any): The data to print.
|
||||
sep (str, optional): The separator between objects. Defaults to " ".
|
||||
end (str, optional): The end of the output. Defaults to "\n".
|
||||
flush (bool, optional): Whether to flush the output. Defaults to False.
|
||||
"""
|
||||
print_message = PrintEvent(*objects, sep=sep, end=end)
|
||||
self.send(print_message)
|
||||
|
||||
def send(self, message: BaseEvent) -> None:
|
||||
"""Send a message to the output stream.
|
||||
|
||||
Args:
|
||||
message (Any): The message to send.
|
||||
"""
|
||||
self._websocket.send(message.model_dump_json())
|
||||
|
||||
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
"""Read a line from the input stream.
|
||||
|
||||
Args:
|
||||
prompt (str, optional): The prompt to display. Defaults to "".
|
||||
password (bool, optional): Whether to read a password. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: The line read from the input stream.
|
||||
|
||||
"""
|
||||
if prompt != "":
|
||||
self._websocket.send(prompt)
|
||||
|
||||
msg = self._websocket.recv()
|
||||
|
||||
return msg.decode("utf-8") if isinstance(msg, bytes) else msg
|
||||
Reference in New Issue
Block a user