Files
sci-gui-agent-benchmark/desktop_env/evaluators/metrics/vscode.py
2024-01-22 23:30:10 +08:00

129 lines
2.8 KiB
Python

from typing import Dict
import json, copy
def check_json_keybindings(actual: str, expected: str, **options) -> float:
"""
Args:
actual (str): path to result text file
expected (str): expected dict{}
Return:
float: the score
"""
def direct_load_json(fp):
try:
with open(fp, 'r') as f:
data = json.load(f)
return data
except:
return None
def skip_first_line_load_json(fp):
try:
with open(fp, 'r') as f:
f.readline()
data = json.load(f)
return data
except:
return None
for func in [direct_load_json, skip_first_line_load_json]:
data = func(actual)
if data is not None and type(data) == list:
break
else:
return 0.0
expected = expected['expect']
if expected in data:
return 1.0
else:
return 0.0
def check_json_settings(actual: str, expected: str, **options) -> float:
"""
Args:
actual (str): path to result text file
expected (dict): expected dict{}, containing key "expect"
Return:
float: the score
"""
with open(actual, 'r') as f:
data = json.load(f)
expect = expected['expect']
data_copy = copy.deepcopy(data)
data_copy.update(expect)
if data == data_copy:
return 1.0
else:
return 0.0
def compare_text_file(actual: str, expected: str, **options) -> float:
"""
Args:
actual (str): path to result text file
expected (str): path to gold text file
Return:
float: the score
"""
if not actual:
return 0.
with open(actual) as f1:
actual_text = f1.read()
with open(expected) as f2:
expected_text = f2.read()
if actual_text == expected_text:
return 1.0
return 0.0
def compare_config(actual: str, rules: Dict, **options) -> float:
if not actual:
return 0.
with open(actual) as f1:
actual_text = f1.read()
if actual_text == rules['expect']:
return 1.0
return 0.0
def compare_answer(actual: str, rules: Dict, **options) -> float:
"""
Args:
actual (str): result string
expected (str): gold string
Return:
float: the score
"""
if not actual:
return 0.
if actual == rules['expect']:
return 1.0
# TODO: can use text embedding to get non-zero return
return 0.0
def is_extension_installed(actual: str, rules: Dict, **options):
if rules['type'] == 'contain':
if rules['expected'] in actual:
return 1.0
return 0.0
elif rules['type'] == 'not_contain':
if rules['expected'] not in actual:
return 1.0
return 0.0
else:
raise NotImplementedError