diff --git a/desktop_env/evaluators/getters/misc.py b/desktop_env/evaluators/getters/misc.py index db04aea..976db19 100644 --- a/desktop_env/evaluators/getters/misc.py +++ b/desktop_env/evaluators/getters/misc.py @@ -1,5 +1,5 @@ import logging -from typing import TypeVar +from typing import TypeVar, Dict from datetime import datetime, timedelta logger = logging.getLogger("desktopenv.getters.misc") @@ -74,13 +74,13 @@ relativeTime_to_IntDay = { "first monday four months later": "special" } -def get_rule(env, config: R) -> R: +def get_rule(env, config: Dict[str, R]) -> R: """ Returns the rule as-is. """ return config["rules"] -def get_rule_relativeTime(env, config: R) -> R: +def get_rule_relativeTime(env, config: Dict[str, R]) -> R: """ According to the rule definded in funciton "apply_rules_to_timeFormat", convert the relative time to absolute time. config: diff --git a/desktop_env/evaluators/metrics/general.py b/desktop_env/evaluators/metrics/general.py index 4458a69..26a8e3c 100644 --- a/desktop_env/evaluators/metrics/general.py +++ b/desktop_env/evaluators/metrics/general.py @@ -1,6 +1,7 @@ import csv import functools import json +import yaml import operator import re import sqlite3 @@ -132,11 +133,11 @@ _accessibility_ns_map = {"st": "uri:deskat:state.at-spi.gnome.org" } -def check_accessibility_tree(result: str, rules: Dict[str, Any]) -> float: +def check_accessibility_tree(result: str, rules: List[Dict[str, Any]]) -> float: """ Args: result (str): XML of GNOME Accessibility Tree - rules (Dict[str, Any]): dict like + rules (List[Dict[str, Any]]): list of dict like { "selectors": list of str as CSS selectors, will be connected by ", " to form a composite selector. Only one from `selectors` and @@ -154,30 +155,33 @@ def check_accessibility_tree(result: str, rules: Dict[str, Any]) -> float: """ at: _Element = lxml.etree.fromstring(result) - if "xpath" in rules: - elements: List[_Element] = at.xpath(rules["xpath"], namespaces=_accessibility_ns_map) - elif "selectors" in rules: - selector = CSSSelector(", ".join(rules["selectors"]), namespaces=_accessibility_ns_map) - elements: List[_Element] = selector(at) - else: - raise ValueError("At least one of xpath and selectors is required") + total_match_score = 1. + for r in rules: + if "xpath" in r: + elements: List[_Element] = at.xpath(r["xpath"], namespaces=_accessibility_ns_map) + elif "selectors" in r: + selector = CSSSelector(", ".join(r["selectors"]), namespaces=_accessibility_ns_map) + elements: List[_Element] = selector(at) + else: + raise ValueError("At least one of xpath and selectors is required") - if len(elements) == 0: - print("no elements") - return 0. + if len(elements) == 0: + print("no elements") + return 0. - if "text" in rules: - match_func: Callable[[str], Number] = functools.partial(operator.eq if rules["exact"] \ - else (lambda a, b: fuzz.ratio(a, b) / 100.) - , rules["text"] - ) - match_score: Number = 0 - for elm in elements: - match_score = max(match_score, match_func(elm.text or None)) - else: - match_score = 1. + if "text" in r: + match_func: Callable[[str], Number] = functools.partial( operator.eq if r["exact"] \ + else (lambda a, b: fuzz.ratio(a, b) / 100.) + , r["text"] + ) + match_score: Number = 0 + for elm in elements: + match_score = max(match_score, match_func(elm.text or None)) + else: + match_score = 1. + total_match_score *= match_score - return float(match_score) + return float(total_match_score) # def check_existence(result: str, *args) -> float: @@ -189,7 +193,7 @@ def run_sqlite3(result: str, rules: Dict[str, Any]) -> float: return float(cursor.fetchone()[0] or 0) -def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str]]]]) -> float: +def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str]]]], is_yaml: bool = False) -> float: """ Args: result (str): path to json file @@ -204,6 +208,7 @@ def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str ], "unexpect": page-tab[name=\"About Profiles\"]" - ] - } + "rules": [ + { + "selectors": [ + "application[name=Thunderbird] page-tab-list[attr|id=\"tabmail-tabs\"]>page-tab[name=\"About Profiles\"]" + ] + } + ] }, "func": "check_accessibility_tree" } diff --git a/requirements.txt b/requirements.txt index a6082f9..6571f11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,3 +42,4 @@ func-timeout beautifulsoup4 dashscope google-generativeai +PyYaml