Merge remote-tracking branch 'origin/main'

# Conflicts:
#	desktop_env/evaluators/getters/__init__.py
#	desktop_env/evaluators/metrics/__init__.py
#	requirements.txt
This commit is contained in:
Timothyxxx
2024-01-10 23:20:49 +08:00
32 changed files with 626 additions and 80 deletions

View File

@@ -1,4 +1,5 @@
from .file import get_cloud_file, get_vm_file, get_cache_file
from .misc import get_rule
from .info import get_vm_screen_size, get_vm_window_size, get_vm_wallpaper
from .misc import get_rule, get_accessibility_tree
from .vlc import get_vlc_playing_info, get_vlc_config

View File

@@ -1,5 +1,6 @@
import os
from typing import Dict
from typing import Optional
import requests
@@ -27,7 +28,7 @@ def get_cloud_file(env, config: Dict[str, str]) -> str:
return _path
def get_vm_file(env, config: Dict[str, str]) -> str:
def get_vm_file(env, config: Dict[str, str]) -> Optional[str]:
"""
Config:
path (str): absolute path on the VM to fetch
@@ -37,10 +38,9 @@ def get_vm_file(env, config: Dict[str, str]) -> str:
_path = os.path.join(env.cache_dir, config["dest"])
file = env.controller.get_file(config["path"])
if file is None:
raise FileNotFoundError("File not found on VM: {:}".format(config["path"]))
return None
#raise FileNotFoundError("File not found on VM: {:}".format(config["path"]))
with open(_path, "wb") as f:
f.write(file)

View File

@@ -0,0 +1,23 @@
from typing import Dict
import os
import requests
def get_string(env, config: Dict[str, str]) -> str:
"""
Config:
string (str)
"""
return config["string"]
def get_command_line(env, config: Dict[str, str]) -> str:
"""
Config:
string (str)
"""
f = os.popen(config["command"])
return f.read()

View File

@@ -1,5 +1,6 @@
import logging
from typing import TypeVar
#from typing import Dict, List
logger = logging.getLogger("desktopenv.getters.misc")
@@ -11,3 +12,8 @@ def get_rule(env, config: R) -> R:
Returns the rule as-is.
"""
return config["rules"]
def get_accessibility_tree(env, *args) -> str:
accessibility_tree: str = env.controller.get_accessibility_tree()
logger.debug("AT@eval: %s", accessibility_tree)
return accessibility_tree

View File

@@ -6,4 +6,5 @@ from .docs import is_first_line_centered, check_file_exists, compare_contains_im
from .pdf import check_pdf_pages
from .libreoffice import check_libre_locale
from .vlc import is_vlc_playing, is_vlc_recordings_folder, is_vlc_fullscreen, compare_images, compare_audios, compare_videos
from .general import check_csv
from .general import check_csv, check_accessibility_tree, check_list

View File

