CoACT initialize (#292)

This commit is contained in:
Linxin Song
2025-07-30 19:35:20 -07:00
committed by GitHub
parent 862d704b8c
commit b968155757
228 changed files with 42386 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from .dependencies import Provider, dependency_provider
from .use import Depends, inject
__all__ = (
"Depends",
"Provider",
"dependency_provider",
"inject",
)

View File

@@ -0,0 +1,80 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
import sys
from importlib.metadata import version as get_version
from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, create_model
from pydantic.version import VERSION as PYDANTIC_VERSION
__all__ = (
"PYDANTIC_V2",
"BaseModel",
"ConfigDict",
"ExceptionGroup",
"create_model",
"evaluate_forwardref",
"get_config_base",
)
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
default_pydantic_config = {"arbitrary_types_allowed": True}
evaluate_forwardref: Any
# isort: off
if PYDANTIC_V2:
from pydantic import ConfigDict
from pydantic._internal._typing_extra import ( # type: ignore[no-redef]
eval_type_lenient as evaluate_forwardref,
)
def model_schema(model: Type[BaseModel]) -> Dict[str, Any]:
return model.model_json_schema()
def get_config_base(config_data: Optional[ConfigDict] = None) -> ConfigDict:
return config_data or ConfigDict(**default_pydantic_config) # type: ignore[typeddict-item]
def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]:
return tuple(f.alias or name for name, f in model.model_fields.items())
class CreateBaseModel(BaseModel):
"""Just to support FastStream < 0.3.7."""
model_config = ConfigDict(arbitrary_types_allowed=True)
else:
from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef]
from pydantic.config import get_config, ConfigDict, BaseConfig
def get_config_base(config_data: Optional[ConfigDict] = None) -> Type[BaseConfig]: # type: ignore[misc,no-any-unimported]
return get_config(config_data or ConfigDict(**default_pydantic_config)) # type: ignore[typeddict-item,no-any-unimported,no-any-return]
def model_schema(model: Type[BaseModel]) -> Dict[str, Any]:
return model.schema()
def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]:
return tuple(f.alias or name for name, f in model.__fields__.items()) # type: ignore[attr-defined]
class CreateBaseModel(BaseModel): # type: ignore[no-redef]
"""Just to support FastStream < 0.3.7."""
class Config:
arbitrary_types_allowed = True
ANYIO_V3 = get_version("anyio").startswith("3.")
if ANYIO_V3:
from anyio import ExceptionGroup as ExceptionGroup # type: ignore[attr-defined]
else:
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup as ExceptionGroup
else:
ExceptionGroup = ExceptionGroup

View File

@@ -0,0 +1,14 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from .build import build_call_model
from .model import CallModel
__all__ = (
"CallModel",
"build_call_model",
)

View File

