CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from .base import IOStream, InputStream, OutputStream
from .console import IOConsole
from .websockets import IOWebsockets
# Set the default input/output stream to the console
IOStream.set_global_default(IOConsole())
IOStream.set_default(IOConsole())
__all__ = ("IOConsole", "IOStream", "IOWebsockets", "InputStream", "OutputStream")

View File

@@ -0,0 +1,151 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import logging
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Optional, Protocol, Union, runtime_checkable
from ..doc_utils import export_module
from ..events.base_event import BaseEvent
__all__ = ("IOStream", "InputStream", "OutputStream")
logger = logging.getLogger(__name__)
@runtime_checkable
@export_module("autogen.io")
class OutputStream(Protocol):
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
"""Print data to the output stream.
Args:
objects (any): The data to print.
sep (str, optional): The separator between objects. Defaults to " ".
end (str, optional): The end of the output. Defaults to "\n".
flush (bool, optional): Whether to flush the output. Defaults to False.
"""
... # pragma: no cover
def send(self, message: BaseEvent) -> None:
"""Send data to the output stream.
Args:
message (BaseEvent): BaseEvent from autogen.messages.base_message
"""
...
@runtime_checkable
@export_module("autogen.io")
class InputStream(Protocol):
def input(self, prompt: str = "", *, password: bool = False) -> str:
"""Read a line from the input stream.
Args:
prompt (str, optional): The prompt to display. Defaults to "".
password (bool, optional): Whether to read a password. Defaults to False.
Returns:
str: The line read from the input stream.
"""
... # pragma: no cover
@runtime_checkable
@export_module("autogen.io")
class AsyncInputStream(Protocol):
async def input(self, prompt: str = "", *, password: bool = False) -> str:
"""Read a line from the input stream.
Args:
prompt (str, optional): The prompt to display. Defaults to "".
password (bool, optional): Whether to read a password. Defaults to False.
Returns:
str: The line read from the input stream.
"""
... # pragma: no cover
@runtime_checkable
@export_module("autogen.io")
class IOStreamProtocol(InputStream, OutputStream, Protocol):
"""A protocol for input/output streams."""
@runtime_checkable
@export_module("autogen.io")
class AsyncIOStreamProtocol(AsyncInputStream, OutputStream, Protocol):
"""A protocol for input/output streams."""
iostream_union = Union[IOStreamProtocol, AsyncIOStreamProtocol]
@export_module("autogen.io")
class IOStream:
"""A protocol for input/output streams."""
# ContextVar must be used in multithreaded or async environments
_default_io_stream: ContextVar[Optional[iostream_union]] = ContextVar("default_iostream", default=None)
_default_io_stream.set(None)
_global_default: Optional[iostream_union] = None
@staticmethod
def set_global_default(stream: iostream_union) -> None:
"""Set the default input/output stream.
Args:
stream (IOStream): The input/output stream to set as the default.
"""
IOStream._global_default = stream
@staticmethod
def get_global_default() -> iostream_union:
"""Get the default input/output stream.
Returns:
IOStream: The default input/output stream.
"""
if IOStream._global_default is None:
raise RuntimeError("No global default IOStream has been set")
return IOStream._global_default
@staticmethod
def get_default() -> iostream_union:
"""Get the default input/output stream.
Returns:
IOStream: The default input/output stream.
"""
iostream = IOStream._default_io_stream.get()
if iostream is None:
iostream = IOStream.get_global_default()
# Set the default IOStream of the current context (thread/cooroutine)
IOStream.set_default(iostream)
return iostream
@staticmethod
@contextmanager
def set_default(stream: Optional[iostream_union]) -> Iterator[None]:
"""Set the default input/output stream.
Args:
stream (IOStream): The input/output stream to set as the default.
"""
global _default_io_stream
try:
token = IOStream._default_io_stream.set(stream)
yield
finally:
IOStream._default_io_stream.reset(token)
return

View File

