1. extend evaluator to list (compatible with single evaluator) 2. fix a variable name error in metrics/general.py
This commit is contained in:
@@ -8,7 +8,7 @@ import time
|
||||
from typing import Callable, Any, Optional, Tuple
|
||||
# import uuid
|
||||
# import platform
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
@@ -211,12 +211,28 @@ class DesktopEnv(gym.Env):
|
||||
self.instruction = task_config["instruction"]
|
||||
self.config = task_config["config"]
|
||||
|
||||
# evaluator dict
|
||||
# func -> metric function string, or list of metric function strings
|
||||
# conj -> conjuction of multiple metrics if func is a list with length > 1, "and"/"or"
|
||||
# result -> result getter config, or list of result getter configs
|
||||
# expected (optional) -> expected getter config, or list of expected getter configs
|
||||
# options (optional) -> metric options, or list of metric options
|
||||
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
|
||||
# even if one of the metrics does not need expected or options field, it should be included in the list with None
|
||||
self.evaluator = task_config["evaluator"]
|
||||
self.metric: Metric = getattr(metrics, self.evaluator["func"])
|
||||
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
|
||||
self.expected_getter: Getter = getattr(getters, "get_{:}".format(
|
||||
self.evaluator["expected"]["type"])) if "expected" in self.evaluator else None
|
||||
self.metric_options: Dict[str, Any] = self.evaluator.get("options", {})
|
||||
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] if type(self.evaluator["func"]) == list \
|
||||
else getattr(metrics, self.evaluator["func"])
|
||||
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjuction of multiple metrics
|
||||
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in self.evaluator["result"]] \
|
||||
if type(self.evaluator["result"]) == list else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
|
||||
if "expected" in self.evaluator:
|
||||
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in self.evaluator["expected"]] \
|
||||
if type(self.evaluator["expected"]) == list else getattr(getters, "get_{:}".format())
|
||||
else: self.expected_getter = [None for _ in len(self.metric)] if type(self.metric) == list else None
|
||||
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in self.evaluator["options"]] \
|
||||
if type(self.evaluator.get("options", {})) == list else self.evaluator["options"] if "options" in self.evaluator else \
|
||||
[{} for _ in len(self.metric)] if type(self.metric) == list else {}
|
||||
assert type(self.evaluator["func"]) != list or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(self.metric_options))
|
||||
|
||||
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
||||
logger.info("Resetting environment...")
|
||||
@@ -311,17 +327,39 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
self.setup_controller.setup(self.evaluator.get("postconfig", []))
|
||||
|
||||
try:
|
||||
result_state = self.result_getter(self, self.evaluator["result"])
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
return 0
|
||||
if type(self.metric) == list:
|
||||
for idx, metric in enumerate(self.metric):
|
||||
try:
|
||||
config = self.evaluator["result"][idx]
|
||||
result_state = self.result_getter[idx](self, config)
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
if self.metric_conj == 'and':
|
||||
return 0
|
||||
|
||||
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
|
||||
else None
|
||||
expected = self.evaluator["expected"][idx]
|
||||
expected_state = self.expected_getter[idx](self, expected) if expected else None
|
||||
|
||||
metric: float = self.metric(result_state, expected_state, **self.metric_options) if expected_state is not None \
|
||||
else self.metric(result_state, **self.metric_options)
|
||||
metric: int = metric(result_state, expected_state, **self.metric_options[idx]) if expected_state is not None \
|
||||
else metric(result_state, **self.metric_options[idx])
|
||||
|
||||
if self.metric_conj == 'and' and not bool(metric):
|
||||
return 0
|
||||
elif self.metric_conj == 'or' and bool(metric):
|
||||
return 1
|
||||
return 1 if self.metric_conj == 'and' else 0
|
||||
else:
|
||||
try:
|
||||
result_state = self.result_getter(self, self.evaluator["result"])
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
return 0
|
||||
|
||||
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
|
||||
else None
|
||||
|
||||
metric: float = self.metric(result_state, expected_state, **self.metric_options) if expected_state is not None \
|
||||
else self.metric(result_state, **self.metric_options)
|
||||
|
||||
return metric
|
||||
|
||||
|
||||
@@ -180,8 +180,8 @@ def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str
|
||||
with open(result) as f:
|
||||
result: Dict[str, Any] = json.load(f)
|
||||
|
||||
expect_rules = rule.get("expect", {})
|
||||
unexpect_rules = rule.get("unexpect", {})
|
||||
expect_rules = rules.get("expect", {})
|
||||
unexpect_rules = rules.get("unexpect", {})
|
||||
|
||||
metric = True
|
||||
for r in expect_rules:
|
||||
|
||||
Reference in New Issue
Block a user