Merge branch 'main' of https://github.com/xlang-ai/DesktopEnv
This commit is contained in:
@@ -61,7 +61,8 @@ from .general import (
|
||||
fuzzy_match,
|
||||
check_include_exclude,
|
||||
check_direct_json_object,
|
||||
diff_text_file
|
||||
diff_text_file,
|
||||
literal_match
|
||||
)
|
||||
from .gimp import (
|
||||
check_brightness_decrease_and_structure_sim,
|
||||
@@ -131,7 +132,7 @@ from .vscode import (
|
||||
check_python_file_by_gold_file,
|
||||
compare_zip_files
|
||||
)
|
||||
|
||||
from .others import compare_epub, check_mp3_meta
|
||||
|
||||
def infeasible():
|
||||
pass
|
||||
|
||||
@@ -58,6 +58,8 @@ def contains_page_break(docx_file):
|
||||
|
||||
def compare_docx_files(file1, file2, **options):
|
||||
ignore_blanks = options.get('ignore_blanks', True)
|
||||
ignore_case = options.get('ignore_case', False)
|
||||
ignore_order = options.get('ignore_order', False)
|
||||
content_only = options.get('content_only', False)
|
||||
|
||||
def get_paragraph_texts_odt(document):
|
||||
@@ -82,11 +84,17 @@ def compare_docx_files(file1, file2, **options):
|
||||
doc2 = Document(file2)
|
||||
doc1_paragraphs = [p.text for p in doc1.paragraphs]
|
||||
doc2_paragraphs = [p.text for p in doc2.paragraphs]
|
||||
if ignore_order:
|
||||
doc1_paragraphs = sorted(doc1_paragraphs)
|
||||
doc2_paragraphs = sorted(doc2_paragraphs)
|
||||
elif file1.endswith('.odt') and file2.endswith('.odt'):
|
||||
doc1 = load(file1)
|
||||
doc2 = load(file2)
|
||||
doc1_paragraphs = get_paragraph_texts_odt(doc1)
|
||||
doc2_paragraphs = get_paragraph_texts_odt(doc2)
|
||||
if ignore_order:
|
||||
doc1_paragraphs = sorted(doc1_paragraphs)
|
||||
doc2_paragraphs = sorted(doc2_paragraphs)
|
||||
else:
|
||||
# Unsupported file types or mismatch
|
||||
print("Unsupported file types or mismatch between file types.")
|
||||
@@ -96,6 +104,8 @@ def compare_docx_files(file1, file2, **options):
|
||||
# Compare the content of the documents
|
||||
text1 = re.sub(r'\s+', ' ', '\n'.join(doc1_paragraphs)).strip()
|
||||
text2 = re.sub(r'\s+', ' ', '\n'.join(doc2_paragraphs)).strip()
|
||||
if ignore_case:
|
||||
text1, text2 = text1.lower(), text2.lower()
|
||||
similarity = fuzz.ratio(text1, text2) / 100.0
|
||||
return similarity
|
||||
|
||||
@@ -103,6 +113,8 @@ def compare_docx_files(file1, file2, **options):
|
||||
if ignore_blanks:
|
||||
text1 = re.sub(r'\s+', ' ', '\n'.join(doc1_paragraphs)).strip()
|
||||
text2 = re.sub(r'\s+', ' ', '\n'.join(doc2_paragraphs)).strip()
|
||||
if ignore_case:
|
||||
text1, text2 = text1.lower(), text2.lower()
|
||||
if text1 != text2:
|
||||
return 0
|
||||
else:
|
||||
@@ -111,6 +123,8 @@ def compare_docx_files(file1, file2, **options):
|
||||
|
||||
# Compare each paragraph
|
||||
for p1, p2 in zip(doc1_paragraphs, doc2_paragraphs):
|
||||
if ignore_case:
|
||||
p1, p2 = p1.lower(), p2.lower()
|
||||
if p1 != p2:
|
||||
return 0
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import csv
|
||||
import functools
|
||||
import json
|
||||
import yaml
|
||||
import operator
|
||||
import re
|
||||
import sqlite3
|
||||
@@ -39,6 +40,24 @@ def exact_match(result, rules) -> float:
|
||||
else:
|
||||
return 0.
|
||||
|
||||
|
||||
def literal_match(result: Any, expected: Any, **options) -> float:
|
||||
literal_type = options.get('type', 'str')
|
||||
if literal_type == 'str':
|
||||
ignore_case = options.get('ignore_case', False)
|
||||
score = str(result) == str(expected) if not ignore_case else str(result).lower() == str(expected).lower()
|
||||
return float(score)
|
||||
elif literal_type == 'list':
|
||||
if type(result) not in [list, tuple] or type(expected) not in [list, tuple] or len(result) != len(expected):
|
||||
return .0
|
||||
ignore_case = options.get('ignore_case', False)
|
||||
result = [str(s) for s in result] if not ignore_case else [str(s).lower() for s in result]
|
||||
expected = [str(s) for s in expected] if not ignore_case else [str(s).lower() for s in expected]
|
||||
return float(result == expected)
|
||||
else:
|
||||
raise NotImplementedError(f"Type {type} not supported")
|
||||
|
||||
|
||||
def is_in_list(result, rules) -> float:
|
||||
expect = rules["expected"]
|
||||
if expect in result:
|
||||
@@ -132,11 +151,11 @@ _accessibility_ns_map = {"st": "uri:deskat:state.at-spi.gnome.org"
|
||||
}
|
||||
|
||||
|
||||
def check_accessibility_tree(result: str, rules: Dict[str, Any]) -> float:
|
||||
def check_accessibility_tree(result: str, rules: List[Dict[str, Any]]) -> float:
|
||||
"""
|
||||
Args:
|
||||
result (str): XML of GNOME Accessibility Tree
|
||||
rules (Dict[str, Any]): dict like
|
||||
rules (List[Dict[str, Any]]): list of dict like
|
||||
{
|
||||
"selectors": list of str as CSS selectors, will be connected by ", "
|
||||
to form a composite selector. Only one from `selectors` and
|
||||
@@ -154,30 +173,33 @@ def check_accessibility_tree(result: str, rules: Dict[str, Any]) -> float:
|
||||
"""
|
||||
|
||||
at: _Element = lxml.etree.fromstring(result)
|
||||
if "xpath" in rules:
|
||||
elements: List[_Element] = at.xpath(rules["xpath"], namespaces=_accessibility_ns_map)
|
||||
elif "selectors" in rules:
|
||||
selector = CSSSelector(", ".join(rules["selectors"]), namespaces=_accessibility_ns_map)
|
||||
elements: List[_Element] = selector(at)
|
||||
else:
|
||||
raise ValueError("At least one of xpath and selectors is required")
|
||||
total_match_score = 1.
|
||||
for r in rules:
|
||||
if "xpath" in r:
|
||||
elements: List[_Element] = at.xpath(r["xpath"], namespaces=_accessibility_ns_map)
|
||||
elif "selectors" in r:
|
||||
selector = CSSSelector(", ".join(r["selectors"]), namespaces=_accessibility_ns_map)
|
||||
elements: List[_Element] = selector(at)
|
||||
else:
|
||||
raise ValueError("At least one of xpath and selectors is required")
|
||||
|
||||
if len(elements) == 0:
|
||||
print("no elements")
|
||||
return 0.
|
||||
if len(elements) == 0:
|
||||
print("no elements")
|
||||
return 0.
|
||||
|
||||
if "text" in rules:
|
||||
match_func: Callable[[str], Number] = functools.partial(operator.eq if rules["exact"] \
|
||||
else (lambda a, b: fuzz.ratio(a, b) / 100.)
|
||||
, rules["text"]
|
||||
)
|
||||
match_score: Number = 0
|
||||
for elm in elements:
|
||||
match_score = max(match_score, match_func(elm.text or None))
|
||||
else:
|
||||
match_score = 1.
|
||||
if "text" in r:
|
||||
match_func: Callable[[str], Number] = functools.partial( operator.eq if r["exact"] \
|
||||
else (lambda a, b: fuzz.ratio(a, b) / 100.)
|
||||
, r["text"]
|
||||
)
|
||||
match_score: Number = 0
|
||||
for elm in elements:
|
||||
match_score = max(match_score, match_func(elm.text or None))
|
||||
else:
|
||||
match_score = 1.
|
||||
total_match_score *= match_score
|
||||
|
||||
return float(match_score)
|
||||
return float(total_match_score)
|
||||
|
||||
|
||||
# def check_existence(result: str, *args) -> float:
|
||||
@@ -189,7 +211,7 @@ def run_sqlite3(result: str, rules: Dict[str, Any]) -> float:
|
||||
return float(cursor.fetchone()[0] or 0)
|
||||
|
||||
|
||||
def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str]]]]) -> float:
|
||||
def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str]]]], is_yaml: bool = False) -> float:
|
||||
"""
|
||||
Args:
|
||||
result (str): path to json file
|
||||
@@ -204,6 +226,7 @@ def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str
|
||||
],
|
||||
"unexpect": <the same as `expect`
|
||||
}
|
||||
is_yaml (bool): yaml rather than json
|
||||
|
||||
Returns:
|
||||
float
|
||||
@@ -212,7 +235,10 @@ def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str
|
||||
if result is None:
|
||||
return 0.
|
||||
with open(result) as f:
|
||||
result: Dict[str, Any] = json.load(f)
|
||||
if is_yaml:
|
||||
result: Dict[str, Any] = yaml.load(f, Loader=yaml.Loader)
|
||||
else:
|
||||
result: Dict[str, Any] = json.load(f)
|
||||
|
||||
expect_rules = rules.get("expect", {})
|
||||
unexpect_rules = rules.get("unexpect", {})
|
||||
|
||||
128
desktop_env/evaluators/metrics/others.py
Normal file
128
desktop_env/evaluators/metrics/others.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import zipfile
|
||||
import os.path
|
||||
import os
|
||||
|
||||
import lxml.html
|
||||
from lxml.html import HtmlElement
|
||||
from typing import List, Dict
|
||||
from typing import Union, TypeVar
|
||||
from mutagen.easyid3 import EasyID3
|
||||
|
||||
from .general import diff_text_file
|
||||
from .utils import _match_value_to_rule
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("desktopenv.metric.others")
|
||||
|
||||
def process_epub(filename: str) -> List[str]:
|
||||
file_list: List[str] = []
|
||||
|
||||
base_dir: str = filename + ".dir"
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(filename, "r") as z_f:
|
||||
with z_f.open("toc.ncx") as in_f\
|
||||
, open(os.path.join(base_dir, "toc.ncx"), "w") as out_f:
|
||||
contents: str = in_f.read().decode()
|
||||
contents = contents.splitlines()
|
||||
for l in contents:
|
||||
if "navPoint" not in l:
|
||||
out_f.write(l + "\n")
|
||||
file_list.append(os.path.join(base_dir, "toc.ncx"))
|
||||
with z_f.open("content.opf") as in_f\
|
||||
, open(os.path.join(base_dir, "content.opf"), "w") as out_f:
|
||||
contents: str = in_f.read().decode()
|
||||
contents = contents.splitlines()
|
||||
for l in contents:
|
||||
if "dc:identifier" not in l:
|
||||
out_f.write(l + "\n")
|
||||
file_list.append(os.path.join(base_dir, "content.opf"))
|
||||
for f_n in z_f.namelist():
|
||||
if f_n.endswith(".html"):
|
||||
with z_f.open(f_n) as in_f\
|
||||
, open(os.path.join(base_dir, f_n), "w") as out_f:
|
||||
html: HtmlElement = lxml.html.fromstring(
|
||||
''.join( filter( lambda ch: ch!="\n" and ch!="\r"
|
||||
, in_f.read().decode()
|
||||
)
|
||||
).encode()
|
||||
)
|
||||
out_f.write(lxml.html.tostring(html, pretty_print=True, encoding="unicode"))
|
||||
file_list.append(os.path.join(base_dir, f_n))
|
||||
logger.debug("%s: %s", filename, file_list)
|
||||
return list(sorted(file_list))
|
||||
except zipfile.BadZipFile:
|
||||
return []
|
||||
|
||||
def compare_epub(result: str, expected: str) -> float:
|
||||
if result is None:
|
||||
return 0.
|
||||
result_files: List[str] = process_epub(result)
|
||||
expected_files: List[str] = process_epub(expected)
|
||||
|
||||
metric: float = 1.
|
||||
for f1, f2 in zip(result_files, expected_files):
|
||||
current_metric: float = diff_text_file(f1, f2)
|
||||
logger.debug("%s vs %s: %f", f1, f2, current_metric)
|
||||
metric *= current_metric
|
||||
return metric
|
||||
|
||||
V = TypeVar("Value")
|
||||
|
||||
def check_mp3_meta(result: str, meta: Dict[str, Dict[str, Union[str, V]]]) -> bool:
|
||||
# checks using _match_value_to_rule
|
||||
if result is None:
|
||||
return 0.
|
||||
|
||||
id3_dict = EasyID3(result)
|
||||
metric: bool = True
|
||||
for k, r in meta.items():
|
||||
value = id3_dict.get(k, "")
|
||||
if isinstance(value, list):
|
||||
value: str = value[0]
|
||||
logger.debug("%s.%s: %s", result, k, value)
|
||||
metric = metric and _match_value_to_rule(value, r)
|
||||
return float(metric)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import datetime
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)))
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)))
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)))
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
|
||||
metric = check_mp3_meta( "snapshots/test/cache/3f05f3b9-29ba-4b6b-95aa-2204697ffc06/Cheng Xiang - Missing You - gt.mp3"
|
||||
, { "title": { "method": "eq"
|
||||
, "ref": "Missing You"
|
||||
}
|
||||
, "artist": { "method": "eq"
|
||||
, "ref": "Cheng Xiang"
|
||||
}
|
||||
}
|
||||
)
|
||||
print(metric)
|
||||
@@ -12,6 +12,7 @@ import pandas as pd
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.cell.cell import Cell
|
||||
from openpyxl.worksheet.cell_range import MultiCellRange
|
||||
from openpyxl.utils import get_column_letter
|
||||
from openpyxl.worksheet.datavalidation import DataValidation
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
|
||||
@@ -208,8 +209,10 @@ def compare_table(result: str, expected: str = None, **options) -> float:
|
||||
for rl in r["rules"]:
|
||||
for rng in MultiCellRange(rl["range"]):
|
||||
for cdn in rng.cells:
|
||||
value1: str = str(read_cell_value(*sheet1, cdn))
|
||||
value2: str = str(read_cell_value(*sheet2, cdn))
|
||||
coordinate: str = "{:}{:d}".format(get_column_letter(cdn[1]), cdn[0])
|
||||
value1: str = str(read_cell_value(*sheet1, coordinate))
|
||||
value2: str = str(read_cell_value(*sheet2, coordinate))
|
||||
logger.debug("%s: %s vs %s", cdn, value1, value2)
|
||||
|
||||
for rplc in rl.get("normalization", []):
|
||||
value1 = value1.replace(rplc[0], rplc[1])
|
||||
@@ -230,11 +233,11 @@ def compare_table(result: str, expected: str = None, **options) -> float:
|
||||
|
||||
if rl["type"]=="includes":
|
||||
metric: bool = value1 in value2
|
||||
if rl["type"]=="includes_by":
|
||||
elif rl["type"]=="includes_by":
|
||||
metric: bool = value2 in value1
|
||||
if rl["type"]=="fuzzy_match":
|
||||
elif rl["type"]=="fuzzy_match":
|
||||
metric: bool = fuzz.ratio(value1, value2) >= rl.get("threshold", 85.)
|
||||
if rl["type"]=="exact_match":
|
||||
elif rl["type"]=="exact_match":
|
||||
metric: bool = value1==value2
|
||||
total_metric = total_metric and metric
|
||||
|
||||
|
||||
@@ -311,14 +311,15 @@ def read_cell_value(xlsx_file: str, sheet_name: str, coordinate: str) -> Any:
|
||||
, namespaces=_xlsx_ns_imapping
|
||||
)
|
||||
logger.debug("%s.%s[%s]: %s", xlsx_file, sheet_name, coordinate, repr(cell))
|
||||
if "@t" not in cell["c"]:
|
||||
try:
|
||||
if "@t" not in cell["c"] or cell["c"]["@t"] == "n":
|
||||
return float(cell["c"]["v"])
|
||||
if cell["c"]["@t"] == "s":
|
||||
return shared_strs[int(cell["c"]["v"])]
|
||||
if cell["c"]["@t"] == "str":
|
||||
return cell["c"]["v"]
|
||||
except (KeyError, ValueError):
|
||||
return None
|
||||
if cell["c"]["@t"] == "s":
|
||||
return shared_strs[int(cell["c"]["v"])]
|
||||
if cell["c"]["@t"] == "n":
|
||||
return float(cell["c"]["v"])
|
||||
if cell["c"]["@t"] == "str":
|
||||
return cell["c"]["v"]
|
||||
# }}} read_cell_value #
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user