@@ -0,0 +1,56 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import getpass
from typing import Any
from ..doc_utils import export_module
from ..events.base_event import BaseEvent
from ..events.print_event import PrintEvent
from .base import IOStream
__all__ = ("IOConsole",)
@export_module("autogen.io")
class IOConsole(IOStream):
"""A console input/output stream."""
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
"""Print data to the output stream.
Args:
objects (any): The data to print.
sep (str, optional): The separator between objects. Defaults to " ".
end (str, optional): The end of the output. Defaults to "\n".
flush (bool, optional): Whether to flush the output. Defaults to False.
"""
print_message = PrintEvent(*objects, sep=sep, end=end)
self.send(print_message)
# print(*objects, sep=sep, end=end, flush=flush)
def send(self, message: BaseEvent) -> None:
"""Send a message to the output stream.
Args:
message (Any): The message to send.
"""
message.print()
def input(self, prompt: str = "", *, password: bool = False) -> str:
"""Read a line from the input stream.
Args:
prompt (str, optional): The prompt to display. Defaults to "".
password (bool, optional): Whether to read a password. Defaults to False.
Returns:
str: The line read from the input stream.
"""
if password:
return getpass.getpass(prompt if prompt != "" else "Password: ")
return input(prompt)

View File

@@ -0,0 +1,12 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from .base import AsyncEventProcessorProtocol, EventProcessorProtocol
from .console_event_processor import AsyncConsoleEventProcessor, ConsoleEventProcessor
__all__ = [
"AsyncConsoleEventProcessor",
"AsyncEventProcessorProtocol",
"ConsoleEventProcessor",
"EventProcessorProtocol",
]

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Protocol
from ...doc_utils import export_module
if TYPE_CHECKING:
from ..run_response import AsyncRunResponseProtocol, RunResponseProtocol
__all__ = ["AsyncEventProcessorProtocol", "EventProcessorProtocol"]
@export_module("autogen.io")
class EventProcessorProtocol(Protocol):
def process(self, response: "RunResponseProtocol") -> None: ...
@export_module("autogen.io")
class AsyncEventProcessorProtocol(Protocol):
async def process(self, response: "AsyncRunResponseProtocol") -> None: ...

View File

@@ -0,0 +1,56 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import getpass
from typing import TYPE_CHECKING
from ...doc_utils import export_module
from ...events.agent_events import InputRequestEvent
from ...events.base_event import BaseEvent
if TYPE_CHECKING:
from ..run_response import AsyncRunResponseProtocol, RunResponseProtocol
from .base import AsyncEventProcessorProtocol, EventProcessorProtocol
@export_module("autogen.io")
class ConsoleEventProcessor:
def process(self, response: "RunResponseProtocol") -> None:
for event in response.events:
self.process_event(event)
def process_event(self, event: BaseEvent) -> None:
if isinstance(event, InputRequestEvent):
prompt = event.content.prompt # type: ignore[attr-defined]
if event.content.password: # type: ignore[attr-defined]
result = getpass.getpass(prompt if prompt != "" else "Password: ")
result = input(prompt)
event.content.respond(result) # type: ignore[attr-defined]
else:
event.print()
@export_module("autogen.io")
class AsyncConsoleEventProcessor:
async def process(self, response: "AsyncRunResponseProtocol") -> None:
async for event in response.events:
await self.process_event(event)
async def process_event(self, event: BaseEvent) -> None:
if isinstance(event, InputRequestEvent):
prompt = event.content.prompt # type: ignore[attr-defined]
if event.content.password: # type: ignore[attr-defined]
result = getpass.getpass(prompt if prompt != "" else "Password: ")
result = input(prompt)
await event.content.respond(result) # type: ignore[attr-defined]
else:
event.print()
if TYPE_CHECKING:
def check_type_1(x: ConsoleEventProcessor) -> EventProcessorProtocol:
return x
def check_type_2(x: AsyncConsoleEventProcessor) -> AsyncEventProcessorProtocol:
return x

View File