@@ -1,5 +1,17 @@
import csv
from typing import Dict, List
import lxml.etree
from lxml.etree import _Element
from lxml.cssselect import CSSSelector
from typing import Dict, List, Pattern
from typing import Callable, Any
from numbers import Number
import operator
from rapidfuzz import fuzz
import functools
import re
def _match_record(pattern: Dict[str, str], item: Dict[str, str]) -> float:
return all(k in item and item[k]==val for k, val in pattern.items())
@@ -22,9 +34,92 @@ def check_csv(result: str, rules: Dict[str, List[Dict[str, str]]]) -> float:
unexpect_metric = True
with open(result) as f:
reader = csv.DictReader(f)
for rcd in reader:
for i, r in enumerate(rules.get("expect", [])):
expect_metrics[i] = expect_metrics[i] or _match_record(r, rcd)
unexpect_metric = unexpect_metric and all(_match_record(r, rcd) for r in rules.get("unexpect", []))
unexpect_metric = unexpect_metric and not any(_match_record(r, rcd) for r in rules.get("unexpect", []))
return float(all(expect_metrics) and unexpect_metric)
def check_list(result: str, rules: Dict[str, List[str]]) -> float:
"""
Args:
result (str): path to list file
rules (Dict[str, List[str]]): dict like
{
"expect": list of str as regexes
"unexpect": list of str as regexes
}
Returns:
float
"""
expect_patterns: List[Pattern[str]] = [re.compile(ptt) for ptt in rules.get("expect", [])]
unexpect_patterns: List[Pattern[str]] = [re.compile(ptt) for ptt in rules.get("unexpect", [])]
expect_metrics = [False] * len(expect_patterns)
unexpect_metric = True
with open(result) as f:
for l in f:
for i, r in enumerate(expect_patterns):
expect_metrics[i] = expect_metrics[i] or (r.search(l) is not None)
unexpect_metric = unexpect_metric and all(r.search(l) is None for r in unexpect_patterns)
return float(all(expect_metrics) and unexpect_metric)
_accessibility_ns_map = { "st": "uri:deskat:state.at-spi.gnome.org"
, "attr": "uri:deskat:attributes.at-spi.gnome.org"
, "cp": "uri:deskat:component.at-spi.gnome.org"
, "doc": "uri:deskat:document.at-spi.gnome.org"
, "docattr": "uri:deskat:attributes.document.at-spi.gnome.org"
, "txt": "uri:deskat:text.at-spi.gnome.org"
, "val": "uri:deskat:value.at-spi.gnome.org"
, "act": "uri:deskat:action.at-spi.gnome.org"
}
def check_accessibility_tree(result: str, rules: Dict[str, Any]) -> float:
"""
Args:
result (str): XML of GNOME Accessibility Tree
rules (Dict[str, Any]): dict like
{
"selectors": list of str as CSS selectors, will be connected by ", "
to form a composite selector. Only one from `selectors` and
`xpath` is needed. If both are present, `xpath` takes the
priority.
"xpath": str as xpath. Only one from `selectors` and `xpath` is
needed. If both are present, `xpath` takes the priority.
"text": str as the expected text content of the selected element.
"exact": bool specifying whether exact match or fuzzy match should
be performed. defaults to True
}
Returns:
float
"""
at: _Element = lxml.etree.fromstring(result)
if "xpath" in rules:
elements: List[_Element] = at.xpath(rules["xpath"], namespaces=_accessibility_ns_map)
elif "selectors" in rules:
selector = CSSSelector(", ".join(rules["selectors"]), namespaces=_accessibility_ns_map)
elements: List[_Element] = selector(at)
else:
raise ValueError("At least one of xpath and selectors is required")
if len(elements)==0:
return 0.
if "text" in rules:
match_func: Callable[[str], Number] = functools.partial( operator.eq if rules["exact"] else fuzz.ratio
, rules["text"]
)
match_score: Number = 0
for elm in elements:
match_score = max(match_score, match_func(elm.text or None))
else:
match_score = 1.
return float(match_score)
#def check_existence(result: str, *args) -> float:
#return 1. - (result is None)

View File

