CoACT initialize (#292)
This commit is contained in:
293
mm_agents/coact/autogen/io/run_response.py
Normal file
293
mm_agents/coact/autogen/io/run_response.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import queue
|
||||
from asyncio import Queue as AsyncQueue
|
||||
from typing import Any, AsyncIterable, Dict, Iterable, Optional, Protocol, Sequence, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogen.tools.tool import Tool
|
||||
|
||||
from ..agentchat.agent import Agent, LLMMessageType
|
||||
from ..agentchat.group.context_variables import ContextVariables
|
||||
from ..events.agent_events import ErrorEvent, InputRequestEvent, RunCompletionEvent
|
||||
from ..events.base_event import BaseEvent
|
||||
from .processors import (
|
||||
AsyncConsoleEventProcessor,
|
||||
AsyncEventProcessorProtocol,
|
||||
ConsoleEventProcessor,
|
||||
EventProcessorProtocol,
|
||||
)
|
||||
from .thread_io_stream import AsyncThreadIOStream, ThreadIOStream
|
||||
|
||||
Message = dict[str, Any]
|
||||
|
||||
|
||||
class RunInfoProtocol(Protocol):
|
||||
@property
|
||||
def uuid(self) -> UUID: ...
|
||||
|
||||
@property
|
||||
def above_run(self) -> Optional["RunResponseProtocol"]: ...
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
cost: float
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CostBreakdown(BaseModel):
|
||||
total_cost: float
|
||||
models: Dict[str, Usage] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, data: dict[str, Any]) -> "CostBreakdown":
|
||||
# Extract total cost
|
||||
total_cost = data.get("total_cost", 0.0)
|
||||
|
||||
# Remove total_cost key to extract models
|
||||
model_usages = {k: Usage(**v) for k, v in data.items() if k != "total_cost"}
|
||||
|
||||
return cls(total_cost=total_cost, models=model_usages)
|
||||
|
||||
|
||||
class Cost(BaseModel):
|
||||
usage_including_cached_inference: CostBreakdown
|
||||
usage_excluding_cached_inference: CostBreakdown
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, data: dict[str, Any]) -> "Cost":
|
||||
return cls(
|
||||
usage_including_cached_inference=CostBreakdown.from_raw(data.get("usage_including_cached_inference", {})),
|
||||
usage_excluding_cached_inference=CostBreakdown.from_raw(data.get("usage_excluding_cached_inference", {})),
|
||||
)
|
||||
|
||||
|
||||
class RunResponseProtocol(RunInfoProtocol, Protocol):
|
||||
@property
|
||||
def events(self) -> Iterable[BaseEvent]: ...
|
||||
|
||||
@property
|
||||
def messages(self) -> Iterable[Message]: ...
|
||||
|
||||
@property
|
||||
def summary(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
def context_variables(self) -> Optional[ContextVariables]: ...
|
||||
|
||||
@property
|
||||
def last_speaker(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
def cost(self) -> Optional[Cost]: ...
|
||||
|
||||
def process(self, processor: Optional[EventProcessorProtocol] = None) -> None: ...
|
||||
|
||||
def set_ui_tools(self, tools: list[Tool]) -> None: ...
|
||||
|
||||
|
||||
class AsyncRunResponseProtocol(RunInfoProtocol, Protocol):
|
||||
@property
|
||||
def events(self) -> AsyncIterable[BaseEvent]: ...
|
||||
|
||||
@property
|
||||
async def messages(self) -> Iterable[Message]: ...
|
||||
|
||||
@property
|
||||
async def summary(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
async def context_variables(self) -> Optional[ContextVariables]: ...
|
||||
|
||||
@property
|
||||
async def last_speaker(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
async def cost(self) -> Optional[Cost]: ...
|
||||
|
||||
async def process(self, processor: Optional[AsyncEventProcessorProtocol] = None) -> None: ...
|
||||
|
||||
def set_ui_tools(self, tools: list[Tool]) -> None: ...
|
||||
|
||||
|
||||
class RunResponse:
|
||||
def __init__(self, iostream: ThreadIOStream, agents: list[Agent]):
|
||||
self.iostream = iostream
|
||||
self.agents = agents
|
||||
self._summary: Optional[str] = None
|
||||
self._messages: Sequence[LLMMessageType] = []
|
||||
self._uuid = uuid4()
|
||||
self._context_variables: Optional[ContextVariables] = None
|
||||
self._last_speaker: Optional[str] = None
|
||||
self._cost: Optional[Cost] = None
|
||||
|
||||
def _queue_generator(self, q: queue.Queue) -> Iterable[BaseEvent]: # type: ignore[type-arg]
|
||||
"""A generator to yield items from the queue until the termination message is found."""
|
||||
while True:
|
||||
try:
|
||||
# Get an item from the queue
|
||||
event = q.get(timeout=0.1) # Adjust timeout as needed
|
||||
|
||||
if isinstance(event, InputRequestEvent):
|
||||
event.content.respond = lambda response: self.iostream._output_stream.put(response) # type: ignore[attr-defined]
|
||||
|
||||
yield event
|
||||
|
||||
if isinstance(event, RunCompletionEvent):
|
||||
self._messages = event.content.history # type: ignore[attr-defined]
|
||||
self._last_speaker = event.content.last_speaker # type: ignore[attr-defined]
|
||||
self._summary = event.content.summary # type: ignore[attr-defined]
|
||||
self._context_variables = event.content.context_variables # type: ignore[attr-defined]
|
||||
self.cost = event.content.cost # type: ignore[attr-defined]
|
||||
break
|
||||
|
||||
if isinstance(event, ErrorEvent):
|
||||
raise event.content.error # type: ignore[attr-defined]
|
||||
except queue.Empty:
|
||||
continue # Wait for more items in the queue
|
||||
|
||||
@property
|
||||
def events(self) -> Iterable[BaseEvent]:
|
||||
return self._queue_generator(self.iostream.input_stream)
|
||||
|
||||
@property
|
||||
def messages(self) -> Iterable[Message]:
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def summary(self) -> Optional[str]:
|
||||
return self._summary
|
||||
|
||||
@property
|
||||
def above_run(self) -> Optional["RunResponseProtocol"]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def uuid(self) -> UUID:
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def context_variables(self) -> Optional[ContextVariables]:
|
||||
return self._context_variables
|
||||
|
||||
@property
|
||||
def last_speaker(self) -> Optional[str]:
|
||||
return self._last_speaker
|
||||
|
||||
@property
|
||||
def cost(self) -> Optional[Cost]:
|
||||
return self._cost
|
||||
|
||||
@cost.setter
|
||||
def cost(self, value: Union[Cost, dict[str, Any]]) -> None:
|
||||
if isinstance(value, dict):
|
||||
self._cost = Cost.from_raw(value)
|
||||
else:
|
||||
self._cost = value
|
||||
|
||||
def process(self, processor: Optional[EventProcessorProtocol] = None) -> None:
|
||||
processor = processor or ConsoleEventProcessor()
|
||||
processor.process(self)
|
||||
|
||||
def set_ui_tools(self, tools: list[Tool]) -> None:
|
||||
"""Set the UI tools for the agents."""
|
||||
for agent in self.agents:
|
||||
agent.set_ui_tools(tools)
|
||||
|
||||
|
||||
class AsyncRunResponse:
|
||||
def __init__(self, iostream: AsyncThreadIOStream, agents: list[Agent]):
|
||||
self.iostream = iostream
|
||||
self.agents = agents
|
||||
self._summary: Optional[str] = None
|
||||
self._messages: Sequence[LLMMessageType] = []
|
||||
self._uuid = uuid4()
|
||||
self._context_variables: Optional[ContextVariables] = None
|
||||
self._last_speaker: Optional[str] = None
|
||||
self._cost: Optional[Cost] = None
|
||||
|
||||
async def _queue_generator(self, q: AsyncQueue[Any]) -> AsyncIterable[BaseEvent]: # type: ignore[type-arg]
|
||||
"""A generator to yield items from the queue until the termination message is found."""
|
||||
while True:
|
||||
try:
|
||||
# Get an item from the queue
|
||||
event = await q.get()
|
||||
|
||||
if isinstance(event, InputRequestEvent):
|
||||
|
||||
async def respond(response: str) -> None:
|
||||
await self.iostream._output_stream.put(response)
|
||||
|
||||
event.content.respond = respond # type: ignore[attr-defined]
|
||||
|
||||
yield event
|
||||
|
||||
if isinstance(event, RunCompletionEvent):
|
||||
self._messages = event.content.history # type: ignore[attr-defined]
|
||||
self._last_speaker = event.content.last_speaker # type: ignore[attr-defined]
|
||||
self._summary = event.content.summary # type: ignore[attr-defined]
|
||||
self._context_variables = event.content.context_variables # type: ignore[attr-defined]
|
||||
self.cost = event.content.cost # type: ignore[attr-defined]
|
||||
break
|
||||
|
||||
if isinstance(event, ErrorEvent):
|
||||
raise event.content.error # type: ignore[attr-defined]
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
@property
|
||||
def events(self) -> AsyncIterable[BaseEvent]:
|
||||
return self._queue_generator(self.iostream.input_stream)
|
||||
|
||||
@property
|
||||
async def messages(self) -> Iterable[Message]:
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
async def summary(self) -> Optional[str]:
|
||||
return self._summary
|
||||
|
||||
@property
|
||||
def above_run(self) -> Optional["RunResponseProtocol"]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def uuid(self) -> UUID:
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
async def context_variables(self) -> Optional[ContextVariables]:
|
||||
return self._context_variables
|
||||
|
||||
@property
|
||||
async def last_speaker(self) -> Optional[str]:
|
||||
return self._last_speaker
|
||||
|
||||
@property
|
||||
async def cost(self) -> Optional[Cost]:
|
||||
return self._cost
|
||||
|
||||
@cost.setter
|
||||
def cost(self, value: Union[Cost, dict[str, Any]]) -> None:
|
||||
if isinstance(value, dict):
|
||||
self._cost = Cost.from_raw(value)
|
||||
else:
|
||||
self._cost = value
|
||||
|
||||
async def process(self, processor: Optional[AsyncEventProcessorProtocol] = None) -> None:
|
||||
processor = processor or AsyncConsoleEventProcessor()
|
||||
await processor.process(self)
|
||||
|
||||
def set_ui_tools(self, tools: list[Tool]) -> None:
|
||||
"""Set the UI tools for the agents."""
|
||||
for agent in self.agents:
|
||||
agent.set_ui_tools(tools)
|
||||
Reference in New Issue
Block a user