294 lines
9.5 KiB
Python
294 lines
9.5 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import queue
|
|
from asyncio import Queue as AsyncQueue
|
|
from typing import Any, AsyncIterable, Dict, Iterable, Optional, Protocol, Sequence, Union
|
|
from uuid import UUID, uuid4
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from autogen.tools.tool import Tool
|
|
|
|
from ..agentchat.agent import Agent, LLMMessageType
|
|
from ..agentchat.group.context_variables import ContextVariables
|
|
from ..events.agent_events import ErrorEvent, InputRequestEvent, RunCompletionEvent
|
|
from ..events.base_event import BaseEvent
|
|
from .processors import (
|
|
AsyncConsoleEventProcessor,
|
|
AsyncEventProcessorProtocol,
|
|
ConsoleEventProcessor,
|
|
EventProcessorProtocol,
|
|
)
|
|
from .thread_io_stream import AsyncThreadIOStream, ThreadIOStream
|
|
|
|
Message = dict[str, Any]
|
|
|
|
|
|
class RunInfoProtocol(Protocol):
|
|
@property
|
|
def uuid(self) -> UUID: ...
|
|
|
|
@property
|
|
def above_run(self) -> Optional["RunResponseProtocol"]: ...
|
|
|
|
|
|
class Usage(BaseModel):
|
|
cost: float
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
total_tokens: int
|
|
|
|
|
|
class CostBreakdown(BaseModel):
|
|
total_cost: float
|
|
models: Dict[str, Usage] = Field(default_factory=dict)
|
|
|
|
@classmethod
|
|
def from_raw(cls, data: dict[str, Any]) -> "CostBreakdown":
|
|
# Extract total cost
|
|
total_cost = data.get("total_cost", 0.0)
|
|
|
|
# Remove total_cost key to extract models
|
|
model_usages = {k: Usage(**v) for k, v in data.items() if k != "total_cost"}
|
|
|
|
return cls(total_cost=total_cost, models=model_usages)
|
|
|
|
|
|
class Cost(BaseModel):
|
|
usage_including_cached_inference: CostBreakdown
|
|
usage_excluding_cached_inference: CostBreakdown
|
|
|
|
@classmethod
|
|
def from_raw(cls, data: dict[str, Any]) -> "Cost":
|
|
return cls(
|
|
usage_including_cached_inference=CostBreakdown.from_raw(data.get("usage_including_cached_inference", {})),
|
|
usage_excluding_cached_inference=CostBreakdown.from_raw(data.get("usage_excluding_cached_inference", {})),
|
|
)
|
|
|
|
|
|
class RunResponseProtocol(RunInfoProtocol, Protocol):
|
|
@property
|
|
def events(self) -> Iterable[BaseEvent]: ...
|
|
|
|
@property
|
|
def messages(self) -> Iterable[Message]: ...
|
|
|
|
@property
|
|
def summary(self) -> Optional[str]: ...
|
|
|
|
@property
|
|
def context_variables(self) -> Optional[ContextVariables]: ...
|
|
|
|
@property
|
|
def last_speaker(self) -> Optional[str]: ...
|
|
|
|
@property
|
|
def cost(self) -> Optional[Cost]: ...
|
|
|
|
def process(self, processor: Optional[EventProcessorProtocol] = None) -> None: ...
|
|
|
|
def set_ui_tools(self, tools: list[Tool]) -> None: ...
|
|
|
|
|
|
class AsyncRunResponseProtocol(RunInfoProtocol, Protocol):
|
|
@property
|
|
def events(self) -> AsyncIterable[BaseEvent]: ...
|
|
|
|
@property
|
|
async def messages(self) -> Iterable[Message]: ...
|
|
|
|
@property
|
|
async def summary(self) -> Optional[str]: ...
|
|
|
|
@property
|
|
async def context_variables(self) -> Optional[ContextVariables]: ...
|
|
|
|
@property
|
|
async def last_speaker(self) -> Optional[str]: ...
|
|
|
|
@property
|
|
async def cost(self) -> Optional[Cost]: ...
|
|
|
|
async def process(self, processor: Optional[AsyncEventProcessorProtocol] = None) -> None: ...
|
|
|
|
def set_ui_tools(self, tools: list[Tool]) -> None: ...
|
|
|
|
|
|
class RunResponse:
|
|
def __init__(self, iostream: ThreadIOStream, agents: list[Agent]):
|
|
self.iostream = iostream
|
|
self.agents = agents
|
|
self._summary: Optional[str] = None
|
|
self._messages: Sequence[LLMMessageType] = []
|
|
self._uuid = uuid4()
|
|
self._context_variables: Optional[ContextVariables] = None
|
|
self._last_speaker: Optional[str] = None
|
|
self._cost: Optional[Cost] = None
|
|
|
|
def _queue_generator(self, q: queue.Queue) -> Iterable[BaseEvent]: # type: ignore[type-arg]
|
|
"""A generator to yield items from the queue until the termination message is found."""
|
|
while True:
|
|
try:
|
|
# Get an item from the queue
|
|
event = q.get(timeout=0.1) # Adjust timeout as needed
|
|
|
|
if isinstance(event, InputRequestEvent):
|
|
event.content.respond = lambda response: self.iostream._output_stream.put(response) # type: ignore[attr-defined]
|
|
|
|
yield event
|
|
|
|
if isinstance(event, RunCompletionEvent):
|
|
self._messages = event.content.history # type: ignore[attr-defined]
|
|
self._last_speaker = event.content.last_speaker # type: ignore[attr-defined]
|
|
self._summary = event.content.summary # type: ignore[attr-defined]
|
|
self._context_variables = event.content.context_variables # type: ignore[attr-defined]
|
|
self.cost = event.content.cost # type: ignore[attr-defined]
|
|
break
|
|
|
|
if isinstance(event, ErrorEvent):
|
|
raise event.content.error # type: ignore[attr-defined]
|
|
except queue.Empty:
|
|
continue # Wait for more items in the queue
|
|
|
|
@property
|
|
def events(self) -> Iterable[BaseEvent]:
|
|
return self._queue_generator(self.iostream.input_stream)
|
|
|
|
@property
|
|
def messages(self) -> Iterable[Message]:
|
|
return self._messages
|
|
|
|
@property
|
|
def summary(self) -> Optional[str]:
|
|
return self._summary
|
|
|
|
@property
|
|
def above_run(self) -> Optional["RunResponseProtocol"]:
|
|
return None
|
|
|
|
@property
|
|
def uuid(self) -> UUID:
|
|
return self._uuid
|
|
|
|
@property
|
|
def context_variables(self) -> Optional[ContextVariables]:
|
|
return self._context_variables
|
|
|
|
@property
|
|
def last_speaker(self) -> Optional[str]:
|
|
return self._last_speaker
|
|
|
|
@property
|
|
def cost(self) -> Optional[Cost]:
|
|
return self._cost
|
|
|
|
@cost.setter
|
|
def cost(self, value: Union[Cost, dict[str, Any]]) -> None:
|
|
if isinstance(value, dict):
|
|
self._cost = Cost.from_raw(value)
|
|
else:
|
|
self._cost = value
|
|
|
|
def process(self, processor: Optional[EventProcessorProtocol] = None) -> None:
|
|
processor = processor or ConsoleEventProcessor()
|
|
processor.process(self)
|
|
|
|
def set_ui_tools(self, tools: list[Tool]) -> None:
|
|
"""Set the UI tools for the agents."""
|
|
for agent in self.agents:
|
|
agent.set_ui_tools(tools)
|
|
|
|
|
|
class AsyncRunResponse:
|
|
def __init__(self, iostream: AsyncThreadIOStream, agents: list[Agent]):
|
|
self.iostream = iostream
|
|
self.agents = agents
|
|
self._summary: Optional[str] = None
|
|
self._messages: Sequence[LLMMessageType] = []
|
|
self._uuid = uuid4()
|
|
self._context_variables: Optional[ContextVariables] = None
|
|
self._last_speaker: Optional[str] = None
|
|
self._cost: Optional[Cost] = None
|
|
|
|
async def _queue_generator(self, q: AsyncQueue[Any]) -> AsyncIterable[BaseEvent]: # type: ignore[type-arg]
|
|
"""A generator to yield items from the queue until the termination message is found."""
|
|
while True:
|
|
try:
|
|
# Get an item from the queue
|
|
event = await q.get()
|
|
|
|
if isinstance(event, InputRequestEvent):
|
|
|
|
async def respond(response: str) -> None:
|
|
await self.iostream._output_stream.put(response)
|
|
|
|
event.content.respond = respond # type: ignore[attr-defined]
|
|
|
|
yield event
|
|
|
|
if isinstance(event, RunCompletionEvent):
|
|
self._messages = event.content.history # type: ignore[attr-defined]
|
|
self._last_speaker = event.content.last_speaker # type: ignore[attr-defined]
|
|
self._summary = event.content.summary # type: ignore[attr-defined]
|
|
self._context_variables = event.content.context_variables # type: ignore[attr-defined]
|
|
self.cost = event.content.cost # type: ignore[attr-defined]
|
|
break
|
|
|
|
if isinstance(event, ErrorEvent):
|
|
raise event.content.error # type: ignore[attr-defined]
|
|
except queue.Empty:
|
|
continue
|
|
|
|
@property
|
|
def events(self) -> AsyncIterable[BaseEvent]:
|
|
return self._queue_generator(self.iostream.input_stream)
|
|
|
|
@property
|
|
async def messages(self) -> Iterable[Message]:
|
|
return self._messages
|
|
|
|
@property
|
|
async def summary(self) -> Optional[str]:
|
|
return self._summary
|
|
|
|
@property
|
|
def above_run(self) -> Optional["RunResponseProtocol"]:
|
|
return None
|
|
|
|
@property
|
|
def uuid(self) -> UUID:
|
|
return self._uuid
|
|
|
|
@property
|
|
async def context_variables(self) -> Optional[ContextVariables]:
|
|
return self._context_variables
|
|
|
|
@property
|
|
async def last_speaker(self) -> Optional[str]:
|
|
return self._last_speaker
|
|
|
|
@property
|
|
async def cost(self) -> Optional[Cost]:
|
|
return self._cost
|
|
|
|
@cost.setter
|
|
def cost(self, value: Union[Cost, dict[str, Any]]) -> None:
|
|
if isinstance(value, dict):
|
|
self._cost = Cost.from_raw(value)
|
|
else:
|
|
self._cost = value
|
|
|
|
async def process(self, processor: Optional[AsyncEventProcessorProtocol] = None) -> None:
|
|
processor = processor or AsyncConsoleEventProcessor()
|
|
await processor.process(self)
|
|
|
|
def set_ui_tools(self, tools: list[Tool]) -> None:
|
|
"""Set the UI tools for the agents."""
|
|
for agent in self.agents:
|
|
agent.set_ui_tools(tools)
|