CoACT initialize (#292)
This commit is contained in:
15
mm_agents/coact/autogen/io/__init__.py
Normal file
15
mm_agents/coact/autogen/io/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from .base import IOStream, InputStream, OutputStream
|
||||
from .console import IOConsole
|
||||
from .websockets import IOWebsockets
|
||||
|
||||
# Set the default input/output stream to the console
|
||||
IOStream.set_global_default(IOConsole())
|
||||
IOStream.set_default(IOConsole())
|
||||
|
||||
__all__ = ("IOConsole", "IOStream", "IOWebsockets", "InputStream", "OutputStream")
|
||||
151
mm_agents/coact/autogen/io/base.py
Normal file
151
mm_agents/coact/autogen/io/base.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Optional, Protocol, Union, runtime_checkable
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from ..events.base_event import BaseEvent
|
||||
|
||||
__all__ = ("IOStream", "InputStream", "OutputStream")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.io")
|
||||
class OutputStream(Protocol):
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
"""Print data to the output stream.
|
||||
|
||||
Args:
|
||||
objects (any): The data to print.
|
||||
sep (str, optional): The separator between objects. Defaults to " ".
|
||||
end (str, optional): The end of the output. Defaults to "\n".
|
||||
flush (bool, optional): Whether to flush the output. Defaults to False.
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
def send(self, message: BaseEvent) -> None:
|
||||
"""Send data to the output stream.
|
||||
|
||||
Args:
|
||||
message (BaseEvent): BaseEvent from autogen.messages.base_message
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.io")
|
||||
class InputStream(Protocol):
|
||||
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
"""Read a line from the input stream.
|
||||
|
||||
Args:
|
||||
prompt (str, optional): The prompt to display. Defaults to "".
|
||||
password (bool, optional): Whether to read a password. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: The line read from the input stream.
|
||||
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.io")
|
||||
class AsyncInputStream(Protocol):
|
||||
async def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
"""Read a line from the input stream.
|
||||
|
||||
Args:
|
||||
prompt (str, optional): The prompt to display. Defaults to "".
|
||||
password (bool, optional): Whether to read a password. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: The line read from the input stream.
|
||||
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.io")
|
||||
class IOStreamProtocol(InputStream, OutputStream, Protocol):
|
||||
"""A protocol for input/output streams."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.io")
|
||||
class AsyncIOStreamProtocol(AsyncInputStream, OutputStream, Protocol):
|
||||
"""A protocol for input/output streams."""
|
||||
|
||||
|
||||
iostream_union = Union[IOStreamProtocol, AsyncIOStreamProtocol]
|
||||
|
||||
|
||||
@export_module("autogen.io")
|
||||
class IOStream:
|
||||
"""A protocol for input/output streams."""
|
||||
|
||||
# ContextVar must be used in multithreaded or async environments
|
||||
_default_io_stream: ContextVar[Optional[iostream_union]] = ContextVar("default_iostream", default=None)
|
||||
_default_io_stream.set(None)
|
||||
_global_default: Optional[iostream_union] = None
|
||||
|
||||
@staticmethod
|
||||
def set_global_default(stream: iostream_union) -> None:
|
||||
"""Set the default input/output stream.
|
||||
|
||||
Args:
|
||||
stream (IOStream): The input/output stream to set as the default.
|
||||
"""
|
||||
IOStream._global_default = stream
|
||||
|
||||
@staticmethod
|
||||
def get_global_default() -> iostream_union:
|
||||
"""Get the default input/output stream.
|
||||
|
||||
Returns:
|
||||
IOStream: The default input/output stream.
|
||||
"""
|
||||
if IOStream._global_default is None:
|
||||
raise RuntimeError("No global default IOStream has been set")
|
||||
return IOStream._global_default
|
||||
|
||||
@staticmethod
|
||||
def get_default() -> iostream_union:
|
||||
"""Get the default input/output stream.
|
||||
|
||||
Returns:
|
||||
IOStream: The default input/output stream.
|
||||
"""
|
||||
iostream = IOStream._default_io_stream.get()
|
||||
if iostream is None:
|
||||
iostream = IOStream.get_global_default()
|
||||
# Set the default IOStream of the current context (thread/cooroutine)
|
||||
IOStream.set_default(iostream)
|
||||
return iostream
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def set_default(stream: Optional[iostream_union]) -> Iterator[None]:
|
||||
"""Set the default input/output stream.
|
||||
|
||||
Args:
|
||||
stream (IOStream): The input/output stream to set as the default.
|
||||
"""
|
||||
global _default_io_stream
|
||||
try:
|
||||
token = IOStream._default_io_stream.set(stream)
|
||||
yield
|
||||
finally:
|
||||
IOStream._default_io_stream.reset(token)
|
||||
|
||||
return
|
||||
56
mm_agents/coact/autogen/io/console.py
Normal file
56
mm_agents/coact/autogen/io/console.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import getpass
|
||||
from typing import Any
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from ..events.base_event import BaseEvent
|
||||
from ..events.print_event import PrintEvent
|
||||
from .base import IOStream
|
||||
|
||||
__all__ = ("IOConsole",)
|
||||
|
||||
|
||||
@export_module("autogen.io")
|
||||
class IOConsole(IOStream):
|
||||
"""A console input/output stream."""
|
||||
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
"""Print data to the output stream.
|
||||
|
||||
Args:
|
||||
objects (any): The data to print.
|
||||
sep (str, optional): The separator between objects. Defaults to " ".
|
||||
end (str, optional): The end of the output. Defaults to "\n".
|
||||
flush (bool, optional): Whether to flush the output. Defaults to False.
|
||||
"""
|
||||
print_message = PrintEvent(*objects, sep=sep, end=end)
|
||||
self.send(print_message)
|
||||
# print(*objects, sep=sep, end=end, flush=flush)
|
||||
|
||||
def send(self, message: BaseEvent) -> None:
|
||||
"""Send a message to the output stream.
|
||||
|
||||
Args:
|
||||
message (Any): The message to send.
|
||||
"""
|
||||
message.print()
|
||||
|
||||
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
"""Read a line from the input stream.
|
||||
|
||||
Args:
|
||||
prompt (str, optional): The prompt to display. Defaults to "".
|
||||
password (bool, optional): Whether to read a password. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: The line read from the input stream.
|
||||
|
||||
"""
|
||||
if password:
|
||||
return getpass.getpass(prompt if prompt != "" else "Password: ")
|
||||
return input(prompt)
|
||||
12
mm_agents/coact/autogen/io/processors/__init__.py
Normal file
12
mm_agents/coact/autogen/io/processors/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from .base import AsyncEventProcessorProtocol, EventProcessorProtocol
|
||||
from .console_event_processor import AsyncConsoleEventProcessor, ConsoleEventProcessor
|
||||
|
||||
__all__ = [
|
||||
"AsyncConsoleEventProcessor",
|
||||
"AsyncEventProcessorProtocol",
|
||||
"ConsoleEventProcessor",
|
||||
"EventProcessorProtocol",
|
||||
]
|
||||
21
mm_agents/coact/autogen/io/processors/base.py
Normal file
21
mm_agents/coact/autogen/io/processors/base.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from ...doc_utils import export_module
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..run_response import AsyncRunResponseProtocol, RunResponseProtocol
|
||||
|
||||
__all__ = ["AsyncEventProcessorProtocol", "EventProcessorProtocol"]
|
||||
|
||||
|
||||
@export_module("autogen.io")
|
||||
class EventProcessorProtocol(Protocol):
|
||||
def process(self, response: "RunResponseProtocol") -> None: ...
|
||||
|
||||
|
||||
@export_module("autogen.io")
|
||||
class AsyncEventProcessorProtocol(Protocol):
|
||||
async def process(self, response: "AsyncRunResponseProtocol") -> None: ...
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import getpass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...doc_utils import export_module
|
||||
from ...events.agent_events import InputRequestEvent
|
||||
from ...events.base_event import BaseEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..run_response import AsyncRunResponseProtocol, RunResponseProtocol
|
||||
from .base import AsyncEventProcessorProtocol, EventProcessorProtocol
|
||||
|
||||
|
||||
@export_module("autogen.io")
|
||||
class ConsoleEventProcessor:
|
||||
def process(self, response: "RunResponseProtocol") -> None:
|
||||
for event in response.events:
|
||||
self.process_event(event)
|
||||
|
||||
def process_event(self, event: BaseEvent) -> None:
|
||||
if isinstance(event, InputRequestEvent):
|
||||
prompt = event.content.prompt # type: ignore[attr-defined]
|
||||
if event.content.password: # type: ignore[attr-defined]
|
||||
result = getpass.getpass(prompt if prompt != "" else "Password: ")
|
||||
result = input(prompt)
|
||||
event.content.respond(result) # type: ignore[attr-defined]
|
||||
else:
|
||||
event.print()
|
||||
|
||||
|
||||
@export_module("autogen.io")
|
||||
class AsyncConsoleEventProcessor:
|
||||
async def process(self, response: "AsyncRunResponseProtocol") -> None:
|
||||
async for event in response.events:
|
||||
await self.process_event(event)
|
||||
|
||||
async def process_event(self, event: BaseEvent) -> None:
|
||||
if isinstance(event, InputRequestEvent):
|
||||
prompt = event.content.prompt # type: ignore[attr-defined]
|
||||
if event.content.password: # type: ignore[attr-defined]
|
||||
result = getpass.getpass(prompt if prompt != "" else "Password: ")
|
||||
result = input(prompt)
|
||||
await event.content.respond(result) # type: ignore[attr-defined]
|
||||
else:
|
||||
event.print()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def check_type_1(x: ConsoleEventProcessor) -> EventProcessorProtocol:
|
||||
return x
|
||||
|
||||
def check_type_2(x: AsyncConsoleEventProcessor) -> AsyncEventProcessorProtocol:
|
||||
return x
|
||||
293
mm_agents/coact/autogen/io/run_response.py
Normal file
293
mm_agents/coact/autogen/io/run_response.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import queue
|
||||
from asyncio import Queue as AsyncQueue
|
||||
from typing import Any, AsyncIterable, Dict, Iterable, Optional, Protocol, Sequence, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogen.tools.tool import Tool
|
||||
|
||||
from ..agentchat.agent import Agent, LLMMessageType
|
||||
from ..agentchat.group.context_variables import ContextVariables
|
||||
from ..events.agent_events import ErrorEvent, InputRequestEvent, RunCompletionEvent
|
||||
from ..events.base_event import BaseEvent
|
||||
from .processors import (
|
||||
AsyncConsoleEventProcessor,
|
||||
AsyncEventProcessorProtocol,
|
||||
ConsoleEventProcessor,
|
||||
EventProcessorProtocol,
|
||||
)
|
||||
from .thread_io_stream import AsyncThreadIOStream, ThreadIOStream
|
||||
|
||||
Message = dict[str, Any]
|
||||
|
||||
|
||||
class RunInfoProtocol(Protocol):
|
||||
@property
|
||||
def uuid(self) -> UUID: ...
|
||||
|
||||
@property
|
||||
def above_run(self) -> Optional["RunResponseProtocol"]: ...
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
cost: float
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CostBreakdown(BaseModel):
|
||||
total_cost: float
|
||||
models: Dict[str, Usage] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, data: dict[str, Any]) -> "CostBreakdown":
|
||||
# Extract total cost
|
||||
total_cost = data.get("total_cost", 0.0)
|
||||
|
||||
# Remove total_cost key to extract models
|
||||
model_usages = {k: Usage(**v) for k, v in data.items() if k != "total_cost"}
|
||||
|
||||
return cls(total_cost=total_cost, models=model_usages)
|
||||
|
||||
|
||||
class Cost(BaseModel):
|
||||
usage_including_cached_inference: CostBreakdown
|
||||
usage_excluding_cached_inference: CostBreakdown
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, data: dict[str, Any]) -> "Cost":
|
||||
return cls(
|
||||
usage_including_cached_inference=CostBreakdown.from_raw(data.get("usage_including_cached_inference", {})),
|
||||
usage_excluding_cached_inference=CostBreakdown.from_raw(data.get("usage_excluding_cached_inference", {})),
|
||||
)
|
||||
|
||||
|
||||
class RunResponseProtocol(RunInfoProtocol, Protocol):
|
||||
@property
|
||||
def events(self) -> Iterable[BaseEvent]: ...
|
||||
|
||||
@property
|
||||
def messages(self) -> Iterable[Message]: ...
|
||||
|
||||
@property
|
||||
def summary(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
def context_variables(self) -> Optional[ContextVariables]: ...
|
||||
|
||||
@property
|
||||
def last_speaker(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
def cost(self) -> Optional[Cost]: ...
|
||||
|
||||
def process(self, processor: Optional[EventProcessorProtocol] = None) -> None: ...
|
||||
|
||||
def set_ui_tools(self, tools: list[Tool]) -> None: ...
|
||||
|
||||
|
||||
class AsyncRunResponseProtocol(RunInfoProtocol, Protocol):
|
||||
@property
|
||||
def events(self) -> AsyncIterable[BaseEvent]: ...
|
||||
|
||||
@property
|
||||
async def messages(self) -> Iterable[Message]: ...
|
||||
|
||||
@property
|
||||
async def summary(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
async def context_variables(self) -> Optional[ContextVariables]: ...
|
||||
|
||||
@property
|
||||
async def last_speaker(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
async def cost(self) -> Optional[Cost]: ...
|
||||
|
||||
async def process(self, processor: Optional[AsyncEventProcessorProtocol] = None) -> None: ...
|
||||
|
||||
def set_ui_tools(self, tools: list[Tool]) -> None: ...
|
||||
|
||||
|
||||
class RunResponse:
|
||||
def __init__(self, iostream: ThreadIOStream, agents: list[Agent]):
|
||||
self.iostream = iostream
|
||||
self.agents = agents
|
||||
self._summary: Optional[str] = None
|
||||
self._messages: Sequence[LLMMessageType] = []
|
||||
self._uuid = uuid4()
|
||||
self._context_variables: Optional[ContextVariables] = None
|
||||
self._last_speaker: Optional[str] = None
|
||||
self._cost: Optional[Cost] = None
|
||||
|
||||
def _queue_generator(self, q: queue.Queue) -> Iterable[BaseEvent]: # type: ignore[type-arg]
|
||||
"""A generator to yield items from the queue until the termination message is found."""
|
||||
while True:
|
||||
try:
|
||||
# Get an item from the queue
|
||||
event = q.get(timeout=0.1) # Adjust timeout as needed
|
||||
|
||||
if isinstance(event, InputRequestEvent):
|
||||
event.content.respond = lambda response: self.iostream._output_stream.put(response) # type: ignore[attr-defined]
|
||||
|
||||
yield event
|
||||
|
||||
if isinstance(event, RunCompletionEvent):
|
||||
self._messages = event.content.history # type: ignore[attr-defined]
|
||||
self._last_speaker = event.content.last_speaker # type: ignore[attr-defined]
|
||||
self._summary = event.content.summary # type: ignore[attr-defined]
|
||||
self._context_variables = event.content.context_variables # type: ignore[attr-defined]
|
||||
self.cost = event.content.cost # type: ignore[attr-defined]
|
||||
break
|
||||
|
||||
if isinstance(event, ErrorEvent):
|
||||
raise event.content.error # type: ignore[attr-defined]
|
||||
except queue.Empty:
|
||||
continue # Wait for more items in the queue
|
||||
|
||||
@property
|
||||
def events(self) -> Iterable[BaseEvent]:
|
||||
return self._queue_generator(self.iostream.input_stream)
|
||||
|
||||
@property
|
||||
def messages(self) -> Iterable[Message]:
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def summary(self) -> Optional[str]:
|
||||
return self._summary
|
||||
|
||||
@property
|
||||
def above_run(self) -> Optional["RunResponseProtocol"]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def uuid(self) -> UUID:
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def context_variables(self) -> Optional[ContextVariables]:
|
||||
return self._context_variables
|
||||
|
||||
@property
|
||||
def last_speaker(self) -> Optional[str]:
|
||||
return self._last_speaker
|
||||
|
||||
@property
|
||||
def cost(self) -> Optional[Cost]:
|
||||
return self._cost
|
||||
|
||||
@cost.setter
|
||||
def cost(self, value: Union[Cost, dict[str, Any]]) -> None:
|
||||
if isinstance(value, dict):
|
||||
self._cost = Cost.from_raw(value)
|
||||
else:
|
||||
self._cost = value
|
||||
|
||||
def process(self, processor: Optional[EventProcessorProtocol] = None) -> None:
|
||||
processor = processor or ConsoleEventProcessor()
|
||||
processor.process(self)
|
||||
|
||||
def set_ui_tools(self, tools: list[Tool]) -> None:
|
||||
"""Set the UI tools for the agents."""
|
||||
for agent in self.agents:
|
||||
agent.set_ui_tools(tools)
|
||||
|
||||
|
||||
class AsyncRunResponse:
|
||||
def __init__(self, iostream: AsyncThreadIOStream, agents: list[Agent]):
|
||||
self.iostream = iostream
|
||||
self.agents = agents
|
||||
self._summary: Optional[str] = None
|
||||
self._messages: Sequence[LLMMessageType] = []
|
||||
self._uuid = uuid4()
|
||||
self._context_variables: Optional[ContextVariables] = None
|
||||
self._last_speaker: Optional[str] = None
|
||||
self._cost: Optional[Cost] = None
|
||||
|
||||
async def _queue_generator(self, q: AsyncQueue[Any]) -> AsyncIterable[BaseEvent]: # type: ignore[type-arg]
|
||||
"""A generator to yield items from the queue until the termination message is found."""
|
||||
while True:
|
||||
try:
|
||||
# Get an item from the queue
|
||||
event = await q.get()
|
||||
|
||||
if isinstance(event, InputRequestEvent):
|
||||
|
||||
async def respond(response: str) -> None:
|
||||
await self.iostream._output_stream.put(response)
|
||||
|
||||
event.content.respond = respond # type: ignore[attr-defined]
|
||||
|
||||
yield event
|
||||
|
||||
if isinstance(event, RunCompletionEvent):
|
||||
self._messages = event.content.history # type: ignore[attr-defined]
|
||||
self._last_speaker = event.content.last_speaker # type: ignore[attr-defined]
|
||||
self._summary = event.content.summary # type: ignore[attr-defined]
|
||||
self._context_variables = event.content.context_variables # type: ignore[attr-defined]
|
||||
self.cost = event.content.cost # type: ignore[attr-defined]
|
||||
break
|
||||
|
||||
if isinstance(event, ErrorEvent):
|
||||
raise event.content.error # type: ignore[attr-defined]
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
@property
|
||||
def events(self) -> AsyncIterable[BaseEvent]:
|
||||
return self._queue_generator(self.iostream.input_stream)
|
||||
|
||||
@property
|
||||
async def messages(self) -> Iterable[Message]:
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
async def summary(self) -> Optional[str]:
|
||||
return self._summary
|
||||
|
||||
@property
|
||||
def above_run(self) -> Optional["RunResponseProtocol"]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def uuid(self) -> UUID:
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
async def context_variables(self) -> Optional[ContextVariables]:
|
||||
return self._context_variables
|
||||
|
||||
@property
|
||||
async def last_speaker(self) -> Optional[str]:
|
||||
return self._last_speaker
|
||||
|
||||
@property
|
||||
async def cost(self) -> Optional[Cost]:
|
||||
return self._cost
|
||||
|
||||
@cost.setter
|
||||
def cost(self, value: Union[Cost, dict[str, Any]]) -> None:
|
||||
if isinstance(value, dict):
|
||||
self._cost = Cost.from_raw(value)
|
||||
else:
|
||||
self._cost = value
|
||||
|
||||
async def process(self, processor: Optional[AsyncEventProcessorProtocol] = None) -> None:
|
||||
processor = processor or AsyncConsoleEventProcessor()
|
||||
await processor.process(self)
|
||||
|
||||
def set_ui_tools(self, tools: list[Tool]) -> None:
|
||||
"""Set the UI tools for the agents."""
|
||||
for agent in self.agents:
|
||||
agent.set_ui_tools(tools)
|
||||
63
mm_agents/coact/autogen/io/thread_io_stream.py
Normal file
63
mm_agents/coact/autogen/io/thread_io_stream.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import queue
|
||||
from asyncio import Queue as AsyncQueue
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from autogen.io.base import AsyncIOStreamProtocol, IOStreamProtocol
|
||||
|
||||
from ..events.agent_events import InputRequestEvent
|
||||
from ..events.print_event import PrintEvent
|
||||
|
||||
|
||||
class ThreadIOStream:
|
||||
def __init__(self) -> None:
|
||||
self._input_stream: queue.Queue = queue.Queue() # type: ignore[type-arg]
|
||||
self._output_stream: queue.Queue = queue.Queue() # type: ignore[type-arg]
|
||||
|
||||
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
self.send(InputRequestEvent(prompt=prompt, password=password)) # type: ignore[call-arg]
|
||||
return self._output_stream.get() # type: ignore[no-any-return]
|
||||
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
print_message = PrintEvent(*objects, sep=sep, end=end)
|
||||
self.send(print_message)
|
||||
|
||||
def send(self, message: Any) -> None:
|
||||
self._input_stream.put(message)
|
||||
|
||||
@property
|
||||
def input_stream(self) -> queue.Queue: # type: ignore[type-arg]
|
||||
return self._input_stream
|
||||
|
||||
|
||||
class AsyncThreadIOStream:
|
||||
def __init__(self) -> None:
|
||||
self._input_stream: AsyncQueue = AsyncQueue() # type: ignore[type-arg]
|
||||
self._output_stream: AsyncQueue = AsyncQueue() # type: ignore[type-arg]
|
||||
|
||||
async def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
self.send(InputRequestEvent(prompt=prompt, password=password)) # type: ignore[call-arg]
|
||||
return await self._output_stream.get() # type: ignore[no-any-return]
|
||||
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
print_message = PrintEvent(*objects, sep=sep, end=end)
|
||||
self.send(print_message)
|
||||
|
||||
def send(self, message: Any) -> None:
|
||||
self._input_stream.put_nowait(message)
|
||||
|
||||
@property
|
||||
def input_stream(self) -> AsyncQueue[Any]:
|
||||
return self._input_stream
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def check_type_1(x: ThreadIOStream) -> IOStreamProtocol:
|
||||
return x
|
||||
|
||||
def check_type_2(x: AsyncThreadIOStream) -> AsyncIOStreamProtocol:
|
||||
return x
|
||||
213
mm_agents/coact/autogen/io/websockets.py
Normal file
213
mm_agents/coact/autogen/io/websockets.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import logging
|
||||
import ssl
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Any, Callable, Iterable, Iterator, Optional, Protocol, Union
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from ..events.base_event import BaseEvent
|
||||
from ..events.print_event import PrintEvent
|
||||
from ..import_utils import optional_import_block, require_optional_import
|
||||
from .base import IOStream
|
||||
|
||||
# Check if the websockets module is available
|
||||
with optional_import_block():
|
||||
from websockets.sync.server import serve as ws_serve
|
||||
|
||||
__all__ = ("IOWebsockets",)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# The following type and protocols are used to define the ServerConnection and WebSocketServer classes
|
||||
# if websockets is not installed, they would be untyped
|
||||
Data = Union[str, bytes]
|
||||
|
||||
|
||||
class ServerConnection(Protocol):
|
||||
def send(self, message: Union[Data, Iterable[Data]]) -> None:
|
||||
"""Send a message to the client.
|
||||
|
||||
Args:
|
||||
message (Union[Data, Iterable[Data]]): The message to send.
|
||||
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
def recv(self, timeout: Optional[float] = None) -> Data:
|
||||
"""Receive a message from the client.
|
||||
|
||||
Args:
|
||||
timeout (Optional[float], optional): The timeout for the receive operation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Data: The message received from the client.
|
||||
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
...
|
||||
|
||||
|
||||
class WebSocketServer(Protocol):
|
||||
def serve_forever(self) -> None:
|
||||
"""Run the server forever."""
|
||||
... # pragma: no cover
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the server."""
|
||||
... # pragma: no cover
|
||||
|
||||
def __enter__(self) -> "WebSocketServer":
|
||||
"""Enter the server context."""
|
||||
... # pragma: no cover
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
"""Exit the server context."""
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@require_optional_import("websockets", "websockets")
|
||||
@export_module("autogen.io")
|
||||
class IOWebsockets(IOStream):
|
||||
"""A websocket input/output stream."""
|
||||
|
||||
def __init__(self, websocket: ServerConnection) -> None:
|
||||
"""Initialize the websocket input/output stream.
|
||||
|
||||
Args:
|
||||
websocket (ServerConnection): The websocket server.
|
||||
"""
|
||||
self._websocket = websocket
|
||||
|
||||
@staticmethod
|
||||
def _handler(websocket: ServerConnection, on_connect: Callable[["IOWebsockets"], None]) -> None:
|
||||
"""The handler function for the websocket server."""
|
||||
logger.info(f" - IOWebsockets._handler(): Client connected on {websocket}")
|
||||
# create a new IOWebsockets instance using the websocket that is create when a client connects
|
||||
try:
|
||||
iowebsocket = IOWebsockets(websocket)
|
||||
with IOStream.set_default(iowebsocket):
|
||||
# call the on_connect function
|
||||
try:
|
||||
on_connect(iowebsocket)
|
||||
except Exception as e:
|
||||
logger.warning(f" - IOWebsockets._handler(): Error in on_connect: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f" - IOWebsockets._handler(): Unexpected error in IOWebsockets: {e}")
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def run_server_in_thread(
|
||||
*,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
on_connect: Callable[["IOWebsockets"], None],
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
"""Factory function to create a websocket input/output stream.
|
||||
|
||||
Args:
|
||||
host (str, optional): The host to bind the server to. Defaults to "127.0.0.1".
|
||||
port (int, optional): The port to bind the server to. Defaults to 8765.
|
||||
on_connect (Callable[[IOWebsockets], None]): The function to be executed on client connection. Typically creates agents and initiate chat.
|
||||
ssl_context (Optional[ssl.SSLContext], optional): The SSL context to use for secure connections. Defaults to None.
|
||||
kwargs (Any): Additional keyword arguments to pass to the websocket server.
|
||||
|
||||
Yields:
|
||||
str: The URI of the websocket server.
|
||||
"""
|
||||
server_dict: dict[str, WebSocketServer] = {}
|
||||
|
||||
def _run_server() -> None:
|
||||
# print(f" - _run_server(): starting server on ws://{host}:{port}", flush=True)
|
||||
with ws_serve(
|
||||
handler=partial(IOWebsockets._handler, on_connect=on_connect),
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_context=ssl_context,
|
||||
**kwargs,
|
||||
) as server:
|
||||
# print(f" - _run_server(): server {server} started on ws://{host}:{port}", flush=True)
|
||||
|
||||
server_dict["server"] = server
|
||||
|
||||
# runs until the server is shutdown
|
||||
server.serve_forever()
|
||||
|
||||
return
|
||||
|
||||
# start server in a separate thread
|
||||
thread = threading.Thread(target=_run_server)
|
||||
thread.start()
|
||||
try:
|
||||
while "server" not in server_dict:
|
||||
sleep(0.1)
|
||||
|
||||
yield f"ws://{host}:{port}"
|
||||
|
||||
finally:
|
||||
# print(f" - run_server_in_thread(): shutting down server on ws://{host}:{port}", flush=True)
|
||||
# gracefully stop server
|
||||
if "server" in server_dict:
|
||||
# print(f" - run_server_in_thread(): shutting down server {server_dict['server']}", flush=True)
|
||||
server_dict["server"].shutdown()
|
||||
|
||||
# wait for the thread to stop
|
||||
if thread:
|
||||
thread.join()
|
||||
|
||||
@property
|
||||
def websocket(self) -> "ServerConnection":
|
||||
"""The URI of the websocket server."""
|
||||
return self._websocket
|
||||
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
"""Print data to the output stream.
|
||||
|
||||
Args:
|
||||
objects (any): The data to print.
|
||||
sep (str, optional): The separator between objects. Defaults to " ".
|
||||
end (str, optional): The end of the output. Defaults to "\n".
|
||||
flush (bool, optional): Whether to flush the output. Defaults to False.
|
||||
"""
|
||||
print_message = PrintEvent(*objects, sep=sep, end=end)
|
||||
self.send(print_message)
|
||||
|
||||
def send(self, message: BaseEvent) -> None:
|
||||
"""Send a message to the output stream.
|
||||
|
||||
Args:
|
||||
message (Any): The message to send.
|
||||
"""
|
||||
self._websocket.send(message.model_dump_json())
|
||||
|
||||
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
"""Read a line from the input stream.
|
||||
|
||||
Args:
|
||||
prompt (str, optional): The prompt to display. Defaults to "".
|
||||
password (bool, optional): Whether to read a password. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: The line read from the input stream.
|
||||
|
||||
"""
|
||||
if prompt != "":
|
||||
self._websocket.send(prompt)
|
||||
|
||||
msg = self._websocket.recv()
|
||||
|
||||
return msg.decode("utf-8") if isinstance(msg, bytes) else msg
|
||||
Reference in New Issue
Block a user