524 lines
18 KiB
Python
524 lines
18 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
# SPDX-License-Identifier: MIT
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import threading
|
|
import uuid
|
|
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
|
|
|
from ..doc_utils import export_module
|
|
from .base_logger import BaseLogger, LLMConfig
|
|
from .logger_utils import get_current_ts, to_dict
|
|
|
|
if TYPE_CHECKING:
|
|
from openai import AzureOpenAI, OpenAI
|
|
from openai.types.chat import ChatCompletion
|
|
|
|
from .. import Agent, ConversableAgent, OpenAIWrapper
|
|
from ..oai.anthropic import AnthropicClient
|
|
from ..oai.bedrock import BedrockClient
|
|
from ..oai.cerebras import CerebrasClient
|
|
from ..oai.cohere import CohereClient
|
|
from ..oai.gemini import GeminiClient
|
|
from ..oai.groq import GroqClient
|
|
from ..oai.mistral import MistralAIClient
|
|
from ..oai.ollama import OllamaClient
|
|
from ..oai.together import TogetherClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
lock = threading.Lock()
|
|
|
|
__all__ = ("SqliteLogger",)
|
|
|
|
F = TypeVar("F", bound=Callable[..., Any])
|
|
|
|
|
|
def safe_serialize(obj: Any) -> str:
|
|
"""Safely serialize an object to JSON.
|
|
|
|
Args:
|
|
obj (Any): Object to serialize.
|
|
|
|
Returns:
|
|
str: Serialized object.
|
|
"""
|
|
|
|
def default(o: Any) -> str:
|
|
if hasattr(o, "to_json"):
|
|
return str(o.to_json())
|
|
else:
|
|
return f"<<non-serializable: {type(o).__qualname__}>>"
|
|
|
|
return json.dumps(obj, default=default)
|
|
|
|
|
|
@export_module("autogen.logger")
|
|
class SqliteLogger(BaseLogger):
|
|
"""Sqlite logger class."""
|
|
|
|
schema_version = 1
|
|
|
|
def __init__(self, config: dict[str, Any]):
|
|
"""Initialize the SqliteLogger.
|
|
|
|
Args:
|
|
config (dict[str, Any]): Configuration for the logger.
|
|
"""
|
|
self.config = config
|
|
|
|
try:
|
|
self.dbname = self.config.get("dbname", "logs.db")
|
|
self.con = sqlite3.connect(self.dbname, check_same_thread=False)
|
|
self.cur = self.con.cursor()
|
|
self.session_id = str(uuid.uuid4())
|
|
except sqlite3.Error as e:
|
|
logger.error(f"[SqliteLogger] Failed to connect to database {self.dbname}: {e}")
|
|
|
|
def start(self) -> str:
|
|
try:
|
|
query = """
|
|
CREATE TABLE IF NOT EXISTS chat_completions(
|
|
id INTEGER PRIMARY KEY,
|
|
invocation_id TEXT,
|
|
client_id INTEGER,
|
|
wrapper_id INTEGER,
|
|
session_id TEXT,
|
|
source_name TEXT,
|
|
request TEXT,
|
|
response TEXT,
|
|
is_cached INEGER,
|
|
cost REAL,
|
|
start_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
end_time DATETIME DEFAULT CURRENT_TIMESTAMP)
|
|
"""
|
|
self._run_query(query=query)
|
|
|
|
query = """
|
|
CREATE TABLE IF NOT EXISTS agents (
|
|
id INTEGER PRIMARY KEY, -- Key assigned by the database
|
|
agent_id INTEGER, -- result of python id(agent)
|
|
wrapper_id INTEGER, -- result of python id(agent.client)
|
|
session_id TEXT,
|
|
name TEXT, -- agent.name
|
|
class TEXT, -- type or class name of agent
|
|
init_args TEXT, -- JSON serialization of constructor
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(agent_id, session_id))
|
|
"""
|
|
self._run_query(query=query)
|
|
|
|
query = """
|
|
CREATE TABLE IF NOT EXISTS oai_wrappers (
|
|
id INTEGER PRIMARY KEY, -- Key assigned by the database
|
|
wrapper_id INTEGER, -- result of python id(wrapper)
|
|
session_id TEXT,
|
|
init_args TEXT, -- JSON serialization of constructor
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(wrapper_id, session_id))
|
|
"""
|
|
self._run_query(query=query)
|
|
|
|
query = """
|
|
CREATE TABLE IF NOT EXISTS oai_clients (
|
|
id INTEGER PRIMARY KEY, -- Key assigned by the database
|
|
client_id INTEGER, -- result of python id(client)
|
|
wrapper_id INTEGER, -- result of python id(wrapper)
|
|
session_id TEXT,
|
|
class TEXT, -- type or class name of client
|
|
init_args TEXT, -- JSON serialization of constructor
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(client_id, session_id))
|
|
"""
|
|
self._run_query(query=query)
|
|
|
|
query = """
|
|
CREATE TABLE IF NOT EXISTS version (
|
|
id INTEGER PRIMARY KEY CHECK (id = 1), -- id of the logging database
|
|
version_number INTEGER NOT NULL -- version of the logging database
|
|
);
|
|
"""
|
|
self._run_query(query=query)
|
|
|
|
query = """
|
|
CREATE TABLE IF NOT EXISTS events (
|
|
event_name TEXT,
|
|
source_id INTEGER,
|
|
source_name TEXT,
|
|
agent_module TEXT DEFAULT NULL,
|
|
agent_class_name TEXT DEFAULT NULL,
|
|
id INTEGER PRIMARY KEY,
|
|
json_state TEXT,
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
"""
|
|
self._run_query(query=query)
|
|
|
|
query = """
|
|
CREATE TABLE IF NOT EXISTS function_calls (
|
|
source_id INTEGER,
|
|
source_name TEXT,
|
|
function_name TEXT,
|
|
args TEXT DEFAULT NULL,
|
|
returns TEXT DEFAULT NULL,
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
"""
|
|
self._run_query(query=query)
|
|
|
|
current_version = self._get_current_db_version()
|
|
if current_version is None:
|
|
self._run_query(
|
|
query="INSERT INTO version (id, version_number) VALUES (1, ?);", args=(SqliteLogger.schema_version,)
|
|
)
|
|
self._apply_migration()
|
|
|
|
except sqlite3.Error as e:
|
|
logger.error(f"[SqliteLogger] start logging error: {e}")
|
|
finally:
|
|
return self.session_id
|
|
|
|
def _get_current_db_version(self) -> None | int:
|
|
self.cur.execute("SELECT version_number FROM version ORDER BY id DESC LIMIT 1")
|
|
result = self.cur.fetchone()
|
|
return result[0] if result is not None else None
|
|
|
|
# Example migration script name format: 002_update_agents_table.sql
|
|
def _apply_migration(self, migrations_dir: str = "./migrations") -> None:
|
|
current_version = self._get_current_db_version()
|
|
current_version = SqliteLogger.schema_version if current_version is None else current_version
|
|
|
|
if os.path.isdir(migrations_dir):
|
|
migrations = sorted(os.listdir(migrations_dir))
|
|
else:
|
|
logger.info("no migration scripts, skip...")
|
|
return
|
|
|
|
migrations_to_apply = [m for m in migrations if int(m.split("_")[0]) > current_version]
|
|
|
|
for script in migrations_to_apply:
|
|
with open(script) as f:
|
|
migration_sql = f.read()
|
|
self._run_query_script(script=migration_sql)
|
|
|
|
latest_version = int(script.split("_")[0])
|
|
query = "UPDATE version SET version_number = ? WHERE id = 1"
|
|
args = (latest_version,)
|
|
self._run_query(query=query, args=args)
|
|
|
|
def _run_query(self, query: str, args: tuple[Any, ...] = ()) -> None:
|
|
"""Executes a given SQL query.
|
|
|
|
Args:
|
|
query (str): The SQL query to execute.
|
|
args (Tuple): The arguments to pass to the SQL query.
|
|
"""
|
|
try:
|
|
with lock:
|
|
self.cur.execute(query, args)
|
|
self.con.commit()
|
|
except Exception as e:
|
|
logger.error("[sqlite logger]Error running query with query %s and args %s: %s", query, args, e)
|
|
|
|
def _run_query_script(self, script: str) -> None:
|
|
"""Executes SQL script.
|
|
|
|
Args:
|
|
script (str): SQL script to execute.
|
|
"""
|
|
try:
|
|
with lock:
|
|
self.cur.executescript(script)
|
|
self.con.commit()
|
|
except Exception as e:
|
|
logger.error("[sqlite logger]Error running query script %s: %s", script, e)
|
|
|
|
def log_chat_completion(
|
|
self,
|
|
invocation_id: uuid.UUID,
|
|
client_id: int,
|
|
wrapper_id: int,
|
|
source: str | Agent,
|
|
request: dict[str, float | str | list[dict[str, str]]],
|
|
response: str | ChatCompletion,
|
|
is_cached: int,
|
|
cost: float,
|
|
start_time: str,
|
|
) -> None:
|
|
"""Log chat completion.
|
|
|
|
Args:
|
|
invocation_id (uuid.UUID): Invocation ID.
|
|
client_id (int): Client ID.
|
|
wrapper_id (int): Wrapper ID.
|
|
source (str | Agent): Source of the chat completion.
|
|
request (dict[str, float | str | list[dict[str, str]]]): Request for the chat completion.
|
|
response (str | ChatCompletion): Response for the chat completion.
|
|
is_cached (int): Whether the response is cached.
|
|
cost (float): Cost of the chat completion.
|
|
start_time (str): Start time of the chat completion.
|
|
"""
|
|
if self.con is None:
|
|
return
|
|
|
|
end_time = get_current_ts()
|
|
|
|
if response is None or isinstance(response, str):
|
|
response_messages = json.dumps({"response": response})
|
|
else:
|
|
response_messages = json.dumps(to_dict(response), indent=4)
|
|
|
|
source_name = (
|
|
source
|
|
if isinstance(source, str)
|
|
else source.name
|
|
if hasattr(source, "name") and source.name is not None
|
|
else ""
|
|
)
|
|
|
|
query = """
|
|
INSERT INTO chat_completions (
|
|
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time, source_name
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
"""
|
|
args = (
|
|
invocation_id,
|
|
client_id,
|
|
wrapper_id,
|
|
self.session_id,
|
|
json.dumps(request),
|
|
response_messages,
|
|
is_cached,
|
|
cost,
|
|
start_time,
|
|
end_time,
|
|
source_name,
|
|
)
|
|
|
|
self._run_query(query=query, args=args)
|
|
|
|
def log_new_agent(self, agent: ConversableAgent, init_args: dict[str, Any]) -> None:
|
|
"""Log new agent.
|
|
|
|
Args:
|
|
agent (ConversableAgent): Agent to log.
|
|
init_args (dict[str, Any]): Initialization arguments of the agent
|
|
"""
|
|
from .. import Agent
|
|
|
|
if self.con is None:
|
|
return
|
|
|
|
args = to_dict(
|
|
init_args,
|
|
exclude=(
|
|
"self",
|
|
"__class__",
|
|
"api_key",
|
|
"organization",
|
|
"base_url",
|
|
"azure_endpoint",
|
|
"azure_ad_token",
|
|
"azure_ad_token_provider",
|
|
),
|
|
no_recursive=(Agent,),
|
|
)
|
|
|
|
# We do an upsert since both the superclass and subclass may call this method (in that order)
|
|
query = """
|
|
INSERT INTO agents (agent_id, wrapper_id, session_id, name, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT (agent_id, session_id) DO UPDATE SET
|
|
wrapper_id = excluded.wrapper_id,
|
|
name = excluded.name,
|
|
class = excluded.class,
|
|
init_args = excluded.init_args,
|
|
timestamp = excluded.timestamp
|
|
"""
|
|
args = (
|
|
id(agent),
|
|
agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "",
|
|
self.session_id,
|
|
agent.name if hasattr(agent, "name") and agent.name is not None else "",
|
|
type(agent).__name__,
|
|
json.dumps(args),
|
|
get_current_ts(),
|
|
)
|
|
self._run_query(query=query, args=args)
|
|
|
|
def log_event(self, source: str | Agent, name: str, **kwargs: dict[str, Any]) -> None:
|
|
"""Log event.
|
|
|
|
Args:
|
|
source (str | Agent): Source of the event.
|
|
name (str): Name of the event.
|
|
**kwargs (dict[str, Any]): Additional arguments for the event.
|
|
"""
|
|
from autogen import Agent
|
|
|
|
if self.con is None:
|
|
return
|
|
|
|
json_args = json.dumps(kwargs, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
|
|
|
|
if isinstance(source, Agent):
|
|
query = """
|
|
INSERT INTO events (source_id, source_name, event_name, agent_module, agent_class_name, json_state, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
"""
|
|
args = (
|
|
id(source),
|
|
source.name if hasattr(source, "name") else source,
|
|
name,
|
|
source.__module__,
|
|
source.__class__.__name__,
|
|
json_args,
|
|
get_current_ts(),
|
|
)
|
|
self._run_query(query=query, args=args)
|
|
else:
|
|
query = """
|
|
INSERT INTO events (source_id, source_name, event_name, json_state, timestamp) VALUES (?, ?, ?, ?, ?)
|
|
"""
|
|
args_str_based = (
|
|
id(source),
|
|
source.name if hasattr(source, "name") else source,
|
|
name,
|
|
json_args,
|
|
get_current_ts(),
|
|
)
|
|
self._run_query(query=query, args=args_str_based)
|
|
|
|
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: dict[str, LLMConfig | list[LLMConfig]]) -> None:
|
|
"""Log new wrapper.
|
|
|
|
Args:
|
|
wrapper (OpenAIWrapper): Wrapper to log.
|
|
init_args (dict[str, LLMConfig | list[LLMConfig]]): Initialization arguments of the wrapper
|
|
"""
|
|
if self.con is None:
|
|
return
|
|
|
|
args = to_dict(
|
|
init_args,
|
|
exclude=(
|
|
"self",
|
|
"__class__",
|
|
"api_key",
|
|
"organization",
|
|
"base_url",
|
|
"azure_endpoint",
|
|
"azure_ad_token",
|
|
"azure_ad_token_provider",
|
|
),
|
|
)
|
|
|
|
query = """
|
|
INSERT INTO oai_wrappers (wrapper_id, session_id, init_args, timestamp) VALUES (?, ?, ?, ?)
|
|
ON CONFLICT (wrapper_id, session_id) DO NOTHING;
|
|
"""
|
|
args = (
|
|
id(wrapper),
|
|
self.session_id,
|
|
json.dumps(args),
|
|
get_current_ts(),
|
|
)
|
|
self._run_query(query=query, args=args)
|
|
|
|
def log_function_use(self, source: str | Agent, function: F, args: dict[str, Any], returns: Any) -> None:
|
|
"""Log function use.
|
|
|
|
Args:
|
|
source (str | Agent): Source of the function use.
|
|
function (F): Function to log.
|
|
args (dict[str, Any]): Arguments of the function.
|
|
returns (Any): Returns of the function.
|
|
"""
|
|
if self.con is None:
|
|
return
|
|
|
|
query = """
|
|
INSERT INTO function_calls (source_id, source_name, function_name, args, returns, timestamp) VALUES (?, ?, ?, ?, ?, ?)
|
|
"""
|
|
query_args: tuple[Any, ...] = (
|
|
id(source),
|
|
source.name if hasattr(source, "name") else source,
|
|
function.__name__,
|
|
safe_serialize(args),
|
|
safe_serialize(returns),
|
|
get_current_ts(),
|
|
)
|
|
self._run_query(query=query, args=query_args)
|
|
|
|
def log_new_client(
|
|
self,
|
|
client: (
|
|
AzureOpenAI
|
|
| OpenAI
|
|
| CerebrasClient
|
|
| GeminiClient
|
|
| AnthropicClient
|
|
| MistralAIClient
|
|
| TogetherClient
|
|
| GroqClient
|
|
| CohereClient
|
|
| OllamaClient
|
|
| BedrockClient
|
|
),
|
|
wrapper: OpenAIWrapper,
|
|
init_args: dict[str, Any],
|
|
) -> None:
|
|
"""Log new client.
|
|
|
|
Args:
|
|
client (AzureOpenAI | OpenAI | CerebrasClient | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient | CohereClient | OllamaClient | BedrockClient): Client to log.
|
|
wrapper (OpenAIWrapper): Wrapper of the client.
|
|
init_args (dict[str, Any]): Initialization arguments of the client.
|
|
"""
|
|
if self.con is None:
|
|
return
|
|
|
|
args = to_dict(
|
|
init_args,
|
|
exclude=(
|
|
"self",
|
|
"__class__",
|
|
"api_key",
|
|
"organization",
|
|
"base_url",
|
|
"azure_endpoint",
|
|
"azure_ad_token",
|
|
"azure_ad_token_provider",
|
|
),
|
|
)
|
|
|
|
query = """
|
|
INSERT INTO oai_clients (client_id, wrapper_id, session_id, class, init_args, timestamp) VALUES (?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT (client_id, session_id) DO NOTHING;
|
|
"""
|
|
args = (
|
|
id(client),
|
|
id(wrapper),
|
|
self.session_id,
|
|
type(client).__name__,
|
|
json.dumps(args),
|
|
get_current_ts(),
|
|
)
|
|
self._run_query(query=query, args=args)
|
|
|
|
def stop(self) -> None:
|
|
"""Stop the logger"""
|
|
if self.con:
|
|
self.con.close()
|
|
|
|
def get_connection(self) -> None | sqlite3.Connection:
|
|
"""Get connection."""
|
|
if self.con:
|
|
return self.con
|
|
return None
|