@@ -0,0 +1,293 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import queue
from asyncio import Queue as AsyncQueue
from typing import Any, AsyncIterable, Dict, Iterable, Optional, Protocol, Sequence, Union
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
from autogen.tools.tool import Tool
from ..agentchat.agent import Agent, LLMMessageType
from ..agentchat.group.context_variables import ContextVariables
from ..events.agent_events import ErrorEvent, InputRequestEvent, RunCompletionEvent
from ..events.base_event import BaseEvent
from .processors import (
AsyncConsoleEventProcessor,
AsyncEventProcessorProtocol,
ConsoleEventProcessor,
EventProcessorProtocol,
)
from .thread_io_stream import AsyncThreadIOStream, ThreadIOStream
Message = dict[str, Any]
class RunInfoProtocol(Protocol):
@property
def uuid(self) -> UUID: ...
@property
def above_run(self) -> Optional["RunResponseProtocol"]: ...
class Usage(BaseModel):
cost: float
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CostBreakdown(BaseModel):
total_cost: float
models: Dict[str, Usage] = Field(default_factory=dict)
@classmethod
def from_raw(cls, data: dict[str, Any]) -> "CostBreakdown":
# Extract total cost
total_cost = data.get("total_cost", 0.0)
# Remove total_cost key to extract models
model_usages = {k: Usage(**v) for k, v in data.items() if k != "total_cost"}
return cls(total_cost=total_cost, models=model_usages)
class Cost(BaseModel):
usage_including_cached_inference: CostBreakdown
usage_excluding_cached_inference: CostBreakdown
@classmethod
def from_raw(cls, data: dict[str, Any]) -> "Cost":
return cls(
usage_including_cached_inference=CostBreakdown.from_raw(data.get("usage_including_cached_inference", {})),
usage_excluding_cached_inference=CostBreakdown.from_raw(data.get("usage_excluding_cached_inference", {})),
)
class RunResponseProtocol(RunInfoProtocol, Protocol):
@property
def events(self) -> Iterable[BaseEvent]: ...
@property
def messages(self) -> Iterable[Message]: ...
@property
def summary(self) -> Optional[str]: ...
@property
def context_variables(self) -> Optional[ContextVariables]: ...
@property
def last_speaker(self) -> Optional[str]: ...
@property
def cost(self) -> Optional[Cost]: ...
def process(self, processor: Optional[EventProcessorProtocol] = None) -> None: ...
def set_ui_tools(self, tools: list[Tool]) -> None: ...
class AsyncRunResponseProtocol(RunInfoProtocol, Protocol):
@property
def events(self) -> AsyncIterable[BaseEvent]: ...
@property
async def messages(self) -> Iterable[Message]: ...
@property
async def summary(self) -> Optional[str]: ...
@property
async def context_variables(self) -> Optional[ContextVariables]: ...
@property
async def last_speaker(self) -> Optional[str]: ...
@property
async def cost(self) -> Optional[Cost]: ...
async def process(self, processor: Optional[AsyncEventProcessorProtocol] = None) -> None: ...
def set_ui_tools(self, tools: list[Tool]) -> None: ...
class RunResponse:
def __init__(self, iostream: ThreadIOStream, agents: list[Agent]):
self.iostream = iostream
self.agents = agents
self._summary: Optional[str] = None
self._messages: Sequence[LLMMessageType] = []
self._uuid = uuid4()
self._context_variables: Optional[ContextVariables] = None
self._last_speaker: Optional[str] = None
self._cost: Optional[Cost] = None
def _queue_generator(self, q: queue.Queue) -> Iterable[BaseEvent]: # type: ignore[type-arg]
"""A generator to yield items from the queue until the termination message is found."""
while True:
try:
# Get an item from the queue
event = q.get(timeout=0.1) # Adjust timeout as needed
if isinstance(event, InputRequestEvent):
event.content.respond = lambda response: self.iostream._output_stream.put(response) # type: ignore[attr-defined]
yield event
if isinstance(event, RunCompletionEvent):
self._messages = event.content.history # type: ignore[attr-defined]
self._last_speaker = event.content.last_speaker # type: ignore[attr-defined]
self._summary = event.content.summary # type: ignore[attr-defined]
self._context_variables = event.content.context_variables # type: ignore[attr-defined]
self.cost = event.content.cost # type: ignore[attr-defined]
break
if isinstance(event, ErrorEvent):
raise event.content.error # type: ignore[attr-defined]
except queue.Empty:
continue # Wait for more items in the queue
@property
def events(self) -> Iterable[BaseEvent]:
return self._queue_generator(self.iostream.input_stream)
@property
def messages(self) -> Iterable[Message]:
return self._messages
@property
def summary(self) -> Optional[str]:
return self._summary
@property
def above_run(self) -> Optional["RunResponseProtocol"]:
return None
@property
def uuid(self) -> UUID:
return self._uuid
@property
def context_variables(self) -> Optional[ContextVariables]:
return self._context_variables
@property
def last_speaker(self) -> Optional[str]:
return self._last_speaker
@property
def cost(self) -> Optional[Cost]:
return self._cost
@cost.setter
def cost(self, value: Union[Cost, dict[str, Any]]) -> None:
if isinstance(value, dict):
self._cost = Cost.from_raw(value)
else:
self._cost = value
def process(self, processor: Optional[EventProcessorProtocol] = None) -> None:
processor = processor or ConsoleEventProcessor()
processor.process(self)
def set_ui_tools(self, tools: list[Tool]) -> None:
"""Set the UI tools for the agents."""
for agent in self.agents:
agent.set_ui_tools(tools)
class AsyncRunResponse:
def __init__(self, iostream: AsyncThreadIOStream, agents: list[Agent]):
self.iostream = iostream
self.agents = agents
self._summary: Optional[str] = None
self._messages: Sequence[LLMMessageType] = []
self._uuid = uuid4()
self._context_variables: Optional[ContextVariables] = None
self._last_speaker: Optional[str] = None
self._cost: Optional[Cost] = None
async def _queue_generator(self, q: AsyncQueue[Any]) -> AsyncIterable[BaseEvent]: # type: ignore[type-arg]
"""A generator to yield items from the queue until the termination message is found."""
while True:
try:
# Get an item from the queue
event = await q.get()
if isinstance(event, InputRequestEvent):
async def respond(response: str) -> None:
await self.iostream._output_stream.put(response)
event.content.respond = respond # type: ignore[attr-defined]
yield event
if isinstance(event, RunCompletionEvent):
self._messages = event.content.history # type: ignore[attr-defined]
self._last_speaker = event.content.last_speaker # type: ignore[attr-defined]
self._summary = event.content.summary # type: ignore[attr-defined]
self._context_variables = event.content.context_variables # type: ignore[attr-defined]
self.cost = event.content.cost # type: ignore[attr-defined]
break
if isinstance(event, ErrorEvent):
raise event.content.error # type: ignore[attr-defined]
except queue.Empty:
continue
@property
def events(self) -> AsyncIterable[BaseEvent]:
return self._queue_generator(self.iostream.input_stream)
@property
async def messages(self) -> Iterable[Message]:
return self._messages
@property
async def summary(self) -> Optional[str]:
return self._summary
@property
def above_run(self) -> Optional["RunResponseProtocol"]:
return None
@property
def uuid(self) -> UUID:
return self._uuid
@property
async def context_variables(self) -> Optional[ContextVariables]:
return self._context_variables
@property
async def last_speaker(self) -> Optional[str]:
return self._last_speaker
@property
async def cost(self) -> Optional[Cost]:
return self._cost
@cost.setter
def cost(self, value: Union[Cost, dict[str, Any]]) -> None:
if isinstance(value, dict):
self._cost = Cost.from_raw(value)
else:
self._cost = value
async def process(self, processor: Optional[AsyncEventProcessorProtocol] = None) -> None:
processor = processor or AsyncConsoleEventProcessor()
await processor.process(self)
def set_ui_tools(self, tools: list[Tool]) -> None:
"""Set the UI tools for the agents."""
for agent in self.agents:
agent.set_ui_tools(tools)

