Files
sci-gui-agent-benchmark/desktop_env/evaluators/metrics/general.py

198 lines
6.1 KiB
Python

import csv
import json
import functools
import operator
import re
from numbers import Number
from typing import Callable, Any, Union
from typing import Dict, List, Pattern
import lxml.etree
from lxml.cssselect import CSSSelector
from lxml.etree import _Element
from rapidfuzz import fuzz
from .utils import _match_record, _match_value_to_rule
import sqlite3
def exact_match(result, rules) -> float:
expect = rules["expected"]
print(result, expect)
if result == expect:
return 1.
else:
return 0.
def fuzzy_match(result, rules) -> float:
expect = rules["expected"]
return fuzz.ratio(result, expect) / 100.
def check_csv(result: str, rules: Dict[str, List[Dict[str, str]]]) -> float:
"""
Args:
result (str): path to csv file
rules (Dict[str, List[Dict[str, str]]]): dict like
{
"expect": [{key: value}]
"unexpect": [{key: value}]
}
Returns:
float
"""
if result is None:
return 0.
expect_metrics = [False] * len(rules.get("expect", []))
unexpect_metric = True
with open(result) as f:
reader = csv.DictReader(f)
for rcd in reader:
for i, r in enumerate(rules.get("expect", [])):
expect_metrics[i] = expect_metrics[i] or _match_record(r, rcd)
unexpect_metric = unexpect_metric and not any(_match_record(r, rcd) for r in rules.get("unexpect", []))
return float(all(expect_metrics) and unexpect_metric)
def check_list(result: str, rules: Dict[str, List[str]]) -> float:
"""
Args:
result (str): path to list file
rules (Dict[str, List[str]]): dict like
{
"expect": list of str as regexes
"unexpect": list of str as regexes
}
Returns:
float
"""
if result is None:
return 0.
expect_patterns: List[Pattern[str]] = [re.compile(ptt) for ptt in rules.get("expect", [])]
unexpect_patterns: List[Pattern[str]] = [re.compile(ptt) for ptt in rules.get("unexpect", [])]
expect_metrics = [False] * len(expect_patterns)
unexpect_metric = True
with open(result) as f:
for l in f:
for i, r in enumerate(expect_patterns):
expect_metrics[i] = expect_metrics[i] or (r.search(l) is not None)
unexpect_metric = unexpect_metric and all(r.search(l) is None for r in unexpect_patterns)
return float(all(expect_metrics) and unexpect_metric)
_accessibility_ns_map = {"st": "uri:deskat:state.at-spi.gnome.org"
, "attr": "uri:deskat:attributes.at-spi.gnome.org"
, "cp": "uri:deskat:component.at-spi.gnome.org"
, "doc": "uri:deskat:document.at-spi.gnome.org"
, "docattr": "uri:deskat:attributes.document.at-spi.gnome.org"
, "txt": "uri:deskat:text.at-spi.gnome.org"
, "val": "uri:deskat:value.at-spi.gnome.org"
, "act": "uri:deskat:action.at-spi.gnome.org"
}
def check_accessibility_tree(result: str, rules: Dict[str, Any]) -> float:
"""
Args:
result (str): XML of GNOME Accessibility Tree
rules (Dict[str, Any]): dict like
{
"selectors": list of str as CSS selectors, will be connected by ", "
to form a composite selector. Only one from `selectors` and
`xpath` is needed. If both are present, `xpath` takes the
priority.
"xpath": str as xpath. Only one from `selectors` and `xpath` is
needed. If both are present, `xpath` takes the priority.
"text": str as the expected text content of the selected element.
"exact": bool specifying whether exact match or fuzzy match should
be performed. defaults to True
}
Returns:
float
"""
at: _Element = lxml.etree.fromstring(result)
if "xpath" in rules:
elements: List[_Element] = at.xpath(rules["xpath"], namespaces=_accessibility_ns_map)
elif "selectors" in rules:
selector = CSSSelector(", ".join(rules["selectors"]), namespaces=_accessibility_ns_map)
elements: List[_Element] = selector(at)
else:
raise ValueError("At least one of xpath and selectors is required")
if len(elements) == 0:
return 0.
if "text" in rules:
match_func: Callable[[str], Number] = functools.partial( operator.eq if rules["exact"]\
else (lambda a, b: fuzz.ratio(a, b)/100.)
, rules["text"]
)
match_score: Number = 0
for elm in elements:
match_score = max(match_score, match_func(elm.text or None))
else:
match_score = 1.
return float(match_score)
# def check_existence(result: str, *args) -> float:
# return 1. - (result is None)
def run_sqlite3(result: str, rules: Dict[str, Any]) -> float:
connection: sqlite3.Connection = sqlite3.connect(result)
cursor: sqlite3.Cursor = connection.execute(rules["sql"])
return float(cursor.fetchone()[0] or 0)
def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str]]]]) -> float:
"""
Args:
result (str): path to json file
rules (Dict[str, List[Dict[str, Union[List[str], str]]]]): dict like
{
"expect": [
{
"key": list of str
"method": str
"ref": something
}
],
"unexpect": <the same as `expect`
}
Returns:
float
"""
if result is None:
return 0.
with open(result) as f:
result: Dict[str, Any] = json.load(f)
expect_rules = rules.get("expect", {})
unexpect_rules = rules.get("unexpect", {})
metric = True
for r in expect_rules:
value = result
for k in r["key"]:
value = value[k]
metric = metric and _match_value_to_rule(value, r)
for r in unexpect_rules:
value = result
for k in r["key"]:
value = value[k]
metric = metric and not _match_value_to_rule(value, r)
return metric