@@ -0,0 +1,225 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
import inspect
from copy import deepcopy
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from typing_extensions import (
Annotated,
ParamSpec,
get_args,
get_origin,
)
from .._compat import ConfigDict, create_model, get_config_base
from ..dependencies import Depends
from ..library import CustomField
from ..utils import (
get_typed_signature,
is_async_gen_callable,
is_coroutine_callable,
is_gen_callable,
)
from .model import CallModel, ResponseModel
CUSTOM_ANNOTATIONS = (Depends, CustomField)
P = ParamSpec("P")
T = TypeVar("T")
def build_call_model(
call: Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
*,
cast: bool = True,
use_cache: bool = True,
is_sync: Optional[bool] = None,
extra_dependencies: Sequence[Depends] = (),
pydantic_config: Optional[ConfigDict] = None,
) -> CallModel[P, T]:
name = getattr(call, "__name__", type(call).__name__)
is_call_async = is_coroutine_callable(call) or is_async_gen_callable(call)
if is_sync is None:
is_sync = not is_call_async
else:
assert not (is_sync and is_call_async), f"You cannot use async dependency `{name}` at sync main"
typed_params, return_annotation = get_typed_signature(call)
if (is_call_generator := is_gen_callable(call) or is_async_gen_callable(call)) and (
return_args := get_args(return_annotation)
):
return_annotation = return_args[0]
class_fields: Dict[str, Tuple[Any, Any]] = {}
dependencies: Dict[str, CallModel[..., Any]] = {}
custom_fields: Dict[str, CustomField] = {}
positional_args: List[str] = []
keyword_args: List[str] = []
var_positional_arg: Optional[str] = None
var_keyword_arg: Optional[str] = None
for param_name, param in typed_params.parameters.items():
dep: Optional[Depends] = None
custom: Optional[CustomField] = None
if param.annotation is inspect.Parameter.empty:
annotation = Any
elif get_origin(param.annotation) is Annotated:
annotated_args = get_args(param.annotation)
type_annotation = annotated_args[0]
custom_annotations = []
regular_annotations = []
for arg in annotated_args[1:]:
if isinstance(arg, CUSTOM_ANNOTATIONS):
custom_annotations.append(arg)
else:
regular_annotations.append(arg)
assert len(custom_annotations) <= 1, (
f"Cannot specify multiple `Annotated` Custom arguments for `{param_name}`!"
)
next_custom = next(iter(custom_annotations), None)
if next_custom is not None:
if isinstance(next_custom, Depends):
dep = next_custom
elif isinstance(next_custom, CustomField):
custom = deepcopy(next_custom)
else: # pragma: no cover
raise AssertionError("unreachable")
annotation = param.annotation if regular_annotations else type_annotation
else:
annotation = param.annotation
else:
annotation = param.annotation
default: Any
if param.kind == inspect.Parameter.VAR_POSITIONAL:
default = ()
var_positional_arg = param_name
elif param.kind == inspect.Parameter.VAR_KEYWORD:
default = {}
var_keyword_arg = param_name
elif param.default is inspect.Parameter.empty:
default = Ellipsis
else:
default = param.default
if isinstance(default, Depends):
assert not dep, "You can not use `Depends` with `Annotated` and default both"
dep, default = default, Ellipsis
elif isinstance(default, CustomField):
assert not custom, "You can not use `CustomField` with `Annotated` and default both"
custom, default = default, Ellipsis
else:
class_fields[param_name] = (annotation, default)
if dep:
if not cast:
dep.cast = False
dependencies[param_name] = build_call_model(
dep.dependency,
cast=dep.cast,
use_cache=dep.use_cache,
is_sync=is_sync,
pydantic_config=pydantic_config,
)
if dep.cast is True:
class_fields[param_name] = (annotation, Ellipsis)
keyword_args.append(param_name)
elif custom:
assert not (is_sync and is_coroutine_callable(custom.use)), (
f"You cannot use async custom field `{type(custom).__name__}` at sync `{name}`"
)
custom.set_param_name(param_name)
custom_fields[param_name] = custom
if custom.cast is False:
annotation = Any
if custom.required:
class_fields[param_name] = (annotation, default)
else:
class_fields[param_name] = class_fields.get(param_name, (Optional[annotation], None))
keyword_args.append(param_name)
else:
if param.kind is param.KEYWORD_ONLY:
keyword_args.append(param_name)
elif param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
positional_args.append(param_name)
func_model = create_model( # type: ignore[call-overload]
name,
__config__=get_config_base(pydantic_config),
**class_fields,
)
response_model: Optional[Type[ResponseModel[T]]] = None
if cast and return_annotation and return_annotation is not inspect.Parameter.empty:
response_model = create_model( # type: ignore[call-overload,assignment]
"ResponseModel",
__config__=get_config_base(pydantic_config),
response=(return_annotation, Ellipsis),
)
return CallModel(
call=call,
model=func_model,
response_model=response_model,
params=class_fields,
cast=cast,
use_cache=use_cache,
is_async=is_call_async,
is_generator=is_call_generator,
dependencies=dependencies,
custom_fields=custom_fields,
positional_args=positional_args,
keyword_args=keyword_args,
var_positional_arg=var_positional_arg,
var_keyword_arg=var_keyword_arg,
extra_dependencies=[
build_call_model(
d.dependency,
cast=d.cast,
use_cache=d.use_cache,
is_sync=is_sync,
pydantic_config=pydantic_config,
)
for d in extra_dependencies
],
)

