CoACT initialize (#292)
This commit is contained in:
16
mm_agents/coact/autogen/fast_depends/__init__.py
Normal file
16
mm_agents/coact/autogen/fast_depends/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .dependencies import Provider, dependency_provider
|
||||
from .use import Depends, inject
|
||||
|
||||
__all__ = (
|
||||
"Depends",
|
||||
"Provider",
|
||||
"dependency_provider",
|
||||
"inject",
|
||||
)
|
||||
80
mm_agents/coact/autogen/fast_depends/_compat.py
Normal file
80
mm_agents/coact/autogen/fast_depends/_compat.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import sys
|
||||
from importlib.metadata import version as get_version
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
|
||||
__all__ = (
|
||||
"PYDANTIC_V2",
|
||||
"BaseModel",
|
||||
"ConfigDict",
|
||||
"ExceptionGroup",
|
||||
"create_model",
|
||||
"evaluate_forwardref",
|
||||
"get_config_base",
|
||||
)
|
||||
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
|
||||
default_pydantic_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
evaluate_forwardref: Any
|
||||
# isort: off
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import ConfigDict
|
||||
from pydantic._internal._typing_extra import ( # type: ignore[no-redef]
|
||||
eval_type_lenient as evaluate_forwardref,
|
||||
)
|
||||
|
||||
def model_schema(model: Type[BaseModel]) -> Dict[str, Any]:
|
||||
return model.model_json_schema()
|
||||
|
||||
def get_config_base(config_data: Optional[ConfigDict] = None) -> ConfigDict:
|
||||
return config_data or ConfigDict(**default_pydantic_config) # type: ignore[typeddict-item]
|
||||
|
||||
def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]:
|
||||
return tuple(f.alias or name for name, f in model.model_fields.items())
|
||||
|
||||
class CreateBaseModel(BaseModel):
|
||||
"""Just to support FastStream < 0.3.7."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
else:
|
||||
from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef]
|
||||
from pydantic.config import get_config, ConfigDict, BaseConfig
|
||||
|
||||
def get_config_base(config_data: Optional[ConfigDict] = None) -> Type[BaseConfig]: # type: ignore[misc,no-any-unimported]
|
||||
return get_config(config_data or ConfigDict(**default_pydantic_config)) # type: ignore[typeddict-item,no-any-unimported,no-any-return]
|
||||
|
||||
def model_schema(model: Type[BaseModel]) -> Dict[str, Any]:
|
||||
return model.schema()
|
||||
|
||||
def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]:
|
||||
return tuple(f.alias or name for name, f in model.__fields__.items()) # type: ignore[attr-defined]
|
||||
|
||||
class CreateBaseModel(BaseModel): # type: ignore[no-redef]
|
||||
"""Just to support FastStream < 0.3.7."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
ANYIO_V3 = get_version("anyio").startswith("3.")
|
||||
|
||||
if ANYIO_V3:
|
||||
from anyio import ExceptionGroup as ExceptionGroup # type: ignore[attr-defined]
|
||||
else:
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import ExceptionGroup as ExceptionGroup
|
||||
else:
|
||||
ExceptionGroup = ExceptionGroup
|
||||
14
mm_agents/coact/autogen/fast_depends/core/__init__.py
Normal file
14
mm_agents/coact/autogen/fast_depends/core/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .build import build_call_model
|
||||
from .model import CallModel
|
||||
|
||||
__all__ = (
|
||||
"CallModel",
|
||||
"build_call_model",
|
||||
)
|
||||
225
mm_agents/coact/autogen/fast_depends/core/build.py
Normal file
225
mm_agents/coact/autogen/fast_depends/core/build.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import (
|
||||
Annotated,
|
||||
ParamSpec,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from .._compat import ConfigDict, create_model, get_config_base
|
||||
from ..dependencies import Depends
|
||||
from ..library import CustomField
|
||||
from ..utils import (
|
||||
get_typed_signature,
|
||||
is_async_gen_callable,
|
||||
is_coroutine_callable,
|
||||
is_gen_callable,
|
||||
)
|
||||
from .model import CallModel, ResponseModel
|
||||
|
||||
CUSTOM_ANNOTATIONS = (Depends, CustomField)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def build_call_model(
|
||||
call: Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
*,
|
||||
cast: bool = True,
|
||||
use_cache: bool = True,
|
||||
is_sync: Optional[bool] = None,
|
||||
extra_dependencies: Sequence[Depends] = (),
|
||||
pydantic_config: Optional[ConfigDict] = None,
|
||||
) -> CallModel[P, T]:
|
||||
name = getattr(call, "__name__", type(call).__name__)
|
||||
|
||||
is_call_async = is_coroutine_callable(call) or is_async_gen_callable(call)
|
||||
if is_sync is None:
|
||||
is_sync = not is_call_async
|
||||
else:
|
||||
assert not (is_sync and is_call_async), f"You cannot use async dependency `{name}` at sync main"
|
||||
|
||||
typed_params, return_annotation = get_typed_signature(call)
|
||||
if (is_call_generator := is_gen_callable(call) or is_async_gen_callable(call)) and (
|
||||
return_args := get_args(return_annotation)
|
||||
):
|
||||
return_annotation = return_args[0]
|
||||
|
||||
class_fields: Dict[str, Tuple[Any, Any]] = {}
|
||||
dependencies: Dict[str, CallModel[..., Any]] = {}
|
||||
custom_fields: Dict[str, CustomField] = {}
|
||||
positional_args: List[str] = []
|
||||
keyword_args: List[str] = []
|
||||
var_positional_arg: Optional[str] = None
|
||||
var_keyword_arg: Optional[str] = None
|
||||
|
||||
for param_name, param in typed_params.parameters.items():
|
||||
dep: Optional[Depends] = None
|
||||
custom: Optional[CustomField] = None
|
||||
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
annotation = Any
|
||||
|
||||
elif get_origin(param.annotation) is Annotated:
|
||||
annotated_args = get_args(param.annotation)
|
||||
type_annotation = annotated_args[0]
|
||||
|
||||
custom_annotations = []
|
||||
regular_annotations = []
|
||||
for arg in annotated_args[1:]:
|
||||
if isinstance(arg, CUSTOM_ANNOTATIONS):
|
||||
custom_annotations.append(arg)
|
||||
else:
|
||||
regular_annotations.append(arg)
|
||||
|
||||
assert len(custom_annotations) <= 1, (
|
||||
f"Cannot specify multiple `Annotated` Custom arguments for `{param_name}`!"
|
||||
)
|
||||
|
||||
next_custom = next(iter(custom_annotations), None)
|
||||
if next_custom is not None:
|
||||
if isinstance(next_custom, Depends):
|
||||
dep = next_custom
|
||||
elif isinstance(next_custom, CustomField):
|
||||
custom = deepcopy(next_custom)
|
||||
else: # pragma: no cover
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
annotation = param.annotation if regular_annotations else type_annotation
|
||||
else:
|
||||
annotation = param.annotation
|
||||
else:
|
||||
annotation = param.annotation
|
||||
|
||||
default: Any
|
||||
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
default = ()
|
||||
var_positional_arg = param_name
|
||||
elif param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
default = {}
|
||||
var_keyword_arg = param_name
|
||||
elif param.default is inspect.Parameter.empty:
|
||||
default = Ellipsis
|
||||
else:
|
||||
default = param.default
|
||||
|
||||
if isinstance(default, Depends):
|
||||
assert not dep, "You can not use `Depends` with `Annotated` and default both"
|
||||
dep, default = default, Ellipsis
|
||||
|
||||
elif isinstance(default, CustomField):
|
||||
assert not custom, "You can not use `CustomField` with `Annotated` and default both"
|
||||
custom, default = default, Ellipsis
|
||||
|
||||
else:
|
||||
class_fields[param_name] = (annotation, default)
|
||||
|
||||
if dep:
|
||||
if not cast:
|
||||
dep.cast = False
|
||||
|
||||
dependencies[param_name] = build_call_model(
|
||||
dep.dependency,
|
||||
cast=dep.cast,
|
||||
use_cache=dep.use_cache,
|
||||
is_sync=is_sync,
|
||||
pydantic_config=pydantic_config,
|
||||
)
|
||||
|
||||
if dep.cast is True:
|
||||
class_fields[param_name] = (annotation, Ellipsis)
|
||||
|
||||
keyword_args.append(param_name)
|
||||
|
||||
elif custom:
|
||||
assert not (is_sync and is_coroutine_callable(custom.use)), (
|
||||
f"You cannot use async custom field `{type(custom).__name__}` at sync `{name}`"
|
||||
)
|
||||
|
||||
custom.set_param_name(param_name)
|
||||
custom_fields[param_name] = custom
|
||||
|
||||
if custom.cast is False:
|
||||
annotation = Any
|
||||
|
||||
if custom.required:
|
||||
class_fields[param_name] = (annotation, default)
|
||||
|
||||
else:
|
||||
class_fields[param_name] = class_fields.get(param_name, (Optional[annotation], None))
|
||||
|
||||
keyword_args.append(param_name)
|
||||
|
||||
else:
|
||||
if param.kind is param.KEYWORD_ONLY:
|
||||
keyword_args.append(param_name)
|
||||
elif param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
||||
positional_args.append(param_name)
|
||||
|
||||
func_model = create_model( # type: ignore[call-overload]
|
||||
name,
|
||||
__config__=get_config_base(pydantic_config),
|
||||
**class_fields,
|
||||
)
|
||||
|
||||
response_model: Optional[Type[ResponseModel[T]]] = None
|
||||
if cast and return_annotation and return_annotation is not inspect.Parameter.empty:
|
||||
response_model = create_model( # type: ignore[call-overload,assignment]
|
||||
"ResponseModel",
|
||||
__config__=get_config_base(pydantic_config),
|
||||
response=(return_annotation, Ellipsis),
|
||||
)
|
||||
|
||||
return CallModel(
|
||||
call=call,
|
||||
model=func_model,
|
||||
response_model=response_model,
|
||||
params=class_fields,
|
||||
cast=cast,
|
||||
use_cache=use_cache,
|
||||
is_async=is_call_async,
|
||||
is_generator=is_call_generator,
|
||||
dependencies=dependencies,
|
||||
custom_fields=custom_fields,
|
||||
positional_args=positional_args,
|
||||
keyword_args=keyword_args,
|
||||
var_positional_arg=var_positional_arg,
|
||||
var_keyword_arg=var_keyword_arg,
|
||||
extra_dependencies=[
|
||||
build_call_model(
|
||||
d.dependency,
|
||||
cast=d.cast,
|
||||
use_cache=d.use_cache,
|
||||
is_sync=is_sync,
|
||||
pydantic_config=pydantic_config,
|
||||
)
|
||||
for d in extra_dependencies
|
||||
],
|
||||
)
|
||||
576
mm_agents/coact/autogen/fast_depends/core/model.py
Normal file
576
mm_agents/coact/autogen/fast_depends/core/model.py
Normal file
@@ -0,0 +1,576 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from collections import namedtuple
|
||||
from contextlib import AsyncExitStack, ExitStack
|
||||
from functools import partial
|
||||
from inspect import Parameter, unwrap
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import anyio
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .._compat import BaseModel, ExceptionGroup, get_aliases
|
||||
from ..library import CustomField
|
||||
from ..utils import (
|
||||
async_map,
|
||||
is_async_gen_callable,
|
||||
is_coroutine_callable,
|
||||
is_gen_callable,
|
||||
run_async,
|
||||
solve_generator_async,
|
||||
solve_generator_sync,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
PriorityPair = namedtuple("PriorityPair", ("call", "dependencies_number", "dependencies_names"))
|
||||
|
||||
|
||||
class ResponseModel(BaseModel, Generic[T]):
|
||||
response: T
|
||||
|
||||
|
||||
class CallModel(Generic[P, T]):
|
||||
call: Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
]
|
||||
is_async: bool
|
||||
is_generator: bool
|
||||
model: Optional[Type[BaseModel]]
|
||||
response_model: Optional[Type[ResponseModel[T]]]
|
||||
|
||||
params: Dict[str, Tuple[Any, Any]]
|
||||
alias_arguments: Tuple[str, ...]
|
||||
|
||||
dependencies: Dict[str, "CallModel[..., Any]"]
|
||||
extra_dependencies: Iterable["CallModel[..., Any]"]
|
||||
sorted_dependencies: Tuple[Tuple["CallModel[..., Any]", int], ...]
|
||||
custom_fields: Dict[str, CustomField]
|
||||
keyword_args: Tuple[str, ...]
|
||||
positional_args: Tuple[str, ...]
|
||||
var_positional_arg: Optional[str]
|
||||
var_keyword_arg: Optional[str]
|
||||
|
||||
# Dependencies and custom fields
|
||||
use_cache: bool
|
||||
cast: bool
|
||||
|
||||
__slots__ = (
|
||||
"call",
|
||||
"is_async",
|
||||
"is_generator",
|
||||
"model",
|
||||
"response_model",
|
||||
"params",
|
||||
"alias_arguments",
|
||||
"keyword_args",
|
||||
"positional_args",
|
||||
"var_positional_arg",
|
||||
"var_keyword_arg",
|
||||
"dependencies",
|
||||
"extra_dependencies",
|
||||
"sorted_dependencies",
|
||||
"custom_fields",
|
||||
"use_cache",
|
||||
"cast",
|
||||
)
|
||||
|
||||
@property
|
||||
def call_name(self) -> str:
|
||||
call = unwrap(self.call)
|
||||
return getattr(call, "__name__", type(call).__name__)
|
||||
|
||||
@property
|
||||
def flat_params(self) -> Dict[str, Tuple[Any, Any]]:
|
||||
params = self.params
|
||||
for d in (*self.dependencies.values(), *self.extra_dependencies):
|
||||
params.update(d.flat_params)
|
||||
return params
|
||||
|
||||
@property
|
||||
def flat_dependencies(
|
||||
self,
|
||||
) -> Dict[
|
||||
Callable[..., Any],
|
||||
Tuple[
|
||||
"CallModel[..., Any]",
|
||||
Tuple[Callable[..., Any], ...],
|
||||
],
|
||||
]:
|
||||
flat: Dict[
|
||||
Callable[..., Any],
|
||||
Tuple[
|
||||
CallModel[..., Any],
|
||||
Tuple[Callable[..., Any], ...],
|
||||
],
|
||||
] = {}
|
||||
|
||||
for i in (*self.dependencies.values(), *self.extra_dependencies):
|
||||
flat.update({
|
||||
i.call: (
|
||||
i,
|
||||
tuple(j.call for j in i.dependencies.values()),
|
||||
)
|
||||
})
|
||||
|
||||
flat.update(i.flat_dependencies)
|
||||
|
||||
return flat
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
/,
|
||||
call: Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
model: Optional[Type[BaseModel]],
|
||||
params: Dict[str, Tuple[Any, Any]],
|
||||
response_model: Optional[Type[ResponseModel[T]]] = None,
|
||||
use_cache: bool = True,
|
||||
cast: bool = True,
|
||||
is_async: bool = False,
|
||||
is_generator: bool = False,
|
||||
dependencies: Optional[Dict[str, "CallModel[..., Any]"]] = None,
|
||||
extra_dependencies: Optional[Iterable["CallModel[..., Any]"]] = None,
|
||||
keyword_args: Optional[List[str]] = None,
|
||||
positional_args: Optional[List[str]] = None,
|
||||
var_positional_arg: Optional[str] = None,
|
||||
var_keyword_arg: Optional[str] = None,
|
||||
custom_fields: Optional[Dict[str, CustomField]] = None,
|
||||
):
|
||||
self.call = call
|
||||
self.model = model
|
||||
|
||||
if model:
|
||||
self.alias_arguments = get_aliases(model)
|
||||
else: # pragma: no cover
|
||||
self.alias_arguments = ()
|
||||
|
||||
self.keyword_args = tuple(keyword_args or ())
|
||||
self.positional_args = tuple(positional_args or ())
|
||||
self.var_positional_arg = var_positional_arg
|
||||
self.var_keyword_arg = var_keyword_arg
|
||||
self.response_model = response_model
|
||||
self.use_cache = use_cache
|
||||
self.cast = cast
|
||||
self.is_async = is_async or is_coroutine_callable(call) or is_async_gen_callable(call)
|
||||
self.is_generator = is_generator or is_gen_callable(call) or is_async_gen_callable(call)
|
||||
|
||||
self.dependencies = dependencies or {}
|
||||
self.extra_dependencies = extra_dependencies or ()
|
||||
self.custom_fields = custom_fields or {}
|
||||
|
||||
sorted_dep: List[CallModel[..., Any]] = []
|
||||
flat = self.flat_dependencies
|
||||
for calls in flat.values():
|
||||
_sort_dep(sorted_dep, calls, flat)
|
||||
|
||||
self.sorted_dependencies = tuple((i, len(i.sorted_dependencies)) for i in sorted_dep if i.use_cache)
|
||||
for name in chain(self.dependencies.keys(), self.custom_fields.keys()):
|
||||
params.pop(name, None)
|
||||
self.params = params
|
||||
|
||||
def _solve(
|
||||
self,
|
||||
/,
|
||||
*args: Tuple[Any, ...],
|
||||
cache_dependencies: Dict[
|
||||
Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
T,
|
||||
],
|
||||
dependency_overrides: Optional[
|
||||
Dict[
|
||||
Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Generator[
|
||||
Tuple[
|
||||
Sequence[Any],
|
||||
Dict[str, Any],
|
||||
Callable[..., Any],
|
||||
],
|
||||
Any,
|
||||
T,
|
||||
]:
|
||||
if dependency_overrides:
|
||||
call = dependency_overrides.get(self.call, self.call)
|
||||
assert self.is_async or not is_coroutine_callable(call), (
|
||||
f"You cannot use async dependency `{self.call_name}` at sync main"
|
||||
)
|
||||
|
||||
else:
|
||||
call = self.call
|
||||
|
||||
if self.use_cache and call in cache_dependencies:
|
||||
return cache_dependencies[call]
|
||||
|
||||
kw: Dict[str, Any] = {}
|
||||
|
||||
for arg in self.keyword_args:
|
||||
if (v := kwargs.pop(arg, Parameter.empty)) is not Parameter.empty:
|
||||
kw[arg] = v
|
||||
|
||||
if self.var_keyword_arg is not None:
|
||||
kw[self.var_keyword_arg] = kwargs
|
||||
else:
|
||||
kw.update(kwargs)
|
||||
|
||||
for arg in self.positional_args:
|
||||
if args:
|
||||
kw[arg], args = args[0], args[1:]
|
||||
else:
|
||||
break
|
||||
|
||||
keyword_args: Iterable[str]
|
||||
if self.var_positional_arg is not None:
|
||||
kw[self.var_positional_arg] = args
|
||||
keyword_args = self.keyword_args
|
||||
|
||||
else:
|
||||
keyword_args = self.keyword_args + self.positional_args
|
||||
for arg in keyword_args:
|
||||
if not self.cast and arg in self.params:
|
||||
kw[arg] = self.params[arg][1]
|
||||
|
||||
if not args:
|
||||
break
|
||||
|
||||
if arg not in self.dependencies:
|
||||
kw[arg], args = args[0], args[1:]
|
||||
|
||||
solved_kw: Dict[str, Any]
|
||||
solved_kw = yield args, kw, call
|
||||
|
||||
args_: Sequence[Any]
|
||||
if self.cast:
|
||||
assert self.model, "Cast should be used only with model"
|
||||
casted_model = self.model(**solved_kw)
|
||||
|
||||
kwargs_ = {arg: getattr(casted_model, arg, solved_kw.get(arg)) for arg in keyword_args}
|
||||
if self.var_keyword_arg:
|
||||
kwargs_.update(getattr(casted_model, self.var_keyword_arg, {}))
|
||||
|
||||
if self.var_positional_arg is not None:
|
||||
args_ = [getattr(casted_model, arg, solved_kw.get(arg)) for arg in self.positional_args]
|
||||
args_.extend(getattr(casted_model, self.var_positional_arg, ()))
|
||||
else:
|
||||
args_ = ()
|
||||
|
||||
else:
|
||||
kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args}
|
||||
|
||||
args_ = tuple(map(solved_kw.get, self.positional_args)) if self.var_positional_arg is None else ()
|
||||
|
||||
response: T
|
||||
response = yield args_, kwargs_, call
|
||||
|
||||
if self.cast and not self.is_generator:
|
||||
response = self._cast_response(response)
|
||||
|
||||
if self.use_cache: # pragma: no branch
|
||||
cache_dependencies[call] = response
|
||||
|
||||
return response
|
||||
|
||||
def _cast_response(self, /, value: Any) -> Any:
|
||||
if self.response_model is not None:
|
||||
return self.response_model(response=value).response
|
||||
else:
|
||||
return value
|
||||
|
||||
def solve(
|
||||
self,
|
||||
/,
|
||||
*args: Any,
|
||||
stack: ExitStack,
|
||||
cache_dependencies: Dict[
|
||||
Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
T,
|
||||
],
|
||||
dependency_overrides: Optional[
|
||||
Dict[
|
||||
Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
nested: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
cast_gen = self._solve(
|
||||
*args,
|
||||
cache_dependencies=cache_dependencies,
|
||||
dependency_overrides=dependency_overrides,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
args, kwargs, _ = next(cast_gen) # type: ignore[assignment]
|
||||
except StopIteration as e:
|
||||
cached_value: T = e.value
|
||||
return cached_value
|
||||
|
||||
# Heat cache and solve extra dependencies
|
||||
for dep, _ in self.sorted_dependencies:
|
||||
dep.solve(
|
||||
*args,
|
||||
stack=stack,
|
||||
cache_dependencies=cache_dependencies,
|
||||
dependency_overrides=dependency_overrides,
|
||||
nested=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Always get from cache
|
||||
for dep in self.extra_dependencies:
|
||||
dep.solve(
|
||||
*args,
|
||||
stack=stack,
|
||||
cache_dependencies=cache_dependencies,
|
||||
dependency_overrides=dependency_overrides,
|
||||
nested=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for dep_arg, dep in self.dependencies.items():
|
||||
kwargs[dep_arg] = dep.solve(
|
||||
stack=stack,
|
||||
cache_dependencies=cache_dependencies,
|
||||
dependency_overrides=dependency_overrides,
|
||||
nested=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for custom in self.custom_fields.values():
|
||||
if custom.field:
|
||||
custom.use_field(kwargs)
|
||||
else:
|
||||
kwargs = custom.use(**kwargs)
|
||||
|
||||
final_args, final_kwargs, call = cast_gen.send(kwargs)
|
||||
|
||||
if self.is_generator and nested:
|
||||
response = solve_generator_sync(
|
||||
*final_args,
|
||||
call=call,
|
||||
stack=stack,
|
||||
**final_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
response = call(*final_args, **final_kwargs)
|
||||
|
||||
try:
|
||||
cast_gen.send(response)
|
||||
except StopIteration as e:
|
||||
value: T = e.value
|
||||
|
||||
if not self.cast or nested or not self.is_generator:
|
||||
return value
|
||||
|
||||
else:
|
||||
return map(self._cast_response, value) # type: ignore[no-any-return, call-overload]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
async def asolve(
|
||||
self,
|
||||
/,
|
||||
*args: Any,
|
||||
stack: AsyncExitStack,
|
||||
cache_dependencies: Dict[
|
||||
Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
T,
|
||||
],
|
||||
dependency_overrides: Optional[
|
||||
Dict[
|
||||
Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
nested: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
cast_gen = self._solve(
|
||||
*args,
|
||||
cache_dependencies=cache_dependencies,
|
||||
dependency_overrides=dependency_overrides,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
args, kwargs, _ = next(cast_gen) # type: ignore[assignment]
|
||||
except StopIteration as e:
|
||||
cached_value: T = e.value
|
||||
return cached_value
|
||||
|
||||
# Heat cache and solve extra dependencies
|
||||
dep_to_solve: List[Callable[..., Awaitable[Any]]] = []
|
||||
try:
|
||||
async with anyio.create_task_group() as tg:
|
||||
for dep, subdep in self.sorted_dependencies:
|
||||
solve = partial(
|
||||
dep.asolve,
|
||||
*args,
|
||||
stack=stack,
|
||||
cache_dependencies=cache_dependencies,
|
||||
dependency_overrides=dependency_overrides,
|
||||
nested=True,
|
||||
**kwargs,
|
||||
)
|
||||
if not subdep:
|
||||
tg.start_soon(solve)
|
||||
else:
|
||||
dep_to_solve.append(solve)
|
||||
except ExceptionGroup as exgr:
|
||||
for ex in exgr.exceptions:
|
||||
raise ex from None
|
||||
|
||||
for i in dep_to_solve:
|
||||
await i()
|
||||
|
||||
# Always get from cache
|
||||
for dep in self.extra_dependencies:
|
||||
await dep.asolve(
|
||||
*args,
|
||||
stack=stack,
|
||||
cache_dependencies=cache_dependencies,
|
||||
dependency_overrides=dependency_overrides,
|
||||
nested=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for dep_arg, dep in self.dependencies.items():
|
||||
kwargs[dep_arg] = await dep.asolve(
|
||||
stack=stack,
|
||||
cache_dependencies=cache_dependencies,
|
||||
dependency_overrides=dependency_overrides,
|
||||
nested=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
custom_to_solve: List[CustomField] = []
|
||||
|
||||
try:
|
||||
async with anyio.create_task_group() as tg:
|
||||
for custom in self.custom_fields.values():
|
||||
if custom.field:
|
||||
tg.start_soon(run_async, custom.use_field, kwargs)
|
||||
else:
|
||||
custom_to_solve.append(custom)
|
||||
|
||||
except ExceptionGroup as exgr:
|
||||
for ex in exgr.exceptions:
|
||||
raise ex from None
|
||||
|
||||
for j in custom_to_solve:
|
||||
kwargs = await run_async(j.use, **kwargs)
|
||||
|
||||
final_args, final_kwargs, call = cast_gen.send(kwargs)
|
||||
|
||||
if self.is_generator and nested:
|
||||
response = await solve_generator_async(
|
||||
*final_args,
|
||||
call=call,
|
||||
stack=stack,
|
||||
**final_kwargs,
|
||||
)
|
||||
else:
|
||||
response = await run_async(call, *final_args, **final_kwargs)
|
||||
|
||||
try:
|
||||
cast_gen.send(response)
|
||||
except StopIteration as e:
|
||||
value: T = e.value
|
||||
|
||||
if not self.cast or nested or not self.is_generator:
|
||||
return value
|
||||
|
||||
else:
|
||||
return async_map(self._cast_response, value) # type: ignore[return-value, arg-type]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
def _sort_dep(
|
||||
collector: List["CallModel[..., Any]"],
|
||||
items: Tuple[
|
||||
"CallModel[..., Any]",
|
||||
Tuple[Callable[..., Any], ...],
|
||||
],
|
||||
flat: Dict[
|
||||
Callable[..., Any],
|
||||
Tuple[
|
||||
"CallModel[..., Any]",
|
||||
Tuple[Callable[..., Any], ...],
|
||||
],
|
||||
],
|
||||
) -> None:
|
||||
model, calls = items
|
||||
|
||||
if model in collector:
|
||||
return
|
||||
|
||||
if not calls:
|
||||
position = -1
|
||||
|
||||
else:
|
||||
for i in calls:
|
||||
sub_model, _ = flat[i]
|
||||
if sub_model not in collector: # pragma: no branch
|
||||
_sort_dep(collector, flat[i], flat)
|
||||
|
||||
position = max(collector.index(flat[i][0]) for i in calls)
|
||||
|
||||
collector.insert(position + 1, model)
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .model import Depends
|
||||
from .provider import Provider, dependency_provider
|
||||
|
||||
__all__ = (
|
||||
"Depends",
|
||||
"Provider",
|
||||
"dependency_provider",
|
||||
)
|
||||
29
mm_agents/coact/autogen/fast_depends/dependencies/model.py
Normal file
29
mm_agents/coact/autogen/fast_depends/dependencies/model.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class Depends:
|
||||
use_cache: bool
|
||||
cast: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dependency: Callable[..., Any],
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
cast: bool = True,
|
||||
) -> None:
|
||||
self.dependency = dependency
|
||||
self.use_cache = use_cache
|
||||
self.cast = cast
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
|
||||
cache = "" if self.use_cache else ", use_cache=False"
|
||||
return f"{self.__class__.__name__}({attr}{cache})"
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Iterator
|
||||
|
||||
|
||||
class Provider:
|
||||
dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.dependency_overrides = {}
|
||||
|
||||
def clear(self) -> None:
|
||||
self.dependency_overrides = {}
|
||||
|
||||
def override(
|
||||
self,
|
||||
original: Callable[..., Any],
|
||||
override: Callable[..., Any],
|
||||
) -> None:
|
||||
self.dependency_overrides[original] = override
|
||||
|
||||
@contextmanager
|
||||
def scope(
|
||||
self,
|
||||
original: Callable[..., Any],
|
||||
override: Callable[..., Any],
|
||||
) -> Iterator[None]:
|
||||
self.dependency_overrides[original] = override
|
||||
yield
|
||||
self.dependency_overrides.pop(original, None)
|
||||
|
||||
|
||||
dependency_provider = Provider()
|
||||
10
mm_agents/coact/autogen/fast_depends/library/__init__.py
Normal file
10
mm_agents/coact/autogen/fast_depends/library/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .model import CustomField
|
||||
|
||||
__all__ = ("CustomField",)
|
||||
46
mm_agents/coact/autogen/fast_depends/library/model.py
Normal file
46
mm_agents/coact/autogen/fast_depends/library/model.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
|
||||
Cls = TypeVar("Cls", bound="CustomField")
|
||||
|
||||
|
||||
class CustomField(ABC):
|
||||
param_name: Optional[str]
|
||||
cast: bool
|
||||
required: bool
|
||||
|
||||
__slots__ = (
|
||||
"cast",
|
||||
"param_name",
|
||||
"required",
|
||||
"field",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cast: bool = True,
|
||||
required: bool = True,
|
||||
) -> None:
|
||||
self.cast = cast
|
||||
self.param_name = None
|
||||
self.required = required
|
||||
self.field = False
|
||||
|
||||
def set_param_name(self: Cls, name: str) -> Cls:
|
||||
self.param_name = name
|
||||
return self
|
||||
|
||||
def use(self, /, **kwargs: Any) -> Dict[str, Any]:
|
||||
assert self.param_name, "You should specify `param_name` before using"
|
||||
return kwargs
|
||||
|
||||
def use_field(self, kwargs: Dict[str, Any]) -> None:
|
||||
raise NotImplementedError("You should implement `use_field` method.")
|
||||
6
mm_agents/coact/autogen/fast_depends/py.typed
Normal file
6
mm_agents/coact/autogen/fast_depends/py.typed
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
66
mm_agents/coact/autogen/fast_depends/schema.py
Normal file
66
mm_agents/coact/autogen/fast_depends/schema.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ._compat import PYDANTIC_V2, create_model, model_schema
|
||||
from .core import CallModel
|
||||
|
||||
|
||||
def get_schema(
|
||||
call: CallModel[Any, Any],
|
||||
embed: bool = False,
|
||||
resolve_refs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
assert call.model, "Call should has a model"
|
||||
params_model = create_model( # type: ignore[call-overload]
|
||||
call.model.__name__, **call.flat_params
|
||||
)
|
||||
|
||||
body: Dict[str, Any] = model_schema(params_model)
|
||||
|
||||
if not call.flat_params:
|
||||
body = {"title": body["title"], "type": "null"}
|
||||
|
||||
if resolve_refs:
|
||||
pydantic_key = "$defs" if PYDANTIC_V2 else "definitions"
|
||||
body = _move_pydantic_refs(body, pydantic_key)
|
||||
body.pop(pydantic_key, None)
|
||||
|
||||
if embed and len(body["properties"]) == 1:
|
||||
body = list(body["properties"].values())[0]
|
||||
|
||||
return body
|
||||
|
||||
|
||||
def _move_pydantic_refs(original: Any, key: str, refs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
if not isinstance(original, Dict):
|
||||
return original
|
||||
|
||||
data = original.copy()
|
||||
|
||||
if refs is None:
|
||||
raw_refs = data.get(key, {})
|
||||
refs = _move_pydantic_refs(raw_refs, key, raw_refs)
|
||||
|
||||
name: Optional[str] = None
|
||||
for k in data:
|
||||
if k == "$ref":
|
||||
name = data[k].replace(f"#/{key}/", "")
|
||||
|
||||
elif isinstance(data[k], dict):
|
||||
data[k] = _move_pydantic_refs(data[k], key, refs)
|
||||
|
||||
elif isinstance(data[k], List):
|
||||
for i in range(len(data[k])):
|
||||
data[k][i] = _move_pydantic_refs(data[k][i], key, refs)
|
||||
|
||||
if name:
|
||||
assert refs, "Smth wrong"
|
||||
data = refs[name]
|
||||
|
||||
return data
|
||||
280
mm_agents/coact/autogen/fast_depends/use.py
Normal file
280
mm_agents/coact/autogen/fast_depends/use.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from contextlib import AsyncExitStack, ExitStack
|
||||
from functools import partial, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterator,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from ._compat import ConfigDict
|
||||
from .core import CallModel, build_call_model
|
||||
from .dependencies import dependency_provider, model
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def Depends( # noqa: N802
|
||||
dependency: Callable[P, T],
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
cast: bool = True,
|
||||
) -> Any:
|
||||
return model.Depends(
|
||||
dependency=dependency,
|
||||
use_cache=use_cache,
|
||||
cast=cast,
|
||||
)
|
||||
|
||||
|
||||
class _InjectWrapper(Protocol[P, T]):
|
||||
def __call__(
|
||||
self,
|
||||
func: Callable[P, T],
|
||||
model: Optional[CallModel[P, T]] = None,
|
||||
) -> Callable[P, T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def inject( # pragma: no cover
|
||||
func: None,
|
||||
*,
|
||||
cast: bool = True,
|
||||
extra_dependencies: Sequence[model.Depends] = (),
|
||||
pydantic_config: Optional[ConfigDict] = None,
|
||||
dependency_overrides_provider: Optional[Any] = dependency_provider,
|
||||
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
|
||||
) -> _InjectWrapper[P, T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def inject( # pragma: no cover
|
||||
func: Callable[P, T],
|
||||
*,
|
||||
cast: bool = True,
|
||||
extra_dependencies: Sequence[model.Depends] = (),
|
||||
pydantic_config: Optional[ConfigDict] = None,
|
||||
dependency_overrides_provider: Optional[Any] = dependency_provider,
|
||||
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
|
||||
) -> Callable[P, T]: ...
|
||||
|
||||
|
||||
def inject(
|
||||
func: Optional[Callable[P, T]] = None,
|
||||
*,
|
||||
cast: bool = True,
|
||||
extra_dependencies: Sequence[model.Depends] = (),
|
||||
pydantic_config: Optional[ConfigDict] = None,
|
||||
dependency_overrides_provider: Optional[Any] = dependency_provider,
|
||||
wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
|
||||
) -> Union[
|
||||
Callable[P, T],
|
||||
_InjectWrapper[P, T],
|
||||
]:
|
||||
decorator = _wrap_inject(
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
wrap_model=wrap_model,
|
||||
extra_dependencies=extra_dependencies,
|
||||
cast=cast,
|
||||
pydantic_config=pydantic_config,
|
||||
)
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
|
||||
else:
|
||||
return decorator(func)
|
||||
|
||||
|
||||
def _wrap_inject(
|
||||
dependency_overrides_provider: Optional[Any],
|
||||
wrap_model: Callable[
|
||||
[CallModel[P, T]],
|
||||
CallModel[P, T],
|
||||
],
|
||||
extra_dependencies: Sequence[model.Depends],
|
||||
cast: bool,
|
||||
pydantic_config: Optional[ConfigDict],
|
||||
) -> _InjectWrapper[P, T]:
|
||||
if (
|
||||
dependency_overrides_provider
|
||||
and getattr(dependency_overrides_provider, "dependency_overrides", None) is not None
|
||||
):
|
||||
overrides = dependency_overrides_provider.dependency_overrides
|
||||
else:
|
||||
overrides = None
|
||||
|
||||
def func_wrapper(
|
||||
func: Callable[P, T],
|
||||
model: Optional[CallModel[P, T]] = None,
|
||||
) -> Callable[P, T]:
|
||||
if model is None:
|
||||
real_model = wrap_model(
|
||||
build_call_model(
|
||||
call=func,
|
||||
extra_dependencies=extra_dependencies,
|
||||
cast=cast,
|
||||
pydantic_config=pydantic_config,
|
||||
)
|
||||
)
|
||||
else:
|
||||
real_model = model
|
||||
|
||||
if real_model.is_async:
|
||||
injected_wrapper: Callable[P, T]
|
||||
|
||||
if real_model.is_generator:
|
||||
injected_wrapper = partial(solve_async_gen, real_model, overrides) # type: ignore[assignment]
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
async def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
async with AsyncExitStack() as stack:
|
||||
r = await real_model.asolve(
|
||||
*args,
|
||||
stack=stack,
|
||||
dependency_overrides=overrides,
|
||||
cache_dependencies={},
|
||||
nested=False,
|
||||
**kwargs,
|
||||
)
|
||||
return r
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
else:
|
||||
if real_model.is_generator:
|
||||
injected_wrapper = partial(solve_gen, real_model, overrides) # type: ignore[assignment]
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
with ExitStack() as stack:
|
||||
r = real_model.solve(
|
||||
*args,
|
||||
stack=stack,
|
||||
dependency_overrides=overrides,
|
||||
cache_dependencies={},
|
||||
nested=False,
|
||||
**kwargs,
|
||||
)
|
||||
return r
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
return injected_wrapper
|
||||
|
||||
return func_wrapper
|
||||
|
||||
|
||||
class solve_async_gen: # noqa: N801
|
||||
_iter: Optional[AsyncIterator[Any]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: "CallModel[..., Any]",
|
||||
overrides: Optional[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.call = model
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.overrides = overrides
|
||||
|
||||
def __aiter__(self) -> "solve_async_gen":
|
||||
self._iter = None
|
||||
self.stack = AsyncExitStack()
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Any:
|
||||
if self._iter is None:
|
||||
stack = self.stack = AsyncExitStack()
|
||||
await self.stack.__aenter__()
|
||||
self._iter = cast(
|
||||
AsyncIterator[Any],
|
||||
(
|
||||
await self.call.asolve(
|
||||
*self.args,
|
||||
stack=stack,
|
||||
dependency_overrides=self.overrides,
|
||||
cache_dependencies={},
|
||||
nested=False,
|
||||
**self.kwargs,
|
||||
)
|
||||
).__aiter__(),
|
||||
)
|
||||
|
||||
try:
|
||||
r = await self._iter.__anext__()
|
||||
except StopAsyncIteration as e:
|
||||
await self.stack.__aexit__(None, None, None)
|
||||
raise e
|
||||
else:
|
||||
return r
|
||||
|
||||
|
||||
class solve_gen: # noqa: N801
|
||||
_iter: Optional[Iterator[Any]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: "CallModel[..., Any]",
|
||||
overrides: Optional[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.call = model
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.overrides = overrides
|
||||
|
||||
def __iter__(self) -> "solve_gen":
|
||||
self._iter = None
|
||||
self.stack = ExitStack()
|
||||
return self
|
||||
|
||||
def __next__(self) -> Any:
|
||||
if self._iter is None:
|
||||
stack = self.stack = ExitStack()
|
||||
self.stack.__enter__()
|
||||
self._iter = cast(
|
||||
Iterator[Any],
|
||||
iter(
|
||||
self.call.solve(
|
||||
*self.args,
|
||||
stack=stack,
|
||||
dependency_overrides=self.overrides,
|
||||
cache_dependencies={},
|
||||
nested=False,
|
||||
**self.kwargs,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
r = next(self._iter)
|
||||
except StopIteration as e:
|
||||
self.stack.__exit__(None, None, None)
|
||||
raise e
|
||||
else:
|
||||
return r
|
||||
187
mm_agents/coact/autogen/fast_depends/utils.py
Normal file
187
mm_agents/coact/autogen/fast_depends/utils.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterable,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
List,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import anyio
|
||||
from typing_extensions import (
|
||||
Annotated,
|
||||
ParamSpec,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from ._compat import evaluate_forwardref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_async(
|
||||
func: Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> T:
|
||||
if is_coroutine_callable(func):
|
||||
return await cast(Callable[P, Awaitable[T]], func)(*args, **kwargs)
|
||||
else:
|
||||
return await run_in_threadpool(cast(Callable[P, T], func), *args, **kwargs)
|
||||
|
||||
|
||||
async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if kwargs:
|
||||
func = functools.partial(func, **kwargs)
|
||||
return await anyio.to_thread.run_sync(func, *args)
|
||||
|
||||
|
||||
async def solve_generator_async(
|
||||
*sub_args: Any, call: Callable[..., Any], stack: AsyncExitStack, **sub_values: Any
|
||||
) -> Any:
|
||||
if is_gen_callable(call):
|
||||
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
|
||||
elif is_async_gen_callable(call): # pragma: no branch
|
||||
cm = asynccontextmanager(call)(*sub_args, **sub_values)
|
||||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
def solve_generator_sync(*sub_args: Any, call: Callable[..., Any], stack: ExitStack, **sub_values: Any) -> Any:
|
||||
cm = contextmanager(call)(*sub_args, **sub_values)
|
||||
return stack.enter_context(cm)
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> Tuple[inspect.Signature, Any]:
|
||||
signature = inspect.signature(call)
|
||||
|
||||
locals = collect_outer_stack_locals()
|
||||
|
||||
# We unwrap call to get the original unwrapped function
|
||||
call = inspect.unwrap(call)
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_typed_annotation(
|
||||
param.annotation,
|
||||
globalns,
|
||||
locals,
|
||||
),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
|
||||
return inspect.Signature(typed_params), get_typed_annotation(
|
||||
signature.return_annotation,
|
||||
globalns,
|
||||
locals,
|
||||
)
|
||||
|
||||
|
||||
def collect_outer_stack_locals() -> Dict[str, Any]:
|
||||
frame = inspect.currentframe()
|
||||
|
||||
frames: List[FrameType] = []
|
||||
while frame is not None:
|
||||
if "fast_depends" not in frame.f_code.co_filename:
|
||||
frames.append(frame)
|
||||
frame = frame.f_back
|
||||
|
||||
locals = {}
|
||||
for f in frames[::-1]:
|
||||
locals.update(f.f_locals)
|
||||
|
||||
return locals
|
||||
|
||||
|
||||
def get_typed_annotation(
|
||||
annotation: Any,
|
||||
globalns: Dict[str, Any],
|
||||
locals: Dict[str, Any],
|
||||
) -> Any:
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
|
||||
if isinstance(annotation, ForwardRef):
|
||||
annotation = evaluate_forwardref(annotation, globalns, locals)
|
||||
|
||||
if get_origin(annotation) is Annotated and (args := get_args(annotation)):
|
||||
solved_args = [get_typed_annotation(x, globalns, locals) for x in args]
|
||||
annotation.__origin__, annotation.__metadata__ = solved_args[0], tuple(solved_args[1:])
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def contextmanager_in_threadpool(
|
||||
cm: ContextManager[T],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
exit_limiter = anyio.CapacityLimiter(1)
|
||||
try:
|
||||
yield await run_in_threadpool(cm.__enter__)
|
||||
except Exception as e:
|
||||
ok = bool(await anyio.to_thread.run_sync(cm.__exit__, type(e), e, None, limiter=exit_limiter))
|
||||
if not ok: # pragma: no branch
|
||||
raise e
|
||||
else:
|
||||
await anyio.to_thread.run_sync(cm.__exit__, None, None, None, limiter=exit_limiter)
|
||||
|
||||
|
||||
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isgeneratorfunction(call):
|
||||
return True
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return inspect.isgeneratorfunction(dunder_call)
|
||||
|
||||
|
||||
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isasyncgenfunction(call):
|
||||
return True
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return inspect.isasyncgenfunction(dunder_call)
|
||||
|
||||
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isclass(call):
|
||||
return False
|
||||
|
||||
if asyncio.iscoroutinefunction(call):
|
||||
return True
|
||||
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return asyncio.iscoroutinefunction(dunder_call)
|
||||
|
||||
|
||||
async def async_map(func: Callable[..., T], async_iterable: AsyncIterable[Any]) -> AsyncIterable[T]:
|
||||
async for i in async_iterable:
|
||||
yield func(i)
|
||||
Reference in New Issue
Block a user