View File

@@ -0,0 +1,63 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import queue
from asyncio import Queue as AsyncQueue
from typing import TYPE_CHECKING, Any
from autogen.io.base import AsyncIOStreamProtocol, IOStreamProtocol
from ..events.agent_events import InputRequestEvent
from ..events.print_event import PrintEvent
class ThreadIOStream:
def __init__(self) -> None:
self._input_stream: queue.Queue = queue.Queue() # type: ignore[type-arg]
self._output_stream: queue.Queue = queue.Queue() # type: ignore[type-arg]
def input(self, prompt: str = "", *, password: bool = False) -> str:
self.send(InputRequestEvent(prompt=prompt, password=password)) # type: ignore[call-arg]
return self._output_stream.get() # type: ignore[no-any-return]
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
print_message = PrintEvent(*objects, sep=sep, end=end)
self.send(print_message)
def send(self, message: Any) -> None:
self._input_stream.put(message)
@property
def input_stream(self) -> queue.Queue: # type: ignore[type-arg]
return self._input_stream
class AsyncThreadIOStream:
def __init__(self) -> None:
self._input_stream: AsyncQueue = AsyncQueue() # type: ignore[type-arg]
self._output_stream: AsyncQueue = AsyncQueue() # type: ignore[type-arg]
async def input(self, prompt: str = "", *, password: bool = False) -> str:
self.send(InputRequestEvent(prompt=prompt, password=password)) # type: ignore[call-arg]
return await self._output_stream.get() # type: ignore[no-any-return]
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
print_message = PrintEvent(*objects, sep=sep, end=end)
self.send(print_message)
def send(self, message: Any) -> None:
self._input_stream.put_nowait(message)
@property
def input_stream(self) -> AsyncQueue[Any]:
return self._input_stream
if TYPE_CHECKING:
def check_type_1(x: ThreadIOStream) -> IOStreamProtocol:
return x
def check_type_2(x: AsyncThreadIOStream) -> AsyncIOStreamProtocol:
return x