View File

@@ -0,0 +1,576 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from collections import namedtuple
from contextlib import AsyncExitStack, ExitStack
from functools import partial
from inspect import Parameter, unwrap
from itertools import chain
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generator,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
import anyio
from typing_extensions import ParamSpec
from .._compat import BaseModel, ExceptionGroup, get_aliases
from ..library import CustomField
from ..utils import (
async_map,
is_async_gen_callable,
is_coroutine_callable,
is_gen_callable,
run_async,
solve_generator_async,
solve_generator_sync,
)
P = ParamSpec("P")
T = TypeVar("T")
PriorityPair = namedtuple("PriorityPair", ("call", "dependencies_number", "dependencies_names"))
class ResponseModel(BaseModel, Generic[T]):
response: T
class CallModel(Generic[P, T]):
call: Union[
Callable[P, T],
Callable[P, Awaitable[T]],
]
is_async: bool
is_generator: bool
model: Optional[Type[BaseModel]]
response_model: Optional[Type[ResponseModel[T]]]
params: Dict[str, Tuple[Any, Any]]
alias_arguments: Tuple[str, ...]
dependencies: Dict[str, "CallModel[..., Any]"]
extra_dependencies: Iterable["CallModel[..., Any]"]
sorted_dependencies: Tuple[Tuple["CallModel[..., Any]", int], ...]
custom_fields: Dict[str, CustomField]
keyword_args: Tuple[str, ...]
positional_args: Tuple[str, ...]
var_positional_arg: Optional[str]
var_keyword_arg: Optional[str]
# Dependencies and custom fields
use_cache: bool
cast: bool
__slots__ = (
"call",
"is_async",
"is_generator",
"model",
"response_model",
"params",
"alias_arguments",
"keyword_args",
"positional_args",
"var_positional_arg",
"var_keyword_arg",
"dependencies",
"extra_dependencies",
"sorted_dependencies",
"custom_fields",
"use_cache",
"cast",
)
@property
def call_name(self) -> str:
call = unwrap(self.call)
return getattr(call, "__name__", type(call).__name__)
@property
def flat_params(self) -> Dict[str, Tuple[Any, Any]]:
params = self.params
for d in (*self.dependencies.values(), *self.extra_dependencies):
params.update(d.flat_params)
return params
@property
def flat_dependencies(
self,
) -> Dict[
Callable[..., Any],
Tuple[
"CallModel[..., Any]",
Tuple[Callable[..., Any], ...],
],
]:
flat: Dict[
Callable[..., Any],
Tuple[
CallModel[..., Any],
Tuple[Callable[..., Any], ...],
],
] = {}
for i in (*self.dependencies.values(), *self.extra_dependencies):
flat.update({
i.call: (
i,
tuple(j.call for j in i.dependencies.values()),
)
})
flat.update(i.flat_dependencies)
return flat
def __init__(
self,
/,
call: Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
model: Optional[Type[BaseModel]],
params: Dict[str, Tuple[Any, Any]],
response_model: Optional[Type[ResponseModel[T]]] = None,
use_cache: bool = True,
cast: bool = True,
is_async: bool = False,
is_generator: bool = False,
dependencies: Optional[Dict[str, "CallModel[..., Any]"]] = None,
extra_dependencies: Optional[Iterable["CallModel[..., Any]"]] = None,
keyword_args: Optional[List[str]] = None,
positional_args: Optional[List[str]] = None,
var_positional_arg: Optional[str] = None,
var_keyword_arg: Optional[str] = None,
custom_fields: Optional[Dict[str, CustomField]] = None,
):
self.call = call
self.model = model
if model:
self.alias_arguments = get_aliases(model)
else: # pragma: no cover
self.alias_arguments = ()
self.keyword_args = tuple(keyword_args or ())
self.positional_args = tuple(positional_args or ())
self.var_positional_arg = var_positional_arg
self.var_keyword_arg = var_keyword_arg
self.response_model = response_model
self.use_cache = use_cache
self.cast = cast
self.is_async = is_async or is_coroutine_callable(call) or is_async_gen_callable(call)
self.is_generator = is_generator or is_gen_callable(call) or is_async_gen_callable(call)
self.dependencies = dependencies or {}
self.extra_dependencies = extra_dependencies or ()
self.custom_fields = custom_fields or {}
sorted_dep: List[CallModel[..., Any]] = []
flat = self.flat_dependencies
for calls in flat.values():
_sort_dep(sorted_dep, calls, flat)
self.sorted_dependencies = tuple((i, len(i.sorted_dependencies)) for i in sorted_dep if i.use_cache)
for name in chain(self.dependencies.keys(), self.custom_fields.keys()):
params.pop(name, None)
self.params = params
def _solve(
self,
/,
*args: Tuple[Any, ...],
cache_dependencies: Dict[
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
T,
],
dependency_overrides: Optional[
Dict[
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
]
] = None,
**kwargs: Dict[str, Any],
) -> Generator[
Tuple[
Sequence[Any],
Dict[str, Any],
Callable[..., Any],
],
Any,
T,
]:
if dependency_overrides:
call = dependency_overrides.get(self.call, self.call)
assert self.is_async or not is_coroutine_callable(call), (
f"You cannot use async dependency `{self.call_name}` at sync main"
)
else:
call = self.call
if self.use_cache and call in cache_dependencies:
return cache_dependencies[call]
kw: Dict[str, Any] = {}
for arg in self.keyword_args:
if (v := kwargs.pop(arg, Parameter.empty)) is not Parameter.empty:
kw[arg] = v
if self.var_keyword_arg is not None:
kw[self.var_keyword_arg] = kwargs
else:
kw.update(kwargs)
for arg in self.positional_args:
if args:
kw[arg], args = args[0], args[1:]
else:
break
keyword_args: Iterable[str]
if self.var_positional_arg is not None:
kw[self.var_positional_arg] = args
keyword_args = self.keyword_args
else:
keyword_args = self.keyword_args + self.positional_args
for arg in keyword_args:
if not self.cast and arg in self.params:
kw[arg] = self.params[arg][1]
if not args:
break
if arg not in self.dependencies:
kw[arg], args = args[0], args[1:]
solved_kw: Dict[str, Any]
solved_kw = yield args, kw, call
args_: Sequence[Any]
if self.cast:
assert self.model, "Cast should be used only with model"
casted_model = self.model(**solved_kw)
kwargs_ = {arg: getattr(casted_model, arg, solved_kw.get(arg)) for arg in keyword_args}
if self.var_keyword_arg:
kwargs_.update(getattr(casted_model, self.var_keyword_arg, {}))
if self.var_positional_arg is not None:
args_ = [getattr(casted_model, arg, solved_kw.get(arg)) for arg in self.positional_args]
args_.extend(getattr(casted_model, self.var_positional_arg, ()))
else:
args_ = ()
else:
kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args}
args_ = tuple(map(solved_kw.get, self.positional_args)) if self.var_positional_arg is None else ()
response: T
response = yield args_, kwargs_, call
if self.cast and not self.is_generator:
response = self._cast_response(response)
if self.use_cache: # pragma: no branch
cache_dependencies[call] = response
return response
def _cast_response(self, /, value: Any) -> Any:
if self.response_model is not None:
return self.response_model(response=value).response
else:
return value
def solve(
self,
/,
*args: Any,
stack: ExitStack,
cache_dependencies: Dict[
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
T,
],
dependency_overrides: Optional[
Dict[
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
]
] = None,
nested: bool = False,
**kwargs: Any,
) -> T:
cast_gen = self._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
**kwargs,
)
try:
args, kwargs, _ = next(cast_gen) # type: ignore[assignment]
except StopIteration as e:
cached_value: T = e.value
return cached_value
# Heat cache and solve extra dependencies
for dep, _ in self.sorted_dependencies:
dep.solve(
*args,
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
# Always get from cache
for dep in self.extra_dependencies:
dep.solve(
*args,
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
for dep_arg, dep in self.dependencies.items():
kwargs[dep_arg] = dep.solve(
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
for custom in self.custom_fields.values():
if custom.field:
custom.use_field(kwargs)
else:
kwargs = custom.use(**kwargs)
final_args, final_kwargs, call = cast_gen.send(kwargs)
if self.is_generator and nested:
response = solve_generator_sync(
*final_args,
call=call,
stack=stack,
**final_kwargs,
)
else:
response = call(*final_args, **final_kwargs)
try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value
if not self.cast or nested or not self.is_generator:
return value
else:
return map(self._cast_response, value) # type: ignore[no-any-return, call-overload]
raise AssertionError("unreachable")
async def asolve(
self,
/,
*args: Any,
stack: AsyncExitStack,
cache_dependencies: Dict[
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
T,
],
dependency_overrides: Optional[
Dict[
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
]
] = None,
nested: bool = False,
**kwargs: Any,
) -> T:
cast_gen = self._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
**kwargs,
)
try:
args, kwargs, _ = next(cast_gen) # type: ignore[assignment]
except StopIteration as e:
cached_value: T = e.value
return cached_value
# Heat cache and solve extra dependencies
dep_to_solve: List[Callable[..., Awaitable[Any]]] = []
try:
async with anyio.create_task_group() as tg:
for dep, subdep in self.sorted_dependencies:
solve = partial(
dep.asolve,
*args,
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
if not subdep:
tg.start_soon(solve)
else:
dep_to_solve.append(solve)
except ExceptionGroup as exgr:
for ex in exgr.exceptions:
raise ex from None
for i in dep_to_solve:
await i()
# Always get from cache
for dep in self.extra_dependencies:
await dep.asolve(
*args,
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
for dep_arg, dep in self.dependencies.items():
kwargs[dep_arg] = await dep.asolve(
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
custom_to_solve: List[CustomField] = []
try:
async with anyio.create_task_group() as tg:
for custom in self.custom_fields.values():
if custom.field:
tg.start_soon(run_async, custom.use_field, kwargs)
else:
custom_to_solve.append(custom)
except ExceptionGroup as exgr:
for ex in exgr.exceptions:
raise ex from None
for j in custom_to_solve:
kwargs = await run_async(j.use, **kwargs)
final_args, final_kwargs, call = cast_gen.send(kwargs)
if self.is_generator and nested:
response = await solve_generator_async(
*final_args,
call=call,
stack=stack,
**final_kwargs,
)
else:
response = await run_async(call, *final_args, **final_kwargs)
try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value
if not self.cast or nested or not self.is_generator:
return value
else:
return async_map(self._cast_response, value) # type: ignore[return-value, arg-type]
raise AssertionError("unreachable")
def _sort_dep(
collector: List["CallModel[..., Any]"],
items: Tuple[
"CallModel[..., Any]",
Tuple[Callable[..., Any], ...],
],
flat: Dict[
Callable[..., Any],
Tuple[
"CallModel[..., Any]",
Tuple[Callable[..., Any], ...],
],
],
) -> None:
model, calls = items
if model in collector:
return
if not calls:
position = -1
else:
for i in calls:
sub_model, _ = flat[i]
if sub_model not in collector: # pragma: no branch
_sort_dep(collector, flat[i], flat)
position = max(collector.index(flat[i][0]) for i in calls)
collector.insert(position + 1, model)

View File

@@ -0,0 +1,15 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from .model import Depends
from .provider import Provider, dependency_provider
__all__ = (
"Depends",
"Provider",
"dependency_provider",
)

View File

@@ -0,0 +1,29 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import Any, Callable
class Depends:
use_cache: bool
cast: bool
def __init__(
self,
dependency: Callable[..., Any],
*,
use_cache: bool = True,
cast: bool = True,
) -> None:
self.dependency = dependency
self.use_cache = use_cache
self.cast = cast
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"

View File

@@ -0,0 +1,39 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator
class Provider:
dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]]
def __init__(self) -> None:
self.dependency_overrides = {}
def clear(self) -> None:
self.dependency_overrides = {}
def override(
self,
original: Callable[..., Any],
override: Callable[..., Any],
) -> None:
self.dependency_overrides[original] = override
@contextmanager
def scope(
self,
original: Callable[..., Any],
override: Callable[..., Any],
) -> Iterator[None]:
self.dependency_overrides[original] = override
yield
self.dependency_overrides.pop(original, None)
dependency_provider = Provider()

View File

@@ -0,0 +1,10 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from .model import CustomField
__all__ = ("CustomField",)

View File

@@ -0,0 +1,46 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from abc import ABC
from typing import Any, Dict, Optional, TypeVar
Cls = TypeVar("Cls", bound="CustomField")
class CustomField(ABC):
param_name: Optional[str]
cast: bool
required: bool
__slots__ = (
"cast",
"param_name",
"required",
"field",
)
def __init__(
self,
*,
cast: bool = True,
required: bool = True,
) -> None:
self.cast = cast
self.param_name = None
self.required = required
self.field = False
def set_param_name(self: Cls, name: str) -> Cls:
self.param_name = name
return self
def use(self, /, **kwargs: Any) -> Dict[str, Any]:
assert self.param_name, "You should specify `param_name` before using"
return kwargs
def use_field(self, kwargs: Dict[str, Any]) -> None:
raise NotImplementedError("You should implement `use_field` method.")

View File

@@ -0,0 +1,6 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,66 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import Any, Dict, List, Optional
from ._compat import PYDANTIC_V2, create_model, model_schema
from .core import CallModel
def get_schema(
call: CallModel[Any, Any],
embed: bool = False,
resolve_refs: bool = False,
) -> Dict[str, Any]:
assert call.model, "Call should has a model"
params_model = create_model( # type: ignore[call-overload]
call.model.__name__, **call.flat_params
)
body: Dict[str, Any] = model_schema(params_model)
if not call.flat_params:
body = {"title": body["title"], "type": "null"}
if resolve_refs:
pydantic_key = "$defs" if PYDANTIC_V2 else "definitions"
body = _move_pydantic_refs(body, pydantic_key)
body.pop(pydantic_key, None)
if embed and len(body["properties"]) == 1:
body = list(body["properties"].values())[0]
return body
def _move_pydantic_refs(original: Any, key: str, refs: Optional[Dict[str, Any]] = None) -> Any:
if not isinstance(original, Dict):
return original
data = original.copy()
if refs is None:
raw_refs = data.get(key, {})
refs = _move_pydantic_refs(raw_refs, key, raw_refs)
name: Optional[str] = None
for k in data:
if k == "$ref":
name = data[k].replace(f"#/{key}/", "")
elif isinstance(data[k], dict):
data[k] = _move_pydantic_refs(data[k], key, refs)
elif isinstance(data[k], List):
for i in range(len(data[k])):
data[k][i] = _move_pydantic_refs(data[k][i], key, refs)
if name:
assert refs, "Smth wrong"
data = refs[name]
return data

View File

@@ -0,0 +1,280 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
from contextlib import AsyncExitStack, ExitStack
from functools import partial, wraps
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import ParamSpec
from ._compat import ConfigDict
from .core import CallModel, build_call_model
from .dependencies import dependency_provider, model
P = ParamSpec("P")
T = TypeVar("T")
def Depends( # noqa: N802
dependency: Callable[P, T],
*,
use_cache: bool = True,
cast: bool = True,
) -> Any:
return model.Depends(
dependency=dependency,
use_cache=use_cache,
cast=cast,
)
class _InjectWrapper(Protocol[P, T]):
def __call__(
self,
func: Callable[P, T],
model: Optional[CallModel[P, T]] = None,
) -> Callable[P, T]: ...
@overload
def inject( # pragma: no cover
func: None,
*,
cast: bool = True,
extra_dependencies: Sequence[model.Depends] = (),
pydantic_config: Optional[ConfigDict] = None,
dependency_overrides_provider: Optional[Any] = dependency_provider,
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> _InjectWrapper[P, T]: ...
@overload
def inject( # pragma: no cover
func: Callable[P, T],
*,
cast: bool = True,
extra_dependencies: Sequence[model.Depends] = (),
pydantic_config: Optional[ConfigDict] = None,
dependency_overrides_provider: Optional[Any] = dependency_provider,
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> Callable[P, T]: ...
def inject(
func: Optional[Callable[P, T]] = None,
*,
cast: bool = True,
extra_dependencies: Sequence[model.Depends] = (),
pydantic_config: Optional[ConfigDict] = None,
dependency_overrides_provider: Optional[Any] = dependency_provider,
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> Union[
Callable[P, T],
_InjectWrapper[P, T],
]:
decorator = _wrap_inject(
dependency_overrides_provider=dependency_overrides_provider,
wrap_model=wrap_model,
extra_dependencies=extra_dependencies,
cast=cast,
pydantic_config=pydantic_config,
)
if func is None:
return decorator
else:
return decorator(func)
def _wrap_inject(
dependency_overrides_provider: Optional[Any],
wrap_model: Callable[
[CallModel[P, T]],
CallModel[P, T],
],
extra_dependencies: Sequence[model.Depends],
cast: bool,
pydantic_config: Optional[ConfigDict],
) -> _InjectWrapper[P, T]:
if (
dependency_overrides_provider
and getattr(dependency_overrides_provider, "dependency_overrides", None) is not None
):
overrides = dependency_overrides_provider.dependency_overrides
else:
overrides = None
def func_wrapper(
func: Callable[P, T],
model: Optional[CallModel[P, T]] = None,
) -> Callable[P, T]:
if model is None:
real_model = wrap_model(
build_call_model(
call=func,
extra_dependencies=extra_dependencies,
cast=cast,
pydantic_config=pydantic_config,
)
)
else:
real_model = model
if real_model.is_async:
injected_wrapper: Callable[P, T]
if real_model.is_generator:
injected_wrapper = partial(solve_async_gen, real_model, overrides) # type: ignore[assignment]
else:
@wraps(func)
async def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
async with AsyncExitStack() as stack:
r = await real_model.asolve(
*args,
stack=stack,
dependency_overrides=overrides,
cache_dependencies={},
nested=False,
**kwargs,
)
return r
raise AssertionError("unreachable")
else:
if real_model.is_generator:
injected_wrapper = partial(solve_gen, real_model, overrides) # type: ignore[assignment]
else:
@wraps(func)
def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with ExitStack() as stack:
r = real_model.solve(
*args,
stack=stack,
dependency_overrides=overrides,
cache_dependencies={},
nested=False,
**kwargs,
)
return r
raise AssertionError("unreachable")
return injected_wrapper
return func_wrapper
class solve_async_gen: # noqa: N801
_iter: Optional[AsyncIterator[Any]] = None
def __init__(
self,
model: "CallModel[..., Any]",
overrides: Optional[Any],
*args: Any,
**kwargs: Any,
):
self.call = model
self.args = args
self.kwargs = kwargs
self.overrides = overrides
def __aiter__(self) -> "solve_async_gen":
self._iter = None
self.stack = AsyncExitStack()
return self
async def __anext__(self) -> Any:
if self._iter is None:
stack = self.stack = AsyncExitStack()
await self.stack.__aenter__()
self._iter = cast(
AsyncIterator[Any],
(
await self.call.asolve(
*self.args,
stack=stack,
dependency_overrides=self.overrides,
cache_dependencies={},
nested=False,
**self.kwargs,
)
).__aiter__(),
)
try:
r = await self._iter.__anext__()
except StopAsyncIteration as e:
await self.stack.__aexit__(None, None, None)
raise e
else:
return r
class solve_gen: # noqa: N801
_iter: Optional[Iterator[Any]] = None
def __init__(
self,
model: "CallModel[..., Any]",
overrides: Optional[Any],
*args: Any,
**kwargs: Any,
):
self.call = model
self.args = args
self.kwargs = kwargs
self.overrides = overrides
def __iter__(self) -> "solve_gen":
self._iter = None
self.stack = ExitStack()
return self
def __next__(self) -> Any:
if self._iter is None:
stack = self.stack = ExitStack()
self.stack.__enter__()
self._iter = cast(
Iterator[Any],
iter(
self.call.solve(
*self.args,
stack=stack,
dependency_overrides=self.overrides,
cache_dependencies={},
nested=False,
**self.kwargs,
)
),
)
try:
r = next(self._iter)
except StopIteration as e:
self.stack.__exit__(None, None, None)
raise e
else:
return r

View File

@@ -0,0 +1,187 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
# SPDX-License-Identifier: MIT
import asyncio
import functools
import inspect
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterable,
Awaitable,
Callable,
ContextManager,
Dict,
ForwardRef,
List,
Tuple,
TypeVar,
Union,
cast,
)
import anyio
from typing_extensions import (
Annotated,
ParamSpec,
get_args,
get_origin,
)
from ._compat import evaluate_forwardref
if TYPE_CHECKING:
from types import FrameType
P = ParamSpec("P")
T = TypeVar("T")
async def run_async(
func: Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
if is_coroutine_callable(func):
return await cast(Callable[P, Awaitable[T]], func)(*args, **kwargs)
else:
return await run_in_threadpool(cast(Callable[P, T], func), *args, **kwargs)
async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
if kwargs:
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args)
async def solve_generator_async(
*sub_args: Any, call: Callable[..., Any], stack: AsyncExitStack, **sub_values: Any
) -> Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call): # pragma: no branch
cm = asynccontextmanager(call)(*sub_args, **sub_values)
return await stack.enter_async_context(cm)
def solve_generator_sync(*sub_args: Any, call: Callable[..., Any], stack: ExitStack, **sub_values: Any) -> Any:
cm = contextmanager(call)(*sub_args, **sub_values)
return stack.enter_context(cm)
def get_typed_signature(call: Callable[..., Any]) -> Tuple[inspect.Signature, Any]:
signature = inspect.signature(call)
locals = collect_outer_stack_locals()
# We unwrap call to get the original unwrapped function
call = inspect.unwrap(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(
param.annotation,
globalns,
locals,
),
)
for param in signature.parameters.values()
]
return inspect.Signature(typed_params), get_typed_annotation(
signature.return_annotation,
globalns,
locals,
)
def collect_outer_stack_locals() -> Dict[str, Any]:
frame = inspect.currentframe()
frames: List[FrameType] = []
while frame is not None:
if "fast_depends" not in frame.f_code.co_filename:
frames.append(frame)
frame = frame.f_back
locals = {}
for f in frames[::-1]:
locals.update(f.f_locals)
return locals
def get_typed_annotation(
annotation: Any,
globalns: Dict[str, Any],
locals: Dict[str, Any],
) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
if isinstance(annotation, ForwardRef):
annotation = evaluate_forwardref(annotation, globalns, locals)
if get_origin(annotation) is Annotated and (args := get_args(annotation)):
solved_args = [get_typed_annotation(x, globalns, locals) for x in args]
annotation.__origin__, annotation.__metadata__ = solved_args[0], tuple(solved_args[1:])
return annotation
@asynccontextmanager
async def contextmanager_in_threadpool(
cm: ContextManager[T],
) -> AsyncGenerator[T, None]:
exit_limiter = anyio.CapacityLimiter(1)
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
ok = bool(await anyio.to_thread.run_sync(cm.__exit__, type(e), e, None, limiter=exit_limiter))
if not ok: # pragma: no branch
raise e
else:
await anyio.to_thread.run_sync(cm.__exit__, None, None, None, limiter=exit_limiter)
def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isclass(call):
return False
if asyncio.iscoroutinefunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return asyncio.iscoroutinefunction(dunder_call)
async def async_map(func: Callable[..., T], async_iterable: AsyncIterable[Any]) -> AsyncIterable[T]:
async for i in async_iterable:
yield func(i)