CoACT initialize (#292)
This commit is contained in:
187
mm_agents/coact/autogen/fast_depends/utils.py
Normal file
187
mm_agents/coact/autogen/fast_depends/utils.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterable,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
List,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import anyio
|
||||
from typing_extensions import (
|
||||
Annotated,
|
||||
ParamSpec,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from ._compat import evaluate_forwardref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_async(
|
||||
func: Union[
|
||||
Callable[P, T],
|
||||
Callable[P, Awaitable[T]],
|
||||
],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> T:
|
||||
if is_coroutine_callable(func):
|
||||
return await cast(Callable[P, Awaitable[T]], func)(*args, **kwargs)
|
||||
else:
|
||||
return await run_in_threadpool(cast(Callable[P, T], func), *args, **kwargs)
|
||||
|
||||
|
||||
async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if kwargs:
|
||||
func = functools.partial(func, **kwargs)
|
||||
return await anyio.to_thread.run_sync(func, *args)
|
||||
|
||||
|
||||
async def solve_generator_async(
|
||||
*sub_args: Any, call: Callable[..., Any], stack: AsyncExitStack, **sub_values: Any
|
||||
) -> Any:
|
||||
if is_gen_callable(call):
|
||||
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
|
||||
elif is_async_gen_callable(call): # pragma: no branch
|
||||
cm = asynccontextmanager(call)(*sub_args, **sub_values)
|
||||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
def solve_generator_sync(*sub_args: Any, call: Callable[..., Any], stack: ExitStack, **sub_values: Any) -> Any:
|
||||
cm = contextmanager(call)(*sub_args, **sub_values)
|
||||
return stack.enter_context(cm)
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> Tuple[inspect.Signature, Any]:
|
||||
signature = inspect.signature(call)
|
||||
|
||||
locals = collect_outer_stack_locals()
|
||||
|
||||
# We unwrap call to get the original unwrapped function
|
||||
call = inspect.unwrap(call)
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_typed_annotation(
|
||||
param.annotation,
|
||||
globalns,
|
||||
locals,
|
||||
),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
|
||||
return inspect.Signature(typed_params), get_typed_annotation(
|
||||
signature.return_annotation,
|
||||
globalns,
|
||||
locals,
|
||||
)
|
||||
|
||||
|
||||
def collect_outer_stack_locals() -> Dict[str, Any]:
|
||||
frame = inspect.currentframe()
|
||||
|
||||
frames: List[FrameType] = []
|
||||
while frame is not None:
|
||||
if "fast_depends" not in frame.f_code.co_filename:
|
||||
frames.append(frame)
|
||||
frame = frame.f_back
|
||||
|
||||
locals = {}
|
||||
for f in frames[::-1]:
|
||||
locals.update(f.f_locals)
|
||||
|
||||
return locals
|
||||
|
||||
|
||||
def get_typed_annotation(
|
||||
annotation: Any,
|
||||
globalns: Dict[str, Any],
|
||||
locals: Dict[str, Any],
|
||||
) -> Any:
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
|
||||
if isinstance(annotation, ForwardRef):
|
||||
annotation = evaluate_forwardref(annotation, globalns, locals)
|
||||
|
||||
if get_origin(annotation) is Annotated and (args := get_args(annotation)):
|
||||
solved_args = [get_typed_annotation(x, globalns, locals) for x in args]
|
||||
annotation.__origin__, annotation.__metadata__ = solved_args[0], tuple(solved_args[1:])
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def contextmanager_in_threadpool(
|
||||
cm: ContextManager[T],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
exit_limiter = anyio.CapacityLimiter(1)
|
||||
try:
|
||||
yield await run_in_threadpool(cm.__enter__)
|
||||
except Exception as e:
|
||||
ok = bool(await anyio.to_thread.run_sync(cm.__exit__, type(e), e, None, limiter=exit_limiter))
|
||||
if not ok: # pragma: no branch
|
||||
raise e
|
||||
else:
|
||||
await anyio.to_thread.run_sync(cm.__exit__, None, None, None, limiter=exit_limiter)
|
||||
|
||||
|
||||
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isgeneratorfunction(call):
|
||||
return True
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return inspect.isgeneratorfunction(dunder_call)
|
||||
|
||||
|
||||
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isasyncgenfunction(call):
|
||||
return True
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return inspect.isasyncgenfunction(dunder_call)
|
||||
|
||||
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isclass(call):
|
||||
return False
|
||||
|
||||
if asyncio.iscoroutinefunction(call):
|
||||
return True
|
||||
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return asyncio.iscoroutinefunction(dunder_call)
|
||||
|
||||
|
||||
async def async_map(func: Callable[..., T], async_iterable: AsyncIterable[Any]) -> AsyncIterable[T]:
|
||||
async for i in async_iterable:
|
||||
yield func(i)
|
||||
Reference in New Issue
Block a user