@@ -31,6 +31,9 @@ def compare_table(actual: str, expected: str, **options) -> float:
float: the score
"""
if actual is None:
return 0.
df1 = pd.read_excel(expected)
df2 = pd.read_excel(actual)
metric: bool = df1.equals(df2)
@@ -71,6 +74,9 @@ def compare_table(actual: str, expected: str, **options) -> float:
return float(metric)
def check_sheet_list(result: str, rules: List[Dict[str, Any]]) -> float:
if result is None:
return 0.
# workbook: Workbook = openpyxl.load_workbook(filename=result)
workbook = pd.ExcelFile(result)
worksheet_names: List[str] = workbook.sheet_names
@@ -109,10 +115,16 @@ def check_sheet_list(result: str, rules: List[Dict[str, Any]]) -> float:
return float(passes)
def check_xlsx_freeze(result: str, rules: Dict[str, str]) -> float:
if result is None:
return 0.
worksheet: Worksheet = openpyxl.load_workbook(filename=result).active
return float(worksheet.freeze_panes == rules["position"])
def check_xlsx_zoom(result: str, rules: Dict[str, Union[str, Number]]) -> float:
if result is None:
return 0.
worksheet = openpyxl.load_workbook(filename=result).active
zoom_scale: Number = worksheet.sheet_view.zoomScale or 100.
return float( getattr(operator, rules["relation"])( zoom_scale

View File

@@ -0,0 +1,53 @@
#from playwright.sync_api import sync_playwright, Browser
#from marionette_driver.marionette import Marionette
#import marionette
#import pyatspi
import lxml.etree
from lxml.cssselect import CSSSelector
from lxml.etree import _Element
from typing import List
if __name__ == "__main__":
#with sync_playwright() as plwr:
#while True:
##try:
#thunderbird: Browser = plwr.firefox.connect("http://127.0.0.1:6000", timeout=60)
#break
##except:
##pass
#for ctx in thunderbird.contexts:
#for p in ctx.pages:
#print(p.url)
#thunderbird = Marionette()
#print(thunderbird.start_session())
#print(thunderbird.chrome_window_handles)
#print(thunderbird.window_handles)
#print(thunderbird.current_chrome_window_handle)
#thunderbird.set_context(Marionette.CONTEXT_CONTENT)
#print(thunderbird.current_window_handle)
#thunderbird.switch_to_window(thunderbird.chrome_window_handles[0])
#thunderbird.switch_to_default_content()
#thunderbird.switch_to_frame()
#print(thunderbird.get_url())
#print(thunderbird.get_window_type())
#thunderbird.fullscreen()
#print(thunderbird.close())
#registry = pyatspi.Registry.get_default()
#registry
#xml = "../../任务数据/Thunderbird/vertical-card-view.xml"
xml = "../../任务数据/Thunderbird/vertical-table-view.xml"
at: _Element = lxml.etree.parse(xml)
#elements: List[_Element] = CSSSelector('application[name=Thunderbird] page-tab-list')(at) # page tab tags
#elements: List[_Element] = CSSSelector('application[name=Thunderbird] panel>scroll-pane>internal-frame>panel[name$="anonym-x2024@outlook.com"]')(at) # email tag page
#elements: List[_Element] = CSSSelector('application[name=Thunderbird] panel>scroll-pane>internal-frame>panel[name$="anonym-x2024@outlook.com"]>section:nth-child(3)')(at) # email tag page
#elements: List[_Element] = CSSSelector('application[name=Thunderbird] panel>scroll-pane>internal-frame>panel[name$="anonym-x2024@outlook.com"]>section[attr|id=threadPane]>section[attr|id="threadTree"]>table[attr|class="tree-table"]>section[attr|class~="tree-table-header"]>table-row>column-header[name=Subject]>push-button', namespaces={"attr": "uri:deskat:attributes.at-spi.gnome.org"})(at) # table view, column header
elements: List[_Element] = CSSSelector('application[name=Thunderbird] panel>scroll-pane>internal-frame>panel[name$="anonym-x2024@outlook.com"]>section[attr|id=threadPane]>section[attr|id="threadTree"]>table[attr|class="tree-table"]>tree>tree-item>section[name="Subject"]>section>section', namespaces={"attr": "uri:deskat:attributes.at-spi.gnome.org"})(at) # table view, column header
print(len(elements))
for elm in elements:
print(lxml.etree.tostring(elm, encoding="unicode", pretty_print=True))

View File

@@ -20,5 +20,13 @@ def compare_text_file(actual: str, expected: str, **options) -> float:
return 1.0
return 0.0
def compare_answer(actual: str, expected: str, **options) -> float:
if actual == expected:
return 1.0
# TODO: can use text embedding to get non-zero return
return 0.0
if __name__ == '__main__':
print(compare_text_file("README.md", "README.md"))