281 lines
7.7 KiB
Python
281 lines
7.7 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from contextlib import AsyncExitStack, ExitStack
|
|
from functools import partial, wraps
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Callable,
|
|
Iterator,
|
|
Optional,
|
|
Protocol,
|
|
Sequence,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
overload,
|
|
)
|
|
|
|
from typing_extensions import ParamSpec
|
|
|
|
from ._compat import ConfigDict
|
|
from .core import CallModel, build_call_model
|
|
from .dependencies import dependency_provider, model
|
|
|
|
P = ParamSpec("P")
|
|
T = TypeVar("T")
|
|
|
|
|
|
def Depends( # noqa: N802
|
|
dependency: Callable[P, T],
|
|
*,
|
|
use_cache: bool = True,
|
|
cast: bool = True,
|
|
) -> Any:
|
|
return model.Depends(
|
|
dependency=dependency,
|
|
use_cache=use_cache,
|
|
cast=cast,
|
|
)
|
|
|
|
|
|
class _InjectWrapper(Protocol[P, T]):
|
|
def __call__(
|
|
self,
|
|
func: Callable[P, T],
|
|
model: Optional[CallModel[P, T]] = None,
|
|
) -> Callable[P, T]: ...
|
|
|
|
|
|
@overload
|
|
def inject( # pragma: no cover
|
|
func: None,
|
|
*,
|
|
cast: bool = True,
|
|
extra_dependencies: Sequence[model.Depends] = (),
|
|
pydantic_config: Optional[ConfigDict] = None,
|
|
dependency_overrides_provider: Optional[Any] = dependency_provider,
|
|
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
|
|
) -> _InjectWrapper[P, T]: ...
|
|
|
|
|
|
@overload
|
|
def inject( # pragma: no cover
|
|
func: Callable[P, T],
|
|
*,
|
|
cast: bool = True,
|
|
extra_dependencies: Sequence[model.Depends] = (),
|
|
pydantic_config: Optional[ConfigDict] = None,
|
|
dependency_overrides_provider: Optional[Any] = dependency_provider,
|
|
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
|
|
) -> Callable[P, T]: ...
|
|
|
|
|
|
def inject(
|
|
func: Optional[Callable[P, T]] = None,
|
|
*,
|
|
cast: bool = True,
|
|
extra_dependencies: Sequence[model.Depends] = (),
|
|
pydantic_config: Optional[ConfigDict] = None,
|
|
dependency_overrides_provider: Optional[Any] = dependency_provider,
|
|
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
|
|
) -> Union[
|
|
Callable[P, T],
|
|
_InjectWrapper[P, T],
|
|
]:
|
|
decorator = _wrap_inject(
|
|
dependency_overrides_provider=dependency_overrides_provider,
|
|
wrap_model=wrap_model,
|
|
extra_dependencies=extra_dependencies,
|
|
cast=cast,
|
|
pydantic_config=pydantic_config,
|
|
)
|
|
|
|
if func is None:
|
|
return decorator
|
|
|
|
else:
|
|
return decorator(func)
|
|
|
|
|
|
def _wrap_inject(
|
|
dependency_overrides_provider: Optional[Any],
|
|
wrap_model: Callable[
|
|
[CallModel[P, T]],
|
|
CallModel[P, T],
|
|
],
|
|
extra_dependencies: Sequence[model.Depends],
|
|
cast: bool,
|
|
pydantic_config: Optional[ConfigDict],
|
|
) -> _InjectWrapper[P, T]:
|
|
if (
|
|
dependency_overrides_provider
|
|
and getattr(dependency_overrides_provider, "dependency_overrides", None) is not None
|
|
):
|
|
overrides = dependency_overrides_provider.dependency_overrides
|
|
else:
|
|
overrides = None
|
|
|
|
def func_wrapper(
|
|
func: Callable[P, T],
|
|
model: Optional[CallModel[P, T]] = None,
|
|
) -> Callable[P, T]:
|
|
if model is None:
|
|
real_model = wrap_model(
|
|
build_call_model(
|
|
call=func,
|
|
extra_dependencies=extra_dependencies,
|
|
cast=cast,
|
|
pydantic_config=pydantic_config,
|
|
)
|
|
)
|
|
else:
|
|
real_model = model
|
|
|
|
if real_model.is_async:
|
|
injected_wrapper: Callable[P, T]
|
|
|
|
if real_model.is_generator:
|
|
injected_wrapper = partial(solve_async_gen, real_model, overrides) # type: ignore[assignment]
|
|
|
|
else:
|
|
|
|
@wraps(func)
|
|
async def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
async with AsyncExitStack() as stack:
|
|
r = await real_model.asolve(
|
|
*args,
|
|
stack=stack,
|
|
dependency_overrides=overrides,
|
|
cache_dependencies={},
|
|
nested=False,
|
|
**kwargs,
|
|
)
|
|
return r
|
|
|
|
raise AssertionError("unreachable")
|
|
|
|
else:
|
|
if real_model.is_generator:
|
|
injected_wrapper = partial(solve_gen, real_model, overrides) # type: ignore[assignment]
|
|
|
|
else:
|
|
|
|
@wraps(func)
|
|
def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
with ExitStack() as stack:
|
|
r = real_model.solve(
|
|
*args,
|
|
stack=stack,
|
|
dependency_overrides=overrides,
|
|
cache_dependencies={},
|
|
nested=False,
|
|
**kwargs,
|
|
)
|
|
return r
|
|
|
|
raise AssertionError("unreachable")
|
|
|
|
return injected_wrapper
|
|
|
|
return func_wrapper
|
|
|
|
|
|
class solve_async_gen: # noqa: N801
|
|
_iter: Optional[AsyncIterator[Any]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
model: "CallModel[..., Any]",
|
|
overrides: Optional[Any],
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
):
|
|
self.call = model
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
self.overrides = overrides
|
|
|
|
def __aiter__(self) -> "solve_async_gen":
|
|
self._iter = None
|
|
self.stack = AsyncExitStack()
|
|
return self
|
|
|
|
async def __anext__(self) -> Any:
|
|
if self._iter is None:
|
|
stack = self.stack = AsyncExitStack()
|
|
await self.stack.__aenter__()
|
|
self._iter = cast(
|
|
AsyncIterator[Any],
|
|
(
|
|
await self.call.asolve(
|
|
*self.args,
|
|
stack=stack,
|
|
dependency_overrides=self.overrides,
|
|
cache_dependencies={},
|
|
nested=False,
|
|
**self.kwargs,
|
|
)
|
|
).__aiter__(),
|
|
)
|
|
|
|
try:
|
|
r = await self._iter.__anext__()
|
|
except StopAsyncIteration as e:
|
|
await self.stack.__aexit__(None, None, None)
|
|
raise e
|
|
else:
|
|
return r
|
|
|
|
|
|
class solve_gen: # noqa: N801
|
|
_iter: Optional[Iterator[Any]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
model: "CallModel[..., Any]",
|
|
overrides: Optional[Any],
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
):
|
|
self.call = model
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
self.overrides = overrides
|
|
|
|
def __iter__(self) -> "solve_gen":
|
|
self._iter = None
|
|
self.stack = ExitStack()
|
|
return self
|
|
|
|
def __next__(self) -> Any:
|
|
if self._iter is None:
|
|
stack = self.stack = ExitStack()
|
|
self.stack.__enter__()
|
|
self._iter = cast(
|
|
Iterator[Any],
|
|
iter(
|
|
self.call.solve(
|
|
*self.args,
|
|
stack=stack,
|
|
dependency_overrides=self.overrides,
|
|
cache_dependencies={},
|
|
nested=False,
|
|
**self.kwargs,
|
|
)
|
|
),
|
|
)
|
|
|
|
try:
|
|
r = next(self._iter)
|
|
except StopIteration as e:
|
|
self.stack.__exit__(None, None, None)
|
|
raise e
|
|
else:
|
|
return r
|