Files
sci-gui-agent-benchmark/desktop_env/evaluators/metrics/general.py
2024-03-08 19:24:15 +08:00

381 lines
12 KiB
Python

import csv
import functools
import json
import operator
import re
import pdfplumber
import sqlite3
from numbers import Number
from typing import Callable, Any, Union
from typing import Dict, List, Pattern
import datetime
import pandas as pd
import lxml.etree
from lxml.cssselect import CSSSelector
from lxml.etree import _Element
from rapidfuzz import fuzz
from docx import Document
from .utils import _match_record, _match_value_to_rule
def check_include_exclude(result: str, rules: Dict[str, List[str]]) -> float:
if result is None:
return 0.
print(result, rules)
include = rules.get("include", [])
exclude = rules.get("exclude", [])
if all(r in result for r in include) and all(r not in result for r in exclude):
return 1.
else:
return 0.
def exact_match(result, rules) -> float:
expect = rules["expected"]
print(result, expect)
if result == expect:
return 1.
else:
return 0.
def is_in_list(result, rules) -> float:
expect = rules["expected"]
if expect in result:
return 1.
else:
return 0.
def fuzzy_match(result, rules) -> float:
expect = rules["expected"]
return fuzz.ratio(result, expect) / 100.
def fuzzy_place_math(result_file_path, rules) -> float:
expect = rules["expected"] # a list of possible answers
# read list.docx, and get all texts out, overlook blank lines, remove blanks before and after each line
doc = Document(result_file_path)
words_list = []
for para in doc.paragraphs:
words_list.extend(para.text.split())
# 打印出提取的单词列表
print(words_list)
for word in words_list:
if not any(ans in word for ans in expect):
print("Wrong place:", word)
return 0
return 1
def check_csv(result: str, rules: Dict[str, List[Dict[str, str]]]) -> float:
"""
Args:
result (str): path to csv file
rules (Dict[str, List[Dict[str, str]]]): dict like
{
"expect": [{key: value}]
"unexpect": [{key: value}]
}
Returns:
float
"""
if result is None:
return 0.
expect_metrics = [False] * len(rules.get("expect", []))
unexpect_metric = True
with open(result) as f:
reader = csv.DictReader(f)
for rcd in reader:
for i, r in enumerate(rules.get("expect", [])):
expect_metrics[i] = expect_metrics[i] or _match_record(r, rcd)
unexpect_metric = unexpect_metric and not any(_match_record(r, rcd) for r in rules.get("unexpect", []))
return float(all(expect_metrics) and unexpect_metric)
def check_list(result: str, rules: Dict[str, List[str]]) -> float:
"""
Args:
result (str): path to list file
rules (Dict[str, List[str]]): dict like
{
"expect": list of str as regexes
"unexpect": list of str as regexes
}
Returns:
float
"""
if result is None:
return 0.
expect_patterns: List[Pattern[str]] = [re.compile(ptt) for ptt in rules.get("expect", [])]
unexpect_patterns: List[Pattern[str]] = [re.compile(ptt) for ptt in rules.get("unexpect", [])]
expect_metrics = [False] * len(expect_patterns)
unexpect_metric = True
with open(result) as f:
for l in f:
for i, r in enumerate(expect_patterns):
expect_metrics[i] = expect_metrics[i] or (r.search(l) is not None)
unexpect_metric = unexpect_metric and all(r.search(l) is None for r in unexpect_patterns)
return float(all(expect_metrics) and unexpect_metric)
_accessibility_ns_map = {"st": "uri:deskat:state.at-spi.gnome.org"
, "attr": "uri:deskat:attributes.at-spi.gnome.org"
, "cp": "uri:deskat:component.at-spi.gnome.org"
, "doc": "uri:deskat:document.at-spi.gnome.org"
, "docattr": "uri:deskat:attributes.document.at-spi.gnome.org"
, "txt": "uri:deskat:text.at-spi.gnome.org"
, "val": "uri:deskat:value.at-spi.gnome.org"
, "act": "uri:deskat:action.at-spi.gnome.org"
}
def check_accessibility_tree(result: str, rules: Dict[str, Any]) -> float:
"""
Args:
result (str): XML of GNOME Accessibility Tree
rules (Dict[str, Any]): dict like
{
"selectors": list of str as CSS selectors, will be connected by ", "
to form a composite selector. Only one from `selectors` and
`xpath` is needed. If both are present, `xpath` takes the
priority.
"xpath": str as xpath. Only one from `selectors` and `xpath` is
needed. If both are present, `xpath` takes the priority.
"text": str as the expected text content of the selected element.
"exact": bool specifying whether exact match or fuzzy match should
be performed. defaults to True.
}
Returns:
float
"""
at: _Element = lxml.etree.fromstring(result)
if "xpath" in rules:
elements: List[_Element] = at.xpath(rules["xpath"], namespaces=_accessibility_ns_map)
elif "selectors" in rules:
selector = CSSSelector(", ".join(rules["selectors"]), namespaces=_accessibility_ns_map)
elements: List[_Element] = selector(at)
else:
raise ValueError("At least one of xpath and selectors is required")
if len(elements) == 0:
print("no elements")
return 0.
if "text" in rules:
match_func: Callable[[str], Number] = functools.partial(operator.eq if rules["exact"] \
else (lambda a, b: fuzz.ratio(a, b) / 100.)
, rules["text"]
)
match_score: Number = 0
for elm in elements:
match_score = max(match_score, match_func(elm.text or None))
else:
match_score = 1.
return float(match_score)
# def check_existence(result: str, *args) -> float:
# return 1. - (result is None)
def run_sqlite3(result: str, rules: Dict[str, Any]) -> float:
connection: sqlite3.Connection = sqlite3.connect(result)
cursor: sqlite3.Cursor = connection.execute(rules["sql"])
return float(cursor.fetchone()[0] or 0)
def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str]]]]) -> float:
"""
Args:
result (str): path to json file
rules (Dict[str, List[Dict[str, Union[List[str], str]]]]): dict like
{
"expect": [
{
"key": list of str
"method": str
"ref": something
}
],
"unexpect": <the same as `expect`
}
Returns:
float
"""
if result is None:
return 0.
with open(result) as f:
result: Dict[str, Any] = json.load(f)
expect_rules = rules.get("expect", {})
unexpect_rules = rules.get("unexpect", {})
metric = True
for r in expect_rules:
value = result
for k in r["key"]:
value = value[k]
metric = metric and _match_value_to_rule(value, r)
for r in unexpect_rules:
value = result
for k in r["key"]:
value = value[k]
metric = metric and not _match_value_to_rule(value, r)
return metric
def check_direct_json_object(result, rules)->float:
"""
One of the most commonly used function to evalute.
Compare two json objects directly.
"""
if isinstance(result, str):
# remove blanks before and after result
result = result.strip()
# replace all ' with "
result = result.replace("'", '"')
# load json object
result = json.loads(result)
print("result: ")
print(result)
print("expected: ")
print(rules["expected"])
if result is None:
return 0.
expect_in_result = rules.get("expect_in_result", False)
if not expect_in_result:
expected_json = rules["expected"]
for key in expected_json.keys():
expected_value = expected_json.get(key)
if expected_value != result.get(key):
return 0.
return 1.0
else:
expected_json = rules["expected"]
for key in expected_json.keys():
expected_value = expected_json.get(key)
if expected_value not in result.get(key):
return 0.
return 1.0
def compare_time_in_speedtest_results(speedtest_result_path, time_diff):
# open the speedtest results file(csv)
date_col = None
with open(speedtest_result_path, 'r') as f:
reader = pd.read_csv(f)
for column in reader.columns:
if column.startswith('TEST_DATE'):
date_col = column
break
now_date_time = datetime.datetime.now().strftime('%H:%M')
for date in reader[date_col]:
date_time = date[-5:]
# compare the date time with the current date time, if time diff less than time_diff para, then return true
if not abs((datetime.datetime.strptime(date_time, '%H:%M') - datetime.datetime.strptime(now_date_time, '%H:%M')).total_seconds()) / 60 < int(time_diff):
return False
return True
def is_included_all_json_objects(gold_file_path, result_file_path):
print("gold_file_path: ")
print(gold_file_path)
print("result_file_path: ")
print(result_file_path)
# two json file, check if all the key-value pair in gold_file_path is included in result_file_path
with open(gold_file_path, 'r') as f:
gold_json = json.load(f)
with open(result_file_path, 'r') as fr:
result_json = json.load(fr)
for key in gold_json.keys():
if key not in result_json.keys() or gold_json[key] != result_json[key]:
return False
return True
def is_gold_text_included_in_pdf(pdf_file_path, gold_text_path):
print("gold_text_path: ")
print(gold_text_path)
print("pdf_file_path: ")
print(pdf_file_path)
# gold file is a json file, we need to check all the value in json are included in pdf file.
with open(gold_text_path, 'r') as f:
gold_json = json.load(f)
with pdfplumber.open(pdf_file_path) as pdf:
text = ''
for page in pdf.pages:
text += page.extract_text()
false_list = []
for key in gold_json.keys():
if gold_json[key] not in text:
false_list.append(key)
if len(false_list) > 0:
print("false_list: ")
print(false_list)
return False
else:
return True
def file_contains(file_path, config):
# file_path ends with .txt
if not file_path :
return False
with open(file_path, 'r') as f:
file_text = f.read()
for text in config["expected"]:
if text not in file_text:
return False
return True
def check_csv_line_number(file_path, line_number):
# check file_path suffix
if not file_path.endswith('.csv'):
return False
# check line number
with open(file_path, 'r') as f:
reader = csv.reader(f)
line_count = sum(1 for row in reader)
return True if line_count == int(line_number["expected"]) else False
def compare_terminal_and_txt(txt_file_path, terminal_output):
# read txt file content
with open(txt_file_path, 'r') as f:
txt_file_content = f.read()
# compare terminal output with txt file content
return True if terminal_output == txt_file_content else False
def compare_python_pure_text(py_file_path, gold_file_path):
# first, change the suffix of gold_file from .txt to .py
print("py_file_path: ")
print(py_file_path)
print("gold_file_path: ")
print(gold_file_path)
# gold_file_path = gold_file_path.replace('.txt', '.py')
def remove_whitespace(text):
return ''.join(text.split())
with open(py_file_path, 'r') as file1:
content1 = file1.read()
with open(gold_file_path, 'r') as file2:
content2 = file2.read()
# 移除文件内容中的所有空白字符
content1_no_whitespace = remove_whitespace(content1)
content2_no_whitespace = remove_whitespace(content2)
# 比较处理后的文件内容
return content1_no_whitespace == content2_no_whitespace