239 lines
9.7 KiB
Python
239 lines
9.7 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import ast
|
|
import re
|
|
from dataclasses import dataclass
|
|
|
|
from ...doc_utils import export_module
|
|
from .context_variables import ContextVariables
|
|
|
|
|
|
@dataclass
|
|
@export_module("autogen")
|
|
class ContextExpression:
|
|
"""A class to evaluate logical expressions using context variables.
|
|
|
|
Args:
|
|
expression (str): A string containing a logical expression with context variable references.
|
|
- Variable references use ${var_name} syntax: ${logged_in}, ${attempts}
|
|
- String literals can use normal quotes: 'hello', "world"
|
|
- Supported operators:
|
|
- Logical: not/!, and/&, or/|
|
|
- Comparison: >, <, >=, <=, ==, !=
|
|
- Supported functions:
|
|
- len(${var_name}): Gets the length of a list, string, or other collection
|
|
- Parentheses can be used for grouping
|
|
- Examples:
|
|
- "not ${logged_in} and ${is_admin} or ${guest_checkout}"
|
|
- "!${logged_in} & ${is_admin} | ${guest_checkout}"
|
|
- "len(${orders}) > 0 & ${user_active}"
|
|
- "len(${cart_items}) == 0 | ${checkout_started}"
|
|
|
|
Raises:
|
|
SyntaxError: If the expression cannot be parsed
|
|
ValueError: If the expression contains disallowed operations
|
|
"""
|
|
|
|
expression: str
|
|
|
|
def __post_init__(self) -> None:
|
|
# Validate the expression immediately upon creation
|
|
try:
|
|
# Extract variable references and replace with placeholders
|
|
self._variable_names = self._extract_variable_names(self.expression)
|
|
|
|
# Convert symbolic operators to Python keywords
|
|
python_expr = self._convert_to_python_syntax(self.expression)
|
|
|
|
# Sanitize for AST parsing
|
|
sanitized_expr = self._prepare_for_ast(python_expr)
|
|
|
|
# Use ast to parse and validate the expression
|
|
self._ast = ast.parse(sanitized_expr, mode="eval")
|
|
|
|
# Verify it only contains allowed operations
|
|
self._validate_operations(self._ast.body)
|
|
|
|
# Store the Python-syntax version for evaluation
|
|
self._python_expr = python_expr
|
|
|
|
except SyntaxError as e:
|
|
raise SyntaxError(f"Invalid expression syntax in '{self.expression}': {str(e)}")
|
|
except Exception as e:
|
|
raise ValueError(f"Error validating expression '{self.expression}': {str(e)}")
|
|
|
|
def _extract_variable_names(self, expr: str) -> list[str]:
|
|
"""Extract all variable references ${var_name} from the expression."""
|
|
# Find all patterns like ${var_name}
|
|
matches = re.findall(r"\${([^}]*)}", expr)
|
|
return matches
|
|
|
|
def _convert_to_python_syntax(self, expr: str) -> str:
|
|
"""Convert symbolic operators to Python keywords."""
|
|
# We need to be careful about operators inside string literals
|
|
# First, temporarily replace string literals with placeholders
|
|
string_literals = []
|
|
|
|
def replace_string_literal(match: re.Match[str]) -> str:
|
|
string_literals.append(match.group(0))
|
|
return f"__STRING_LITERAL_{len(string_literals) - 1}__"
|
|
|
|
# Replace both single and double quoted strings
|
|
expr_without_strings = re.sub(r"'[^']*'|\"[^\"]*\"", replace_string_literal, expr)
|
|
|
|
# Handle the NOT operator (!) - no parentheses handling needed
|
|
# Replace standalone ! before variables or expressions
|
|
expr_without_strings = re.sub(r"!\s*(\${|\()", "not \\1", expr_without_strings)
|
|
|
|
# Handle AND and OR operators - simpler approach without parentheses handling
|
|
expr_without_strings = re.sub(r"\s+&\s+", " and ", expr_without_strings)
|
|
expr_without_strings = re.sub(r"\s+\|\s+", " or ", expr_without_strings)
|
|
|
|
# Now put string literals back
|
|
for i, literal in enumerate(string_literals):
|
|
expr_without_strings = expr_without_strings.replace(f"__STRING_LITERAL_{i}__", literal)
|
|
|
|
return expr_without_strings
|
|
|
|
def _prepare_for_ast(self, expr: str) -> str:
|
|
"""Convert the expression to valid Python for AST parsing by replacing variables with placeholders."""
|
|
# Replace ${var_name} with var_name for AST parsing
|
|
processed_expr = expr
|
|
for var_name in self._variable_names:
|
|
processed_expr = processed_expr.replace(f"${{{var_name}}}", var_name)
|
|
|
|
return processed_expr
|
|
|
|
def _validate_operations(self, node: ast.AST) -> None:
|
|
"""Recursively validate that only allowed operations exist in the AST."""
|
|
allowed_node_types = (
|
|
# Boolean operations
|
|
ast.BoolOp,
|
|
ast.UnaryOp,
|
|
ast.And,
|
|
ast.Or,
|
|
ast.Not,
|
|
# Comparison operations
|
|
ast.Compare,
|
|
ast.Eq,
|
|
ast.NotEq,
|
|
ast.Lt,
|
|
ast.LtE,
|
|
ast.Gt,
|
|
ast.GtE,
|
|
# Basic nodes
|
|
ast.Name,
|
|
ast.Load,
|
|
ast.Constant,
|
|
ast.Expression,
|
|
# Support for basic numeric operations in comparisons
|
|
ast.Num,
|
|
ast.NameConstant,
|
|
# Support for negative numbers
|
|
ast.USub,
|
|
ast.UnaryOp,
|
|
# Support for string literals
|
|
ast.Str,
|
|
ast.Constant,
|
|
# Support for function calls (specifically len())
|
|
ast.Call,
|
|
)
|
|
|
|
if not isinstance(node, allowed_node_types):
|
|
raise ValueError(f"Operation type {type(node).__name__} is not allowed in logical expressions")
|
|
|
|
# Special validation for function calls - only allow len()
|
|
if isinstance(node, ast.Call):
|
|
if not (isinstance(node.func, ast.Name) and node.func.id == "len"):
|
|
raise ValueError(f"Only the len() function is allowed, got: {getattr(node.func, 'id', 'unknown')}")
|
|
if len(node.args) != 1:
|
|
raise ValueError(f"len() function must have exactly one argument, got {len(node.args)}")
|
|
|
|
# Special validation for Compare nodes
|
|
if isinstance(node, ast.Compare):
|
|
for op in node.ops:
|
|
if not isinstance(op, (ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE)):
|
|
raise ValueError(f"Comparison operator {type(op).__name__} is not allowed")
|
|
|
|
# Recursively check child nodes
|
|
for child in ast.iter_child_nodes(node):
|
|
self._validate_operations(child)
|
|
|
|
def evaluate(self, context_variables: ContextVariables) -> bool:
|
|
"""Evaluate the expression using the provided context variables.
|
|
|
|
Args:
|
|
context_variables: Dictionary of context variables to use for evaluation
|
|
|
|
Returns:
|
|
bool: The result of evaluating the expression
|
|
|
|
Raises:
|
|
KeyError: If a variable referenced in the expression is not found in the context
|
|
"""
|
|
# Create a modified expression that we can safely evaluate
|
|
eval_expr = self._python_expr # Use the Python-syntax version
|
|
|
|
# First, handle len() functions with variable references inside
|
|
len_pattern = r"len\(\${([^}]*)}\)"
|
|
len_matches = list(re.finditer(len_pattern, eval_expr))
|
|
|
|
# Process all len() operations first
|
|
for match in len_matches:
|
|
var_name = match.group(1)
|
|
# Check if variable exists in context, raise KeyError if not
|
|
if not context_variables.contains(var_name):
|
|
raise KeyError(f"Missing context variable: '{var_name}'")
|
|
|
|
var_value = context_variables.get(var_name)
|
|
|
|
# Calculate the length - works for lists, strings, dictionaries, etc.
|
|
try:
|
|
length_value = len(var_value) # type: ignore[arg-type]
|
|
except TypeError:
|
|
# If the value doesn't support len(), treat as 0
|
|
length_value = 0
|
|
|
|
# Replace the len() expression with the actual length
|
|
full_match = match.group(0)
|
|
eval_expr = eval_expr.replace(full_match, str(length_value))
|
|
|
|
# Then replace remaining variable references with their values
|
|
for var_name in self._variable_names:
|
|
# Skip variables that were already processed in len() expressions
|
|
if any(m.group(1) == var_name for m in len_matches):
|
|
continue
|
|
|
|
# Check if variable exists in context, raise KeyError if not
|
|
if not context_variables.contains(var_name):
|
|
raise KeyError(f"Missing context variable: '{var_name}'")
|
|
|
|
# Get the value from context
|
|
var_value = context_variables.get(var_name)
|
|
|
|
# Format the value appropriately based on its type
|
|
if isinstance(var_value, (bool, int, float)):
|
|
formatted_value = str(var_value)
|
|
elif isinstance(var_value, str):
|
|
formatted_value = f"'{var_value}'" # Quote strings
|
|
elif isinstance(var_value, (list, dict, tuple)):
|
|
# For collections, convert to their boolean evaluation
|
|
formatted_value = str(bool(var_value))
|
|
else:
|
|
formatted_value = str(var_value)
|
|
|
|
# Replace the variable reference with the formatted value
|
|
eval_expr = eval_expr.replace(f"${{{var_name}}}", formatted_value)
|
|
|
|
try:
|
|
return eval(eval_expr) # type: ignore[no-any-return]
|
|
except Exception as e:
|
|
raise ValueError(
|
|
f"Error evaluating expression '{self.expression}' (are you sure you're using ${{my_context_variable_key}}): {str(e)}"
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
return f"ContextExpression('{self.expression}')"
|