View File

@@ -0,0 +1,213 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import logging
import ssl
import threading
from contextlib import contextmanager
from functools import partial
from time import sleep
from typing import Any, Callable, Iterable, Iterator, Optional, Protocol, Union
from ..doc_utils import export_module
from ..events.base_event import BaseEvent
from ..events.print_event import PrintEvent
from ..import_utils import optional_import_block, require_optional_import
from .base import IOStream
# Check if the websockets module is available
with optional_import_block():
from websockets.sync.server import serve as ws_serve
__all__ = ("IOWebsockets",)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# The following type and protocols are used to define the ServerConnection and WebSocketServer classes
# if websockets is not installed, they would be untyped
Data = Union[str, bytes]
class ServerConnection(Protocol):
def send(self, message: Union[Data, Iterable[Data]]) -> None:
"""Send a message to the client.
Args:
message (Union[Data, Iterable[Data]]): The message to send.
"""
... # pragma: no cover
def recv(self, timeout: Optional[float] = None) -> Data:
"""Receive a message from the client.
Args:
timeout (Optional[float], optional): The timeout for the receive operation. Defaults to None.
Returns:
Data: The message received from the client.
"""
... # pragma: no cover
def close(self) -> None:
"""Close the connection."""
...
class WebSocketServer(Protocol):
def serve_forever(self) -> None:
"""Run the server forever."""
... # pragma: no cover
def shutdown(self) -> None:
"""Shutdown the server."""
... # pragma: no cover
def __enter__(self) -> "WebSocketServer":
"""Enter the server context."""
... # pragma: no cover
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
"""Exit the server context."""
... # pragma: no cover
@require_optional_import("websockets", "websockets")
@export_module("autogen.io")
class IOWebsockets(IOStream):
"""A websocket input/output stream."""
def __init__(self, websocket: ServerConnection) -> None:
"""Initialize the websocket input/output stream.
Args:
websocket (ServerConnection): The websocket server.
"""
self._websocket = websocket
@staticmethod
def _handler(websocket: ServerConnection, on_connect: Callable[["IOWebsockets"], None]) -> None:
"""The handler function for the websocket server."""
logger.info(f" - IOWebsockets._handler(): Client connected on {websocket}")
# create a new IOWebsockets instance using the websocket that is create when a client connects
try:
iowebsocket = IOWebsockets(websocket)
with IOStream.set_default(iowebsocket):
# call the on_connect function
try:
on_connect(iowebsocket)
except Exception as e:
logger.warning(f" - IOWebsockets._handler(): Error in on_connect: {e}")
except Exception as e:
logger.error(f" - IOWebsockets._handler(): Unexpected error in IOWebsockets: {e}")
@staticmethod
@contextmanager
def run_server_in_thread(
*,
host: str = "127.0.0.1",
port: int = 8765,
on_connect: Callable[["IOWebsockets"], None],
ssl_context: Optional[ssl.SSLContext] = None,
**kwargs: Any,
) -> Iterator[str]:
"""Factory function to create a websocket input/output stream.
Args:
host (str, optional): The host to bind the server to. Defaults to "127.0.0.1".
port (int, optional): The port to bind the server to. Defaults to 8765.
on_connect (Callable[[IOWebsockets], None]): The function to be executed on client connection. Typically creates agents and initiate chat.
ssl_context (Optional[ssl.SSLContext], optional): The SSL context to use for secure connections. Defaults to None.
kwargs (Any): Additional keyword arguments to pass to the websocket server.
Yields:
str: The URI of the websocket server.
"""
server_dict: dict[str, WebSocketServer] = {}
def _run_server() -> None:
# print(f" - _run_server(): starting server on ws://{host}:{port}", flush=True)
with ws_serve(
handler=partial(IOWebsockets._handler, on_connect=on_connect),
host=host,
port=port,
ssl_context=ssl_context,
**kwargs,
) as server:
# print(f" - _run_server(): server {server} started on ws://{host}:{port}", flush=True)
server_dict["server"] = server
# runs until the server is shutdown
server.serve_forever()
return
# start server in a separate thread
thread = threading.Thread(target=_run_server)
thread.start()
try:
while "server" not in server_dict:
sleep(0.1)
yield f"ws://{host}:{port}"
finally:
# print(f" - run_server_in_thread(): shutting down server on ws://{host}:{port}", flush=True)
# gracefully stop server
if "server" in server_dict:
# print(f" - run_server_in_thread(): shutting down server {server_dict['server']}", flush=True)
server_dict["server"].shutdown()
# wait for the thread to stop
if thread:
thread.join()
@property
def websocket(self) -> "ServerConnection":
"""The URI of the websocket server."""
return self._websocket
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
"""Print data to the output stream.
Args:
objects (any): The data to print.
sep (str, optional): The separator between objects. Defaults to " ".
end (str, optional): The end of the output. Defaults to "\n".
flush (bool, optional): Whether to flush the output. Defaults to False.
"""
print_message = PrintEvent(*objects, sep=sep, end=end)
self.send(print_message)
def send(self, message: BaseEvent) -> None:
"""Send a message to the output stream.
Args:
message (Any): The message to send.
"""
self._websocket.send(message.model_dump_json())
def input(self, prompt: str = "", *, password: bool = False) -> str:
"""Read a line from the input stream.
Args:
prompt (str, optional): The prompt to display. Defaults to "".
password (bool, optional): Whether to read a password. Defaults to False.
Returns:
str: The line read from the input stream.
"""
if prompt != "":
self._websocket.send(prompt)
msg = self._websocket.recv()
return msg.decode("utf-8") if isinstance(msg, bytes) else msg