CoACT initialize (#292)
This commit is contained in:
11
mm_agents/coact/autogen/logger/__init__.py
Normal file
11
mm_agents/coact/autogen/logger/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from .file_logger import FileLogger
|
||||
from .logger_factory import LoggerFactory
|
||||
from .sqlite_logger import SqliteLogger
|
||||
|
||||
__all__ = ("FileLogger", "LoggerFactory", "SqliteLogger")
|
||||
128
mm_agents/coact/autogen/logger/base_logger.py
Normal file
128
mm_agents/coact/autogen/logger/base_logger.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from .. import Agent, ConversableAgent, OpenAIWrapper
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
ConfigItem = dict[str, Union[str, list[str]]]
|
||||
LLMConfig = dict[str, Union[None, float, int, ConfigItem, list[ConfigItem]]]
|
||||
|
||||
|
||||
class BaseLogger(ABC):
|
||||
@abstractmethod
|
||||
def start(self) -> str:
|
||||
"""Open a connection to the logging database, and start recording.
|
||||
|
||||
Returns:
|
||||
session_id (str): a unique id for the logging session
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_chat_completion(
|
||||
self,
|
||||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
source: str | Agent,
|
||||
request: dict[str, float | str | list[dict[str, str]]],
|
||||
response: str | ChatCompletion,
|
||||
is_cached: int,
|
||||
cost: float,
|
||||
start_time: str,
|
||||
) -> None:
|
||||
"""Log a chat completion to database.
|
||||
|
||||
In AG2, chat completions are somewhat complicated because they are handled by the `autogen.oai.OpenAIWrapper` class.
|
||||
One invocation to `create` can lead to multiple underlying OpenAI calls, depending on the llm_config list used, and
|
||||
any errors or retries.
|
||||
|
||||
Args:
|
||||
invocation_id (uuid): A unique identifier for the invocation to the OpenAIWrapper.create method call
|
||||
client_id (int): A unique identifier for the underlying OpenAI client instance
|
||||
wrapper_id (int): A unique identifier for the OpenAIWrapper instance
|
||||
source (str or Agent): The source/creator of the event as a string name or an Agent instance
|
||||
request (dict): A dictionary representing the request or call to the OpenAI client endpoint
|
||||
response (str or ChatCompletion): The response from OpenAI
|
||||
is_cached (int): 1 if the response was a cache hit, 0 otherwise
|
||||
cost(float): The cost for OpenAI response
|
||||
start_time (str): A string representing the moment the request was initiated
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_new_agent(self, agent: ConversableAgent, init_args: dict[str, Any]) -> None:
|
||||
"""Log the birth of a new agent.
|
||||
|
||||
Args:
|
||||
agent (ConversableAgent): The agent to log.
|
||||
init_args (dict): The arguments passed to the construct the conversable agent
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_event(self, source: str | Agent, name: str, **kwargs: dict[str, Any]) -> None:
|
||||
"""Log an event for an agent.
|
||||
|
||||
Args:
|
||||
source (str or Agent): The source/creator of the event as a string name or an Agent instance
|
||||
name (str): The name of the event
|
||||
kwargs (dict): The event information to log
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: dict[str, LLMConfig | list[LLMConfig]]) -> None:
|
||||
"""Log the birth of a new OpenAIWrapper.
|
||||
|
||||
Args:
|
||||
wrapper (OpenAIWrapper): The wrapper to log.
|
||||
init_args (dict): The arguments passed to the construct the wrapper
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_new_client(self, client: AzureOpenAI | OpenAI, wrapper: OpenAIWrapper, init_args: dict[str, Any]) -> None:
|
||||
"""Log the birth of a new OpenAIWrapper.
|
||||
|
||||
Args:
|
||||
client: The client to log.
|
||||
wrapper: The wrapper that created the client.
|
||||
init_args: The arguments passed to the construct the client.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_function_use(self, source: str | Agent, function: F, args: dict[str, Any], returns: Any) -> None:
|
||||
"""Log the use of a registered function (could be a tool)
|
||||
|
||||
Args:
|
||||
source (str or Agent): The source/creator of the event as a string name or an Agent instance
|
||||
function (F): The function information
|
||||
args (dict): The function args to log
|
||||
returns (any): The return
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Close the connection to the logging database, and stop logging."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_connection(self) -> None | sqlite3.Connection:
|
||||
"""Return a connection to the logging database."""
|
||||
...
|
||||
261
mm_agents/coact/autogen/logger/file_logger.py
Normal file
261
mm_agents/coact/autogen/logger/file_logger.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from .base_logger import BaseLogger, LLMConfig
|
||||
from .logger_utils import get_current_ts, to_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from .. import Agent, ConversableAgent, OpenAIWrapper
|
||||
from ..oai.anthropic import AnthropicClient
|
||||
from ..oai.bedrock import BedrockClient
|
||||
from ..oai.cerebras import CerebrasClient
|
||||
from ..oai.cohere import CohereClient
|
||||
from ..oai.gemini import GeminiClient
|
||||
from ..oai.groq import GroqClient
|
||||
from ..oai.mistral import MistralAIClient
|
||||
from ..oai.ollama import OllamaClient
|
||||
from ..oai.together import TogetherClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
__all__ = ("FileLogger",)
|
||||
|
||||
|
||||
def safe_serialize(obj: Any) -> str:
|
||||
def default(o: Any) -> str:
|
||||
if hasattr(o, "to_json"):
|
||||
return str(o.to_json())
|
||||
else:
|
||||
return f"<<non-serializable: {type(o).__qualname__}>>"
|
||||
|
||||
return json.dumps(obj, default=default)
|
||||
|
||||
|
||||
@export_module("autogen.logger")
|
||||
class FileLogger(BaseLogger):
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
self.config = config
|
||||
self.session_id = str(uuid.uuid4())
|
||||
|
||||
curr_dir = os.getcwd()
|
||||
self.log_dir = os.path.join(curr_dir, "autogen_logs")
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
|
||||
self.log_file = os.path.join(self.log_dir, self.config.get("filename", "runtime.log"))
|
||||
try:
|
||||
with open(self.log_file, "a"):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"[file_logger] Failed to create logging file: {e}")
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
file_handler = logging.FileHandler(self.log_file)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
def start(self) -> str:
|
||||
"""Start the logger and return the session_id."""
|
||||
try:
|
||||
self.logger.info(f"Started new session with Session ID: {self.session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[file_logger] Failed to create logging file: {e}")
|
||||
finally:
|
||||
return self.session_id
|
||||
|
||||
def log_chat_completion(
|
||||
self,
|
||||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
source: str | Agent,
|
||||
request: dict[str, float | str | list[dict[str, str]]],
|
||||
response: str | ChatCompletion,
|
||||
is_cached: int,
|
||||
cost: float,
|
||||
start_time: str,
|
||||
) -> None:
|
||||
"""Log a chat completion."""
|
||||
thread_id = threading.get_ident()
|
||||
source_name = (
|
||||
source
|
||||
if isinstance(source, str)
|
||||
else source.name
|
||||
if hasattr(source, "name") and source.name is not None
|
||||
else ""
|
||||
)
|
||||
try:
|
||||
log_data = json.dumps({
|
||||
"invocation_id": str(invocation_id),
|
||||
"client_id": client_id,
|
||||
"wrapper_id": wrapper_id,
|
||||
"request": to_dict(request),
|
||||
"response": str(response),
|
||||
"is_cached": is_cached,
|
||||
"cost": cost,
|
||||
"start_time": start_time,
|
||||
"end_time": get_current_ts(),
|
||||
"thread_id": thread_id,
|
||||
"source_name": source_name,
|
||||
})
|
||||
|
||||
self.logger.info(log_data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[file_logger] Failed to log chat completion: {e}")
|
||||
|
||||
def log_new_agent(self, agent: ConversableAgent, init_args: dict[str, Any] = {}) -> None:
|
||||
"""Log a new agent instance."""
|
||||
thread_id = threading.get_ident()
|
||||
|
||||
try:
|
||||
log_data = json.dumps({
|
||||
"id": id(agent),
|
||||
"agent_name": agent.name if hasattr(agent, "name") and agent.name is not None else "",
|
||||
"wrapper_id": to_dict(
|
||||
agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else ""
|
||||
),
|
||||
"session_id": self.session_id,
|
||||
"current_time": get_current_ts(),
|
||||
"agent_type": type(agent).__name__,
|
||||
"args": to_dict(init_args),
|
||||
"thread_id": thread_id,
|
||||
})
|
||||
self.logger.info(log_data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[file_logger] Failed to log new agent: {e}")
|
||||
|
||||
def log_event(self, source: str | Agent, name: str, **kwargs: dict[str, Any]) -> None:
|
||||
"""Log an event from an agent or a string source."""
|
||||
from .. import Agent
|
||||
|
||||
# This takes an object o as input and returns a string. If the object o cannot be serialized, instead of raising an error,
|
||||
# it returns a string indicating that the object is non-serializable, along with its type's qualified name obtained using __qualname__.
|
||||
json_args = json.dumps(kwargs, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
|
||||
thread_id = threading.get_ident()
|
||||
|
||||
if isinstance(source, Agent):
|
||||
try:
|
||||
log_data = json.dumps({
|
||||
"source_id": id(source),
|
||||
"source_name": str(source.name) if hasattr(source, "name") else source,
|
||||
"event_name": name,
|
||||
"agent_module": source.__module__,
|
||||
"agent_class": source.__class__.__name__,
|
||||
"json_state": json_args,
|
||||
"timestamp": get_current_ts(),
|
||||
"thread_id": thread_id,
|
||||
})
|
||||
self.logger.info(log_data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[file_logger] Failed to log event {e}")
|
||||
else:
|
||||
try:
|
||||
log_data = json.dumps({
|
||||
"source_id": id(source),
|
||||
"source_name": str(source.name) if hasattr(source, "name") else source,
|
||||
"event_name": name,
|
||||
"json_state": json_args,
|
||||
"timestamp": get_current_ts(),
|
||||
"thread_id": thread_id,
|
||||
})
|
||||
self.logger.info(log_data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[file_logger] Failed to log event {e}")
|
||||
|
||||
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: dict[str, LLMConfig | list[LLMConfig]] = {}) -> None:
|
||||
"""Log a new wrapper instance."""
|
||||
thread_id = threading.get_ident()
|
||||
|
||||
try:
|
||||
log_data = json.dumps({
|
||||
"wrapper_id": id(wrapper),
|
||||
"session_id": self.session_id,
|
||||
"json_state": json.dumps(init_args),
|
||||
"timestamp": get_current_ts(),
|
||||
"thread_id": thread_id,
|
||||
})
|
||||
self.logger.info(log_data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[file_logger] Failed to log event {e}")
|
||||
|
||||
def log_new_client(
|
||||
self,
|
||||
client: (
|
||||
AzureOpenAI
|
||||
| OpenAI
|
||||
| CerebrasClient
|
||||
| GeminiClient
|
||||
| AnthropicClient
|
||||
| MistralAIClient
|
||||
| TogetherClient
|
||||
| GroqClient
|
||||
| CohereClient
|
||||
| OllamaClient
|
||||
| BedrockClient
|
||||
),
|
||||
wrapper: OpenAIWrapper,
|
||||
init_args: dict[str, Any],
|
||||
) -> None:
|
||||
"""Log a new client instance."""
|
||||
thread_id = threading.get_ident()
|
||||
|
||||
try:
|
||||
log_data = json.dumps({
|
||||
"client_id": id(client),
|
||||
"wrapper_id": id(wrapper),
|
||||
"session_id": self.session_id,
|
||||
"class": type(client).__name__,
|
||||
"json_state": json.dumps(init_args),
|
||||
"timestamp": get_current_ts(),
|
||||
"thread_id": thread_id,
|
||||
})
|
||||
self.logger.info(log_data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[file_logger] Failed to log event {e}")
|
||||
|
||||
def log_function_use(self, source: str | Agent, function: F, args: dict[str, Any], returns: Any) -> None:
|
||||
"""Log a registered function(can be a tool) use from an agent or a string source."""
|
||||
thread_id = threading.get_ident()
|
||||
|
||||
try:
|
||||
log_data = json.dumps({
|
||||
"source_id": id(source),
|
||||
"source_name": str(source.name) if hasattr(source, "name") else source,
|
||||
"agent_module": source.__module__,
|
||||
"agent_class": source.__class__.__name__,
|
||||
"timestamp": get_current_ts(),
|
||||
"thread_id": thread_id,
|
||||
"input_args": safe_serialize(args),
|
||||
"returns": safe_serialize(returns),
|
||||
})
|
||||
self.logger.info(log_data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[file_logger] Failed to log event {e}")
|
||||
|
||||
def get_connection(self) -> None:
|
||||
"""Method is intentionally left blank because there is no specific connection needed for the FileLogger."""
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Close the file handler and remove it from the logger."""
|
||||
for handler in self.logger.handlers:
|
||||
if isinstance(handler, logging.FileHandler):
|
||||
handler.close()
|
||||
self.logger.removeHandler(handler)
|
||||
42
mm_agents/coact/autogen/logger/logger_factory.py
Normal file
42
mm_agents/coact/autogen/logger/logger_factory.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from .base_logger import BaseLogger
|
||||
from .file_logger import FileLogger
|
||||
from .sqlite_logger import SqliteLogger
|
||||
|
||||
__all__ = ("LoggerFactory",)
|
||||
|
||||
|
||||
@export_module("autogen.logger")
|
||||
class LoggerFactory:
|
||||
"""Factory class to create logger objects."""
|
||||
|
||||
@staticmethod
|
||||
def get_logger(
|
||||
logger_type: Literal["sqlite", "file"] = "sqlite", config: Optional[dict[str, Any]] = None
|
||||
) -> BaseLogger:
|
||||
"""Factory method to create logger objects.
|
||||
|
||||
Args:
|
||||
logger_type (Literal["sqlite", "file"], optional): Type of logger. Defaults to "sqlite".
|
||||
config (Optional[dict[str, Any]], optional): Configuration for logger. Defaults to None.
|
||||
|
||||
Returns:
|
||||
BaseLogger: Logger object
|
||||
"""
|
||||
if config is None:
|
||||
config = {}
|
||||
|
||||
if logger_type == "sqlite":
|
||||
return SqliteLogger(config)
|
||||
elif logger_type == "file":
|
||||
return FileLogger(config)
|
||||
else:
|
||||
raise ValueError(f"[logger_factory] Unknown logger type: {logger_type}")
|
||||
57
mm_agents/coact/autogen/logger/logger_utils.py
Normal file
57
mm_agents/coact/autogen/logger/logger_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
import inspect
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Any, Union
|
||||
|
||||
__all__ = ("get_current_ts", "to_dict")
|
||||
|
||||
|
||||
def get_current_ts() -> str:
|
||||
"""Get current timestamp in UTC timezone.
|
||||
|
||||
Returns:
|
||||
str: Current timestamp in UTC timezone
|
||||
"""
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
|
||||
def to_dict(
|
||||
obj: Union[int, float, str, bool, dict[Any, Any], list[Any], tuple[Any, ...], Any],
|
||||
exclude: tuple[str, ...] = (),
|
||||
no_recursive: tuple[Any, ...] = (),
|
||||
) -> Any:
|
||||
"""Convert object to dictionary.
|
||||
|
||||
Args:
|
||||
obj (Union[int, float, str, bool, dict[Any, Any], list[Any], tuple[Any, ...], Any]): Object to convert
|
||||
exclude (tuple[str, ...], optional): Keys to exclude. Defaults to ().
|
||||
no_recursive (tuple[Any, ...], optional): Types to exclude from recursive conversion. Defaults to ().
|
||||
"""
|
||||
if isinstance(obj, (int, float, str, bool)):
|
||||
return obj
|
||||
elif isinstance(obj, (Path, PurePath)):
|
||||
return str(obj)
|
||||
elif callable(obj):
|
||||
return inspect.getsource(obj).strip()
|
||||
elif isinstance(obj, dict):
|
||||
return {
|
||||
str(k): to_dict(str(v)) if isinstance(v, no_recursive) else to_dict(v, exclude, no_recursive)
|
||||
for k, v in obj.items()
|
||||
if k not in exclude
|
||||
}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [to_dict(str(v)) if isinstance(v, no_recursive) else to_dict(v, exclude, no_recursive) for v in obj]
|
||||
elif hasattr(obj, "__dict__"):
|
||||
return {
|
||||
str(k): to_dict(str(v)) if isinstance(v, no_recursive) else to_dict(v, exclude, no_recursive)
|
||||
for k, v in vars(obj).items()
|
||||
if k not in exclude
|
||||
}
|
||||
else:
|
||||
return obj
|
||||
523
mm_agents/coact/autogen/logger/sqlite_logger.py
Normal file
523
mm_agents/coact/autogen/logger/sqlite_logger.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
|
||||
from ..doc_utils import export_module
|
||||
from .base_logger import BaseLogger, LLMConfig
|
||||
from .logger_utils import get_current_ts, to_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from .. import Agent, ConversableAgent, OpenAIWrapper
|
||||
from ..oai.anthropic import AnthropicClient
|
||||
from ..oai.bedrock import BedrockClient
|
||||
from ..oai.cerebras import CerebrasClient
|
||||
from ..oai.cohere import CohereClient
|
||||
from ..oai.gemini import GeminiClient
|
||||
from ..oai.groq import GroqClient
|
||||
from ..oai.mistral import MistralAIClient
|
||||
from ..oai.ollama import OllamaClient
|
||||
from ..oai.together import TogetherClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
lock = threading.Lock()
|
||||
|
||||
__all__ = ("SqliteLogger",)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def safe_serialize(obj: Any) -> str:
|
||||
"""Safely serialize an object to JSON.
|
||||
|
||||
Args:
|
||||
obj (Any): Object to serialize.
|
||||
|
||||
Returns:
|
||||
str: Serialized object.
|
||||
"""
|
||||
|
||||
def default(o: Any) -> str:
|
||||
if hasattr(o, "to_json"):
|
||||
return str(o.to_json())
|
||||
else:
|
||||
return f"<<non-serializable: {type(o).__qualname__}>>"
|
||||
|
||||
return json.dumps(obj, default=default)
|
||||
|
||||
|
||||
@export_module("autogen.logger")
|
||||
class SqliteLogger(BaseLogger):
|
||||
"""Sqlite logger class."""
|
||||
|
||||
schema_version = 1
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
"""Initialize the SqliteLogger.
|
||||
|
||||
Args:
|
||||
config (dict[str, Any]): Configuration for the logger.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
try:
|
||||
self.dbname = self.config.get("dbname", "logs.db")
|
||||
self.con = sqlite3.connect(self.dbname, check_same_thread=False)
|
||||
self.cur = self.con.cursor()
|
||||
self.session_id = str(uuid.uuid4())
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"[SqliteLogger] Failed to connect to database {self.dbname}: {e}")
|
||||
|
||||
def start(self) -> str:
|
||||
try:
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS chat_completions(
|
||||
id INTEGER PRIMARY KEY,
|
||||
invocation_id TEXT,
|
||||
client_id INTEGER,
|
||||
wrapper_id INTEGER,
|
||||
session_id TEXT,
|
||||
source_name TEXT,
|
||||
request TEXT,
|
||||
response TEXT,
|
||||
is_cached INEGER,
|
||||
cost REAL,
|
||||
start_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
end_time DATETIME DEFAULT CURRENT_TIMESTAMP)
|
||||
"""
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS agents (
|
||||
id INTEGER PRIMARY KEY, -- Key assigned by the database
|
||||
agent_id INTEGER, -- result of python id(agent)
|
||||
wrapper_id INTEGER, -- result of python id(agent.client)
|
||||
session_id TEXT,
|
||||
name TEXT, -- agent.name
|
||||
class TEXT, -- type or class name of agent
|
||||
init_args TEXT, -- JSON serialization of constructor
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(agent_id, session_id))
|
||||
"""
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS oai_wrappers (
|
||||
id INTEGER PRIMARY KEY, -- Key assigned by the database
|
||||
wrapper_id INTEGER, -- result of python id(wrapper)
|
||||
session_id TEXT,
|
||||
init_args TEXT, -- JSON serialization of constructor
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(wrapper_id, session_id))
|
||||
"""
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS oai_clients (
|
||||
id INTEGER PRIMARY KEY, -- Key assigned by the database
|
||||
client_id INTEGER, -- result of python id(client)
|
||||
wrapper_id INTEGER, -- result of python id(wrapper)
|
||||
session_id TEXT,
|
||||
class TEXT, -- type or class name of client
|
||||
init_args TEXT, -- JSON serialization of constructor
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(client_id, session_id))
|
||||
"""
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS version (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1), -- id of the logging database
|
||||
version_number INTEGER NOT NULL -- version of the logging database
|
||||
);
|
||||
"""
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
event_name TEXT,
|
||||
source_id INTEGER,
|
||||
source_name TEXT,
|
||||
agent_module TEXT DEFAULT NULL,
|
||||
agent_class_name TEXT DEFAULT NULL,
|
||||
id INTEGER PRIMARY KEY,
|
||||
json_state TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
"""
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS function_calls (
|
||||
source_id INTEGER,
|
||||
source_name TEXT,
|
||||
function_name TEXT,
|
||||
args TEXT DEFAULT NULL,
|
||||
returns TEXT DEFAULT NULL,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
"""
|
||||
self._run_query(query=query)
|
||||
|
||||
current_version = self._get_current_db_version()
|
||||
if current_version is None:
|
||||
self._run_query(
|
||||
query="INSERT INTO version (id, version_number) VALUES (1, ?);", args=(SqliteLogger.schema_version,)
|
||||
)
|
||||
self._apply_migration()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"[SqliteLogger] start logging error: {e}")
|
||||
finally:
|
||||
return self.session_id
|
||||
|
||||
def _get_current_db_version(self) -> None | int:
|
||||
self.cur.execute("SELECT version_number FROM version ORDER BY id DESC LIMIT 1")
|
||||
result = self.cur.fetchone()
|
||||
return result[0] if result is not None else None
|
||||
|
||||
# Example migration script name format: 002_update_agents_table.sql
|
||||
def _apply_migration(self, migrations_dir: str = "./migrations") -> None:
|
||||
current_version = self._get_current_db_version()
|
||||
current_version = SqliteLogger.schema_version if current_version is None else current_version
|
||||
|
||||
if os.path.isdir(migrations_dir):
|
||||
migrations = sorted(os.listdir(migrations_dir))
|
||||
else:
|
||||
logger.info("no migration scripts, skip...")
|
||||
return
|
||||
|
||||
migrations_to_apply = [m for m in migrations if int(m.split("_")[0]) > current_version]
|
||||
|
||||
for script in migrations_to_apply:
|
||||
with open(script) as f:
|
||||
migration_sql = f.read()
|
||||
self._run_query_script(script=migration_sql)
|
||||
|
||||
latest_version = int(script.split("_")[0])
|
||||
query = "UPDATE version SET version_number = ? WHERE id = 1"
|
||||
args = (latest_version,)
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def _run_query(self, query: str, args: tuple[Any, ...] = ()) -> None:
|
||||
"""Executes a given SQL query.
|
||||
|
||||
Args:
|
||||
query (str): The SQL query to execute.
|
||||
args (Tuple): The arguments to pass to the SQL query.
|
||||
"""
|
||||
try:
|
||||
with lock:
|
||||
self.cur.execute(query, args)
|
||||
self.con.commit()
|
||||
except Exception as e:
|
||||
logger.error("[sqlite logger]Error running query with query %s and args %s: %s", query, args, e)
|
||||
|
||||
def _run_query_script(self, script: str) -> None:
|
||||
"""Executes SQL script.
|
||||
|
||||
Args:
|
||||
script (str): SQL script to execute.
|
||||
"""
|
||||
try:
|
||||
with lock:
|
||||
self.cur.executescript(script)
|
||||
self.con.commit()
|
||||
except Exception as e:
|
||||
logger.error("[sqlite logger]Error running query script %s: %s", script, e)
|
||||
|
||||
def log_chat_completion(
|
||||
self,
|
||||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
source: str | Agent,
|
||||
request: dict[str, float | str | list[dict[str, str]]],
|
||||
response: str | ChatCompletion,
|
||||
is_cached: int,
|
||||
cost: float,
|
||||
start_time: str,
|
||||
) -> None:
|
||||
"""Log chat completion.
|
||||
|
||||
Args:
|
||||
invocation_id (uuid.UUID): Invocation ID.
|
||||
client_id (int): Client ID.
|
||||
wrapper_id (int): Wrapper ID.
|
||||
source (str | Agent): Source of the chat completion.
|
||||
request (dict[str, float | str | list[dict[str, str]]]): Request for the chat completion.
|
||||
response (str | ChatCompletion): Response for the chat completion.
|
||||
is_cached (int): Whether the response is cached.
|
||||
cost (float): Cost of the chat completion.
|
||||
start_time (str): Start time of the chat completion.
|
||||
"""
|
||||
if self.con is None:
|
||||
return
|
||||
|
||||
end_time = get_current_ts()
|
||||
|
||||
if response is None or isinstance(response, str):
|
||||
response_messages = json.dumps({"response": response})
|
||||
else:
|
||||
response_messages = json.dumps(to_dict(response), indent=4)
|
||||
|
||||
source_name = (
|
||||
source
|
||||
if isinstance(source, str)
|
||||
else source.name
|
||||
if hasattr(source, "name") and source.name is not None
|
||||
else ""
|
||||
)
|
||||
|
||||
query = """
|
||||
INSERT INTO chat_completions (
|
||||
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time, source_name
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
args = (
|
||||
invocation_id,
|
||||
client_id,
|
||||
wrapper_id,
|
||||
self.session_id,
|
||||
json.dumps(request),
|
||||
response_messages,
|
||||
is_cached,
|
||||
cost,
|
||||
start_time,
|
||||
end_time,
|
||||
source_name,
|
||||
)
|
||||
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def log_new_agent(self, agent: ConversableAgent, init_args: dict[str, Any]) -> None:
|
||||
"""Log new agent.
|
||||
|
||||
Args:
|
||||
agent (ConversableAgent): Agent to log.
|
||||
init_args (dict[str, Any]): Initialization arguments of the agent
|
||||
"""
|
||||
from .. import Agent
|
||||
|
||||
if self.con is None:
|
||||
return
|
||||
|
||||
args = to_dict(
|
||||
init_args,
|
||||
exclude=(
|
||||
"self",
|
||||
"__class__",
|
||||
"api_key",
|
||||
"organization",
|
||||
"base_url",
|
||||
"azure_endpoint",
|
||||
"azure_ad_token",
|
||||
"azure_ad_token_provider",
|
||||
),
|
||||
no_recursive=(Agent,),
|
||||
)
|
||||
|
||||
# We do an upsert since both the superclass and subclass may call this method (in that order)
|
||||
query = """
|
||||
INSERT INTO agents (agent_id, wrapper_id, session_id, name, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (agent_id, session_id) DO UPDATE SET
|
||||
wrapper_id = excluded.wrapper_id,
|
||||
name = excluded.name,
|
||||
class = excluded.class,
|
||||
init_args = excluded.init_args,
|
||||
timestamp = excluded.timestamp
|
||||
"""
|
||||
args = (
|
||||
id(agent),
|
||||
agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "",
|
||||
self.session_id,
|
||||
agent.name if hasattr(agent, "name") and agent.name is not None else "",
|
||||
type(agent).__name__,
|
||||
json.dumps(args),
|
||||
get_current_ts(),
|
||||
)
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def log_event(self, source: str | Agent, name: str, **kwargs: dict[str, Any]) -> None:
|
||||
"""Log event.
|
||||
|
||||
Args:
|
||||
source (str | Agent): Source of the event.
|
||||
name (str): Name of the event.
|
||||
**kwargs (dict[str, Any]): Additional arguments for the event.
|
||||
"""
|
||||
from autogen import Agent
|
||||
|
||||
if self.con is None:
|
||||
return
|
||||
|
||||
json_args = json.dumps(kwargs, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
|
||||
|
||||
if isinstance(source, Agent):
|
||||
query = """
|
||||
INSERT INTO events (source_id, source_name, event_name, agent_module, agent_class_name, json_state, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
args = (
|
||||
id(source),
|
||||
source.name if hasattr(source, "name") else source,
|
||||
name,
|
||||
source.__module__,
|
||||
source.__class__.__name__,
|
||||
json_args,
|
||||
get_current_ts(),
|
||||
)
|
||||
self._run_query(query=query, args=args)
|
||||
else:
|
||||
query = """
|
||||
INSERT INTO events (source_id, source_name, event_name, json_state, timestamp) VALUES (?, ?, ?, ?, ?)
|
||||
"""
|
||||
args_str_based = (
|
||||
id(source),
|
||||
source.name if hasattr(source, "name") else source,
|
||||
name,
|
||||
json_args,
|
||||
get_current_ts(),
|
||||
)
|
||||
self._run_query(query=query, args=args_str_based)
|
||||
|
||||
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: dict[str, LLMConfig | list[LLMConfig]]) -> None:
|
||||
"""Log new wrapper.
|
||||
|
||||
Args:
|
||||
wrapper (OpenAIWrapper): Wrapper to log.
|
||||
init_args (dict[str, LLMConfig | list[LLMConfig]]): Initialization arguments of the wrapper
|
||||
"""
|
||||
if self.con is None:
|
||||
return
|
||||
|
||||
args = to_dict(
|
||||
init_args,
|
||||
exclude=(
|
||||
"self",
|
||||
"__class__",
|
||||
"api_key",
|
||||
"organization",
|
||||
"base_url",
|
||||
"azure_endpoint",
|
||||
"azure_ad_token",
|
||||
"azure_ad_token_provider",
|
||||
),
|
||||
)
|
||||
|
||||
query = """
|
||||
INSERT INTO oai_wrappers (wrapper_id, session_id, init_args, timestamp) VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT (wrapper_id, session_id) DO NOTHING;
|
||||
"""
|
||||
args = (
|
||||
id(wrapper),
|
||||
self.session_id,
|
||||
json.dumps(args),
|
||||
get_current_ts(),
|
||||
)
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def log_function_use(self, source: str | Agent, function: F, args: dict[str, Any], returns: Any) -> None:
|
||||
"""Log function use.
|
||||
|
||||
Args:
|
||||
source (str | Agent): Source of the function use.
|
||||
function (F): Function to log.
|
||||
args (dict[str, Any]): Arguments of the function.
|
||||
returns (Any): Returns of the function.
|
||||
"""
|
||||
if self.con is None:
|
||||
return
|
||||
|
||||
query = """
|
||||
INSERT INTO function_calls (source_id, source_name, function_name, args, returns, timestamp) VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
query_args: tuple[Any, ...] = (
|
||||
id(source),
|
||||
source.name if hasattr(source, "name") else source,
|
||||
function.__name__,
|
||||
safe_serialize(args),
|
||||
safe_serialize(returns),
|
||||
get_current_ts(),
|
||||
)
|
||||
self._run_query(query=query, args=query_args)
|
||||
|
||||
def log_new_client(
|
||||
self,
|
||||
client: (
|
||||
AzureOpenAI
|
||||
| OpenAI
|
||||
| CerebrasClient
|
||||
| GeminiClient
|
||||
| AnthropicClient
|
||||
| MistralAIClient
|
||||
| TogetherClient
|
||||
| GroqClient
|
||||
| CohereClient
|
||||
| OllamaClient
|
||||
| BedrockClient
|
||||
),
|
||||
wrapper: OpenAIWrapper,
|
||||
init_args: dict[str, Any],
|
||||
) -> None:
|
||||
"""Log new client.
|
||||
|
||||
Args:
|
||||
client (AzureOpenAI | OpenAI | CerebrasClient | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient | CohereClient | OllamaClient | BedrockClient): Client to log.
|
||||
wrapper (OpenAIWrapper): Wrapper of the client.
|
||||
init_args (dict[str, Any]): Initialization arguments of the client.
|
||||
"""
|
||||
if self.con is None:
|
||||
return
|
||||
|
||||
args = to_dict(
|
||||
init_args,
|
||||
exclude=(
|
||||
"self",
|
||||
"__class__",
|
||||
"api_key",
|
||||
"organization",
|
||||
"base_url",
|
||||
"azure_endpoint",
|
||||
"azure_ad_token",
|
||||
"azure_ad_token_provider",
|
||||
),
|
||||
)
|
||||
|
||||
query = """
|
||||
INSERT INTO oai_clients (client_id, wrapper_id, session_id, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (client_id, session_id) DO NOTHING;
|
||||
"""
|
||||
args = (
|
||||
id(client),
|
||||
id(wrapper),
|
||||
self.session_id,
|
||||
type(client).__name__,
|
||||
json.dumps(args),
|
||||
get_current_ts(),
|
||||
)
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the logger"""
|
||||
if self.con:
|
||||
self.con.close()
|
||||
|
||||
def get_connection(self) -> None | sqlite3.Connection:
|
||||
"""Get connection."""
|
||||
if self.con:
|
||||
return self.con
|
||||
return None
|
||||
Reference in New Issue
Block a user