383 lines
14 KiB
Python
383 lines
14 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import functools
|
|
import json
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Iterable
|
|
from contextvars import ContextVar
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Annotated, Any, Mapping, Optional, Type, TypeVar, Union
|
|
|
|
from httpx import Client as httpxClient
|
|
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, SecretStr, ValidationInfo, field_serializer, field_validator
|
|
|
|
if TYPE_CHECKING:
|
|
from .oai.client import ModelClient
|
|
|
|
_KT = TypeVar("_KT")
|
|
_VT = TypeVar("_VT")
|
|
|
|
__all__ = [
|
|
"LLMConfig",
|
|
"LLMConfigEntry",
|
|
"register_llm_config",
|
|
]
|
|
|
|
|
|
def _add_default_api_type(d: dict[str, Any]) -> dict[str, Any]:
|
|
if "api_type" not in d:
|
|
d["api_type"] = "openai"
|
|
return d
|
|
|
|
|
|
# Meta class to allow LLMConfig.current and LLMConfig.default to be used as class properties
|
|
class MetaLLMConfig(type):
|
|
def __init__(cls, *args: Any, **kwargs: Any) -> None:
|
|
pass
|
|
|
|
@property
|
|
def current(cls) -> "LLMConfig":
|
|
current_llm_config = LLMConfig.get_current_llm_config(llm_config=None)
|
|
if current_llm_config is None:
|
|
raise ValueError("No current LLMConfig set. Are you inside a context block?")
|
|
return current_llm_config # type: ignore[return-value]
|
|
|
|
@property
|
|
def default(cls) -> "LLMConfig":
|
|
return cls.current
|
|
|
|
|
|
class LLMConfig(metaclass=MetaLLMConfig):
|
|
_current_llm_config: ContextVar["LLMConfig"] = ContextVar("current_llm_config")
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
outside_properties = list((self._get_base_model_class()).model_json_schema()["properties"].keys())
|
|
outside_properties.remove("config_list")
|
|
|
|
if "config_list" in kwargs and isinstance(kwargs["config_list"], dict):
|
|
kwargs["config_list"] = [kwargs["config_list"]]
|
|
|
|
modified_kwargs = (
|
|
kwargs
|
|
if "config_list" in kwargs
|
|
else {
|
|
**{
|
|
"config_list": [
|
|
{k: v for k, v in kwargs.items() if k not in outside_properties},
|
|
]
|
|
},
|
|
**{k: v for k, v in kwargs.items() if k in outside_properties},
|
|
}
|
|
)
|
|
|
|
modified_kwargs["config_list"] = [
|
|
_add_default_api_type(v) if isinstance(v, dict) else v for v in modified_kwargs["config_list"]
|
|
]
|
|
for x in ["max_tokens", "top_p"]:
|
|
if x in modified_kwargs:
|
|
modified_kwargs["config_list"] = [{**v, x: modified_kwargs[x]} for v in modified_kwargs["config_list"]]
|
|
modified_kwargs.pop(x)
|
|
|
|
self._model = self._get_base_model_class()(**modified_kwargs)
|
|
|
|
# used by BaseModel to create instance variables
|
|
def __enter__(self) -> "LLMConfig":
|
|
# Store previous context and set self as current
|
|
self._token = LLMConfig._current_llm_config.set(self)
|
|
return self
|
|
|
|
def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None:
|
|
LLMConfig._current_llm_config.reset(self._token)
|
|
|
|
@classmethod
|
|
def get_current_llm_config(cls, llm_config: "Optional[LLMConfig]" = None) -> "Optional[LLMConfig]":
|
|
if llm_config is not None:
|
|
return llm_config
|
|
try:
|
|
return (LLMConfig._current_llm_config.get()).copy()
|
|
except LookupError:
|
|
return None
|
|
|
|
def _satisfies_criteria(self, value: Any, criteria_values: Any) -> bool:
|
|
if value is None:
|
|
return False
|
|
|
|
if isinstance(value, list):
|
|
return bool(set(value) & set(criteria_values)) # Non-empty intersection
|
|
else:
|
|
return value in criteria_values
|
|
|
|
@classmethod
|
|
def from_json(
|
|
cls,
|
|
*,
|
|
env: Optional[str] = None,
|
|
path: Optional[Union[str, Path]] = None,
|
|
file_location: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> "LLMConfig":
|
|
from .oai.openai_utils import config_list_from_json
|
|
|
|
if env is None and path is None:
|
|
raise ValueError("Either 'env' or 'path' must be provided")
|
|
if env is not None and path is not None:
|
|
raise ValueError("Only one of 'env' or 'path' can be provided")
|
|
|
|
config_list = config_list_from_json(
|
|
env_or_file=env if env is not None else str(path), file_location=file_location
|
|
)
|
|
return LLMConfig(config_list=config_list, **kwargs)
|
|
|
|
def where(self, *, exclude: bool = False, **kwargs: Any) -> "LLMConfig":
|
|
from .oai.openai_utils import filter_config
|
|
|
|
filtered_config_list = filter_config(config_list=self.config_list, filter_dict=kwargs, exclude=exclude)
|
|
if len(filtered_config_list) == 0:
|
|
raise ValueError(f"No config found that satisfies the filter criteria: {kwargs}")
|
|
|
|
kwargs = self.model_dump()
|
|
kwargs["config_list"] = filtered_config_list
|
|
|
|
return LLMConfig(**kwargs)
|
|
|
|
# @functools.wraps(BaseModel.model_dump)
|
|
def model_dump(self, *args: Any, exclude_none: bool = True, **kwargs: Any) -> dict[str, Any]:
|
|
d = self._model.model_dump(*args, exclude_none=exclude_none, **kwargs)
|
|
return {k: v for k, v in d.items() if not (isinstance(v, list) and len(v) == 0)}
|
|
|
|
# @functools.wraps(BaseModel.model_dump_json)
|
|
def model_dump_json(self, *args: Any, exclude_none: bool = True, **kwargs: Any) -> str:
|
|
# return self._model.model_dump_json(*args, exclude_none=exclude_none, **kwargs)
|
|
d = self.model_dump(*args, exclude_none=exclude_none, **kwargs)
|
|
return json.dumps(d)
|
|
|
|
# @functools.wraps(BaseModel.model_validate)
|
|
def model_validate(self, *args: Any, **kwargs: Any) -> Any:
|
|
return self._model.model_validate(*args, **kwargs)
|
|
|
|
@functools.wraps(BaseModel.model_validate_json)
|
|
def model_validate_json(self, *args: Any, **kwargs: Any) -> Any:
|
|
return self._model.model_validate_json(*args, **kwargs)
|
|
|
|
@functools.wraps(BaseModel.model_validate_strings)
|
|
def model_validate_strings(self, *args: Any, **kwargs: Any) -> Any:
|
|
return self._model.model_validate_strings(*args, **kwargs)
|
|
|
|
def __eq__(self, value: Any) -> bool:
|
|
return hasattr(value, "_model") and self._model == value._model
|
|
|
|
def _getattr(self, o: object, name: str) -> Any:
|
|
val = getattr(o, name)
|
|
return val
|
|
|
|
def get(self, key: str, default: Optional[Any] = None) -> Any:
|
|
val = getattr(self._model, key, default)
|
|
return val
|
|
|
|
def __getitem__(self, key: str) -> Any:
|
|
try:
|
|
return self._getattr(self._model, key)
|
|
except AttributeError:
|
|
raise KeyError(f"Key '{key}' not found in {self.__class__.__name__}")
|
|
|
|
def __setitem__(self, key: str, value: Any) -> None:
|
|
try:
|
|
setattr(self._model, key, value)
|
|
except ValueError:
|
|
raise ValueError(f"'{self.__class__.__name__}' object has no field '{key}'")
|
|
|
|
def __getattr__(self, name: Any) -> Any:
|
|
try:
|
|
return self._getattr(self._model, name)
|
|
except AttributeError:
|
|
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
|
|
|
def __setattr__(self, name: str, value: Any) -> None:
|
|
if name == "_model":
|
|
object.__setattr__(self, name, value)
|
|
else:
|
|
setattr(self._model, name, value)
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
return hasattr(self._model, key)
|
|
|
|
def __repr__(self) -> str:
|
|
d = self.model_dump()
|
|
r = [f"{k}={repr(v)}" for k, v in d.items()]
|
|
|
|
s = f"LLMConfig({', '.join(r)})"
|
|
# Replace any keys ending with 'key' or 'token' values with stars for security
|
|
s = re.sub(
|
|
r"(['\"])(\w*(key|token))\1:\s*(['\"])([^'\"]*)(?:\4)", r"\1\2\1: \4**********\4", s, flags=re.IGNORECASE
|
|
)
|
|
return s
|
|
|
|
def __copy__(self) -> "LLMConfig":
|
|
return LLMConfig(**self.model_dump())
|
|
|
|
def __deepcopy__(self, memo: Optional[dict[int, Any]] = None) -> "LLMConfig":
|
|
return self.__copy__()
|
|
|
|
def copy(self) -> "LLMConfig":
|
|
return self.__copy__()
|
|
|
|
def deepcopy(self, memo: Optional[dict[int, Any]] = None) -> "LLMConfig":
|
|
return self.__deepcopy__(memo)
|
|
|
|
def __str__(self) -> str:
|
|
return repr(self)
|
|
|
|
def items(self) -> Iterable[tuple[str, Any]]:
|
|
d = self.model_dump()
|
|
return d.items()
|
|
|
|
def keys(self) -> Iterable[str]:
|
|
d = self.model_dump()
|
|
return d.keys()
|
|
|
|
def values(self) -> Iterable[Any]:
|
|
d = self.model_dump()
|
|
return d.values()
|
|
|
|
_base_model_classes: dict[tuple[Type["LLMConfigEntry"], ...], Type[BaseModel]] = {}
|
|
|
|
@classmethod
|
|
def _get_base_model_class(cls) -> Type["BaseModel"]:
|
|
def _get_cls(llm_config_classes: tuple[Type[LLMConfigEntry], ...]) -> Type[BaseModel]:
|
|
if llm_config_classes in LLMConfig._base_model_classes:
|
|
return LLMConfig._base_model_classes[llm_config_classes]
|
|
|
|
class _LLMConfig(BaseModel):
|
|
temperature: Optional[float] = None
|
|
check_every_ms: Optional[int] = None
|
|
max_new_tokens: Optional[int] = None
|
|
seed: Optional[int] = None
|
|
allow_format_str_template: Optional[bool] = None
|
|
response_format: Optional[Union[str, dict[str, Any], BaseModel, Type[BaseModel]]] = None
|
|
timeout: Optional[int] = None
|
|
cache_seed: Optional[int] = None
|
|
|
|
tools: list[Any] = Field(default_factory=list)
|
|
functions: list[Any] = Field(default_factory=list)
|
|
parallel_tool_calls: Optional[bool] = None
|
|
|
|
config_list: Annotated[ # type: ignore[valid-type]
|
|
list[Annotated[Union[llm_config_classes], Field(discriminator="api_type")]],
|
|
Field(default_factory=list, min_length=1),
|
|
]
|
|
|
|
# Following field is configuration for pydantic to disallow extra fields
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
LLMConfig._base_model_classes[llm_config_classes] = _LLMConfig
|
|
|
|
return _LLMConfig
|
|
|
|
return _get_cls(tuple(_llm_config_classes))
|
|
|
|
|
|
class LLMConfigEntry(BaseModel, ABC):
|
|
api_type: str
|
|
model: str = Field(..., min_length=1)
|
|
api_key: Optional[SecretStr] = None
|
|
api_version: Optional[str] = None
|
|
max_tokens: Optional[int] = None
|
|
base_url: Optional[HttpUrl] = None
|
|
voice: Optional[str] = None
|
|
model_client_cls: Optional[str] = None
|
|
http_client: Optional[httpxClient] = None
|
|
response_format: Optional[Union[str, dict[str, Any], BaseModel, Type[BaseModel]]] = None
|
|
default_headers: Optional[Mapping[str, Any]] = None
|
|
tags: list[str] = Field(default_factory=list)
|
|
|
|
# Following field is configuration for pydantic to disallow extra fields
|
|
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
|
|
|
@abstractmethod
|
|
def create_client(self) -> "ModelClient": ...
|
|
|
|
@field_validator("base_url", mode="before")
|
|
@classmethod
|
|
def check_base_url(cls, v: Any, info: ValidationInfo) -> Any:
|
|
if not str(v).startswith("https://") and not str(v).startswith("http://"):
|
|
v = f"http://{str(v)}"
|
|
return v
|
|
|
|
@field_serializer("base_url")
|
|
def serialize_base_url(self, v: Any) -> Any:
|
|
return str(v)
|
|
|
|
@field_serializer("api_key", when_used="unless-none")
|
|
def serialize_api_key(self, v: SecretStr) -> Any:
|
|
return v.get_secret_value()
|
|
|
|
def model_dump(self, *args: Any, exclude_none: bool = True, **kwargs: Any) -> dict[str, Any]:
|
|
return BaseModel.model_dump(self, exclude_none=exclude_none, *args, **kwargs)
|
|
|
|
def model_dump_json(self, *args: Any, exclude_none: bool = True, **kwargs: Any) -> str:
|
|
return BaseModel.model_dump_json(self, exclude_none=exclude_none, *args, **kwargs)
|
|
|
|
def get(self, key: str, default: Optional[Any] = None) -> Any:
|
|
val = getattr(self, key, default)
|
|
if isinstance(val, SecretStr):
|
|
return val.get_secret_value()
|
|
return val
|
|
|
|
def __getitem__(self, key: str) -> Any:
|
|
try:
|
|
val = getattr(self, key)
|
|
if isinstance(val, SecretStr):
|
|
return val.get_secret_value()
|
|
return val
|
|
except AttributeError:
|
|
raise KeyError(f"Key '{key}' not found in {self.__class__.__name__}")
|
|
|
|
def __setitem__(self, key: str, value: Any) -> None:
|
|
setattr(self, key, value)
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
return hasattr(self, key)
|
|
|
|
def items(self) -> Iterable[tuple[str, Any]]:
|
|
d = self.model_dump()
|
|
return d.items()
|
|
|
|
def keys(self) -> Iterable[str]:
|
|
d = self.model_dump()
|
|
return d.keys()
|
|
|
|
def values(self) -> Iterable[Any]:
|
|
d = self.model_dump()
|
|
return d.values()
|
|
|
|
def __repr__(self) -> str:
|
|
# Override to eliminate none values from the repr
|
|
d = self.model_dump()
|
|
r = [f"{k}={repr(v)}" for k, v in d.items()]
|
|
|
|
s = f"{self.__class__.__name__}({', '.join(r)})"
|
|
|
|
# Replace any keys ending with '_key' or '_token' values with stars for security
|
|
# This regex will match any key ending with '_key' or '_token' and its value, and replace the value with stars
|
|
# It also captures the type of quote used (single or double) and reuses it in the replacement
|
|
s = re.sub(r'(\w+_(key|token)\s*=\s*)([\'"]).*?\3', r"\1\3**********\3", s, flags=re.IGNORECASE)
|
|
|
|
return s
|
|
|
|
def __str__(self) -> str:
|
|
return repr(self)
|
|
|
|
|
|
_llm_config_classes: list[Type[LLMConfigEntry]] = []
|
|
|
|
|
|
def register_llm_config(cls: Type[LLMConfigEntry]) -> Type[LLMConfigEntry]:
|
|
if isinstance(cls, type) and issubclass(cls, LLMConfigEntry):
|
|
_llm_config_classes.append(cls)
|
|
else:
|
|
raise TypeError(f"Expected a subclass of LLMConfigEntry, got {cls}")
|
|
return cls
|