CoACT initialize (#292)
This commit is contained in:
526
mm_agents/coact/autogen/import_utils.py
Normal file
526
mm_agents/coact/autogen/import_utils.py
Normal file
@@ -0,0 +1,526 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager, suppress
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Generic, Iterable, Optional, TypeVar, Union
|
||||
|
||||
__all__ = [
|
||||
"optional_import_block",
|
||||
"patch_object",
|
||||
"require_optional_import",
|
||||
"run_for_optional_imports",
|
||||
"skip_on_missing_imports",
|
||||
]
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleInfo:
|
||||
name: str
|
||||
min_version: Optional[str] = None
|
||||
max_version: Optional[str] = None
|
||||
min_inclusive: bool = False
|
||||
max_inclusive: bool = False
|
||||
|
||||
def is_in_sys_modules(self) -> Optional[str]:
|
||||
"""Check if the module is installed and satisfies the version constraints
|
||||
|
||||
Returns:
|
||||
None if the module is installed and satisfies the version constraints, otherwise a message indicating the issue.
|
||||
|
||||
"""
|
||||
if self.name not in sys.modules:
|
||||
return f"'{self.name}' is not installed."
|
||||
else:
|
||||
if hasattr(sys.modules[self.name], "__file__") and sys.modules[self.name].__file__ is not None:
|
||||
autogen_path = (Path(__file__).parent).resolve()
|
||||
test_path = (Path(__file__).parent.parent / "test").resolve()
|
||||
module_path = Path(sys.modules[self.name].__file__).resolve() # type: ignore[arg-type]
|
||||
|
||||
if str(autogen_path) in str(module_path) or str(test_path) in str(module_path):
|
||||
# The module is in the autogen or test directory
|
||||
# Aka similarly named module in the autogen or test directory
|
||||
return f"'{self.name}' is not installed."
|
||||
|
||||
installed_version = (
|
||||
sys.modules[self.name].__version__ if hasattr(sys.modules[self.name], "__version__") else None
|
||||
)
|
||||
if installed_version is None and (self.min_version or self.max_version):
|
||||
return f"'{self.name}' is installed, but the version is not available."
|
||||
|
||||
if self.min_version:
|
||||
msg = f"'{self.name}' is installed, but the installed version {installed_version} is too low (required '{self}')."
|
||||
if not self.min_inclusive and installed_version == self.min_version:
|
||||
return msg
|
||||
if self.min_inclusive and installed_version < self.min_version: # type: ignore[operator]
|
||||
return msg
|
||||
|
||||
if self.max_version:
|
||||
msg = f"'{self.name}' is installed, but the installed version {installed_version} is too high (required '{self}')."
|
||||
if not self.max_inclusive and installed_version == self.max_version:
|
||||
return msg
|
||||
if self.max_inclusive and installed_version > self.max_version: # type: ignore[operator]
|
||||
return msg
|
||||
|
||||
return None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = self.name
|
||||
if self.min_version:
|
||||
s += f">={self.min_version}" if self.min_inclusive else f">{self.min_version}"
|
||||
if self.max_version:
|
||||
s += f"<={self.max_version}" if self.max_inclusive else f"<{self.max_version}"
|
||||
return s
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, module_info: str) -> "ModuleInfo":
|
||||
"""Parse a string to create a ModuleInfo object
|
||||
|
||||
Args:
|
||||
module_info (str): A string containing the module name and optional version constraints
|
||||
|
||||
Returns:
|
||||
ModuleInfo: A ModuleInfo object with the parsed information
|
||||
|
||||
Raises:
|
||||
ValueError: If the module information is invalid
|
||||
"""
|
||||
|
||||
pattern = re.compile(r"^(?P<name>[a-zA-Z0-9-_]+)(?P<constraint>.*)$")
|
||||
match = pattern.match(module_info.strip())
|
||||
|
||||
if not match:
|
||||
raise ValueError(f"Invalid package information: {module_info}")
|
||||
|
||||
name = match.group("name")
|
||||
constraints = match.group("constraint").strip()
|
||||
min_version = max_version = None
|
||||
min_inclusive = max_inclusive = False
|
||||
|
||||
if constraints:
|
||||
constraint_pattern = re.findall(r"(>=|<=|>|<)([0-9\.]+)?", constraints)
|
||||
|
||||
if not all(version for _, version in constraint_pattern):
|
||||
raise ValueError(f"Invalid module information: {module_info}")
|
||||
|
||||
for operator, version in constraint_pattern:
|
||||
if operator == ">=":
|
||||
min_version = version
|
||||
min_inclusive = True
|
||||
elif operator == "<=":
|
||||
max_version = version
|
||||
max_inclusive = True
|
||||
elif operator == ">":
|
||||
min_version = version
|
||||
min_inclusive = False
|
||||
elif operator == "<":
|
||||
max_version = version
|
||||
max_inclusive = False
|
||||
else:
|
||||
raise ValueError(f"Invalid package information: {module_info}")
|
||||
|
||||
return ModuleInfo(
|
||||
name=name,
|
||||
min_version=min_version,
|
||||
max_version=max_version,
|
||||
min_inclusive=min_inclusive,
|
||||
max_inclusive=max_inclusive,
|
||||
)
|
||||
|
||||
|
||||
class Result:
|
||||
def __init__(self) -> None:
|
||||
self._failed: Optional[bool] = None
|
||||
|
||||
@property
|
||||
def is_successful(self) -> bool:
|
||||
if self._failed is None:
|
||||
raise ValueError("Result not set")
|
||||
return not self._failed
|
||||
|
||||
|
||||
@contextmanager
|
||||
def optional_import_block() -> Generator[Result, None, None]:
|
||||
"""Guard a block of code to suppress ImportErrors
|
||||
|
||||
A context manager to temporarily suppress ImportErrors.
|
||||
Use this to attempt imports without failing immediately on missing modules.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with optional_import_block():
|
||||
import some_module
|
||||
import some_other_module
|
||||
```
|
||||
"""
|
||||
result = Result()
|
||||
try:
|
||||
yield result
|
||||
result._failed = False
|
||||
except ImportError as e:
|
||||
# Ignore ImportErrors during this context
|
||||
logger.debug(f"Ignoring ImportError: {e}")
|
||||
result._failed = True
|
||||
|
||||
|
||||
def get_missing_imports(modules: Union[str, Iterable[str]]) -> dict[str, str]:
|
||||
"""Get missing modules from a list of module names
|
||||
|
||||
Args:
|
||||
modules (Union[str, Iterable[str]]): Module name or list of module names
|
||||
|
||||
Returns:
|
||||
List of missing module names
|
||||
"""
|
||||
if isinstance(modules, str):
|
||||
modules = [modules]
|
||||
|
||||
module_infos = [ModuleInfo.from_str(module) for module in modules]
|
||||
x = {m.name: m.is_in_sys_modules() for m in module_infos}
|
||||
return {k: v for k, v in x.items() if v}
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
G = TypeVar("G", bound=Union[Callable[..., Any], type])
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class PatchObject(ABC, Generic[T]):
|
||||
def __init__(self, o: T, missing_modules: dict[str, str], dep_target: str):
|
||||
if not self.accept(o):
|
||||
raise ValueError(f"Cannot patch object of type {type(o)}")
|
||||
|
||||
self.o = o
|
||||
self.missing_modules = missing_modules
|
||||
self.dep_target = dep_target
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def accept(cls, o: Any) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def patch(self, except_for: Iterable[str]) -> T: ...
|
||||
|
||||
def get_object_with_metadata(self) -> Any:
|
||||
return self.o
|
||||
|
||||
@property
|
||||
def msg(self) -> str:
|
||||
o = self.get_object_with_metadata()
|
||||
plural = len(self.missing_modules) > 1
|
||||
fqn = f"{o.__module__}.{o.__name__}" if hasattr(o, "__module__") else o.__name__
|
||||
# modules_str = ", ".join([f"'{m}'" for m in self.missing_modules])
|
||||
msg = f"{'Modules' if plural else 'A module'} needed for {fqn} {'are' if plural else 'is'} missing:\n"
|
||||
for _, status in self.missing_modules.items():
|
||||
msg += f" - {status}\n"
|
||||
msg += f"Please install {'them' if plural else 'it'} using:\n'pip install ag2[{self.dep_target}]'"
|
||||
return msg
|
||||
|
||||
def copy_metadata(self, retval: T) -> None:
|
||||
"""Copy metadata from original object to patched object
|
||||
|
||||
Args:
|
||||
retval: Patched object
|
||||
|
||||
"""
|
||||
o = self.o
|
||||
if hasattr(o, "__doc__"):
|
||||
retval.__doc__ = o.__doc__
|
||||
if hasattr(o, "__name__"):
|
||||
retval.__name__ = o.__name__ # type: ignore[attr-defined]
|
||||
if hasattr(o, "__module__"):
|
||||
retval.__module__ = o.__module__
|
||||
|
||||
_registry: list[type["PatchObject[Any]"]] = []
|
||||
|
||||
@classmethod
|
||||
def register(cls) -> Callable[[type["PatchObject[Any]"]], type["PatchObject[Any]"]]:
|
||||
def decorator(subclass: type["PatchObject[Any]"]) -> type["PatchObject[Any]"]:
|
||||
cls._registry.append(subclass)
|
||||
return subclass
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
o: T,
|
||||
*,
|
||||
missing_modules: dict[str, str],
|
||||
dep_target: str,
|
||||
) -> Optional["PatchObject[T]"]:
|
||||
for subclass in cls._registry:
|
||||
if subclass.accept(o):
|
||||
return subclass(o, missing_modules, dep_target)
|
||||
return None
|
||||
|
||||
|
||||
@PatchObject.register()
|
||||
class PatchCallable(PatchObject[F]):
|
||||
@classmethod
|
||||
def accept(cls, o: Any) -> bool:
|
||||
return inspect.isfunction(o) or inspect.ismethod(o)
|
||||
|
||||
def patch(self, except_for: Iterable[str]) -> F:
|
||||
if self.o.__name__ in except_for:
|
||||
return self.o
|
||||
|
||||
f: Callable[..., Any] = self.o
|
||||
|
||||
# @wraps(f.__call__) # type: ignore[operator]
|
||||
@wraps(f)
|
||||
def _call(*args: Any, **kwargs: Any) -> Any:
|
||||
raise ImportError(self.msg)
|
||||
|
||||
self.copy_metadata(_call) # type: ignore[arg-type]
|
||||
|
||||
return _call # type: ignore[return-value]
|
||||
|
||||
|
||||
@PatchObject.register()
|
||||
class PatchStatic(PatchObject[F]):
|
||||
@classmethod
|
||||
def accept(cls, o: Any) -> bool:
|
||||
# return inspect.ismethoddescriptor(o)
|
||||
return isinstance(o, staticmethod)
|
||||
|
||||
def patch(self, except_for: Iterable[str]) -> F:
|
||||
if hasattr(self.o, "__name__"):
|
||||
name = self.o.__name__
|
||||
elif hasattr(self.o, "__func__"):
|
||||
name = self.o.__func__.__name__
|
||||
else:
|
||||
raise ValueError(f"Cannot determine name for object {self.o}")
|
||||
if name in except_for:
|
||||
return self.o
|
||||
|
||||
f: Callable[..., Any] = self.o.__func__ # type: ignore[attr-defined]
|
||||
|
||||
@wraps(f)
|
||||
def _call(*args: Any, **kwargs: Any) -> Any:
|
||||
raise ImportError(self.msg)
|
||||
|
||||
self.copy_metadata(_call) # type: ignore[arg-type]
|
||||
|
||||
return staticmethod(_call) # type: ignore[return-value]
|
||||
|
||||
def get_object_with_metadata(self) -> Any:
|
||||
return self.o.__func__ # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@PatchObject.register()
|
||||
class PatchInit(PatchObject[F]):
|
||||
@classmethod
|
||||
def accept(cls, o: Any) -> bool:
|
||||
return inspect.ismethoddescriptor(o) and o.__name__ == "__init__"
|
||||
|
||||
def patch(self, except_for: Iterable[str]) -> F:
|
||||
if self.o.__name__ in except_for:
|
||||
return self.o
|
||||
|
||||
f: Callable[..., Any] = self.o
|
||||
|
||||
@wraps(f)
|
||||
def _call(*args: Any, **kwargs: Any) -> Any:
|
||||
raise ImportError(self.msg)
|
||||
|
||||
self.copy_metadata(_call) # type: ignore[arg-type]
|
||||
|
||||
return staticmethod(_call) # type: ignore[return-value]
|
||||
|
||||
def get_object_with_metadata(self) -> Any:
|
||||
return self.o
|
||||
|
||||
|
||||
@PatchObject.register()
|
||||
class PatchProperty(PatchObject[Any]):
|
||||
@classmethod
|
||||
def accept(cls, o: Any) -> bool:
|
||||
return inspect.isdatadescriptor(o) and hasattr(o, "fget")
|
||||
|
||||
def patch(self, except_for: Iterable[str]) -> property:
|
||||
if not hasattr(self.o, "fget"):
|
||||
raise ValueError(f"Cannot patch property without getter: {self.o}")
|
||||
f: Callable[..., Any] = self.o.fget
|
||||
|
||||
if f.__name__ in except_for:
|
||||
return self.o # type: ignore[no-any-return]
|
||||
|
||||
@wraps(f)
|
||||
def _call(*args: Any, **kwargs: Any) -> Any:
|
||||
raise ImportError(self.msg)
|
||||
|
||||
self.copy_metadata(_call)
|
||||
|
||||
return property(_call)
|
||||
|
||||
def get_object_with_metadata(self) -> Any:
|
||||
return self.o.fget
|
||||
|
||||
|
||||
@PatchObject.register()
|
||||
class PatchClass(PatchObject[type[Any]]):
|
||||
@classmethod
|
||||
def accept(cls, o: Any) -> bool:
|
||||
return inspect.isclass(o)
|
||||
|
||||
def patch(self, except_for: Iterable[str]) -> type[Any]:
|
||||
if self.o.__name__ in except_for:
|
||||
return self.o
|
||||
|
||||
for name, member in inspect.getmembers(self.o):
|
||||
# Patch __init__ method if possible, but not other internal methods
|
||||
if name.startswith("__") and name != "__init__":
|
||||
continue
|
||||
patched = patch_object(
|
||||
member,
|
||||
missing_modules=self.missing_modules,
|
||||
dep_target=self.dep_target,
|
||||
fail_if_not_patchable=False,
|
||||
except_for=except_for,
|
||||
)
|
||||
with suppress(AttributeError):
|
||||
setattr(self.o, name, patched)
|
||||
|
||||
return self.o
|
||||
|
||||
|
||||
def patch_object(
|
||||
o: T,
|
||||
*,
|
||||
missing_modules: dict[str, str],
|
||||
dep_target: str,
|
||||
fail_if_not_patchable: bool = True,
|
||||
except_for: Optional[Union[str, Iterable[str]]] = None,
|
||||
) -> T:
|
||||
patcher = PatchObject.create(o, missing_modules=missing_modules, dep_target=dep_target)
|
||||
if fail_if_not_patchable and patcher is None:
|
||||
raise ValueError(f"Cannot patch object of type {type(o)}")
|
||||
|
||||
except_for = except_for if except_for is not None else []
|
||||
except_for = [except_for] if isinstance(except_for, str) else except_for
|
||||
|
||||
return patcher.patch(except_for=except_for) if patcher else o
|
||||
|
||||
|
||||
def require_optional_import(
|
||||
modules: Union[str, Iterable[str]],
|
||||
dep_target: str,
|
||||
*,
|
||||
except_for: Optional[Union[str, Iterable[str]]] = None,
|
||||
) -> Callable[[T], T]:
|
||||
"""Decorator to handle optional module dependencies
|
||||
|
||||
Args:
|
||||
modules: Module name or list of module names required
|
||||
dep_target: Target name for pip installation (e.g. 'test' in pip install ag2[test])
|
||||
except_for: Name or list of names of objects to exclude from patching
|
||||
"""
|
||||
missing_modules = get_missing_imports(modules)
|
||||
|
||||
if not missing_modules:
|
||||
|
||||
def decorator(o: T) -> T:
|
||||
return o
|
||||
|
||||
else:
|
||||
|
||||
def decorator(o: T) -> T:
|
||||
return patch_object(o, missing_modules=missing_modules, dep_target=dep_target, except_for=except_for)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _mark_object(o: T, dep_target: str) -> T:
|
||||
import pytest
|
||||
|
||||
markname = dep_target.replace("-", "_")
|
||||
pytest_mark_markname = getattr(pytest.mark, markname)
|
||||
pytest_mark_o = pytest_mark_markname(o)
|
||||
|
||||
pytest_mark_o = pytest.mark.aux_neg_flag(pytest_mark_o)
|
||||
|
||||
return pytest_mark_o # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def run_for_optional_imports(modules: Union[str, Iterable[str]], dep_target: str) -> Callable[[G], G]:
|
||||
"""Decorator to run a test if and only if optional modules are installed
|
||||
|
||||
Args:
|
||||
modules: Module name or list of module names
|
||||
dep_target: Target name for pip installation (e.g. 'test' in pip install ag2[test])
|
||||
"""
|
||||
# missing_modules = get_missing_imports(modules)
|
||||
# if missing_modules:
|
||||
# raise ImportError(f"Missing module{'s' if len(missing_modules) > 1 else ''}: {', '.join(missing_modules)}. Install using 'pip install ag2[{dep_target}]'")
|
||||
|
||||
def decorator(o: G) -> G:
|
||||
missing_modules = get_missing_imports(modules)
|
||||
|
||||
if isinstance(o, type):
|
||||
wrapped = require_optional_import(modules, dep_target)(o)
|
||||
else:
|
||||
if inspect.iscoroutinefunction(o):
|
||||
|
||||
@wraps(o)
|
||||
async def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
if missing_modules:
|
||||
raise ImportError(
|
||||
f"Missing module{'s' if len(missing_modules) > 1 else ''}: {', '.join(missing_modules)}. Install using 'pip install ag2[{dep_target}]'"
|
||||
)
|
||||
return await o(*args, **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
@wraps(o)
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
if missing_modules:
|
||||
raise ImportError(
|
||||
f"Missing module{'s' if len(missing_modules) > 1 else ''}: {', '.join(missing_modules)}. Install using 'pip install ag2[{dep_target}]'"
|
||||
)
|
||||
return o(*args, **kwargs)
|
||||
|
||||
pytest_mark_o: G = _mark_object(wrapped, dep_target) # type: ignore[assignment]
|
||||
|
||||
return pytest_mark_o
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def skip_on_missing_imports(modules: Union[str, Iterable[str]], dep_target: str) -> Callable[[T], T]:
|
||||
"""Decorator to skip a test if an optional module is missing
|
||||
|
||||
Args:
|
||||
modules: Module name or list of module names
|
||||
dep_target: Target name for pip installation (e.g. 'test' in pip install ag2[test])
|
||||
"""
|
||||
import pytest
|
||||
|
||||
missing_modules = get_missing_imports(modules)
|
||||
|
||||
if not missing_modules:
|
||||
|
||||
def decorator(o: T) -> T:
|
||||
pytest_mark_o = _mark_object(o, dep_target)
|
||||
return pytest_mark_o # type: ignore[no-any-return]
|
||||
|
||||
else:
|
||||
|
||||
def decorator(o: T) -> T:
|
||||
pytest_mark_o = _mark_object(o, dep_target)
|
||||
|
||||
return pytest.mark.skip( # type: ignore[return-value,no-any-return]
|
||||
f"Missing module{'s' if len(missing_modules) > 1 else ''}: {', '.join(missing_modules)}. Install using 'pip install ag2[{dep_target}]'"
|
||||
)(pytest_mark_o)
|
||||
|
||||
return decorator
|
||||
Reference in New Issue
Block a user