From 91824f754cd3eaa69a0c93bad9cae28e93a4ff1c Mon Sep 17 00:00:00 2001 From: rhythmcao Date: Thu, 18 Jan 2024 14:12:54 +0800 Subject: [PATCH] 1. extend evaluator to list (compatible with single evaluator) 2. fix a variable name error in metrics/general.py --- desktop_env/envs/desktop_env.py | 68 ++++++++++++++++++----- desktop_env/evaluators/metrics/general.py | 4 +- 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 94e3b0d..df49f1e 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -8,7 +8,7 @@ import time from typing import Callable, Any, Optional, Tuple # import uuid # import platform -from typing import List, Dict +from typing import List, Dict, Union import gymnasium as gym @@ -211,12 +211,28 @@ class DesktopEnv(gym.Env): self.instruction = task_config["instruction"] self.config = task_config["config"] + # evaluator dict + # func -> metric function string, or list of metric function strings + # conj -> conjuction of multiple metrics if func is a list with length > 1, "and"/"or" + # result -> result getter config, or list of result getter configs + # expected (optional) -> expected getter config, or list of expected getter configs + # options (optional) -> metric options, or list of metric options + # if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length + # even if one of the metrics does not need expected or options field, it should be included in the list with None self.evaluator = task_config["evaluator"] - self.metric: Metric = getattr(metrics, self.evaluator["func"]) - self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"])) - self.expected_getter: Getter = getattr(getters, "get_{:}".format( - self.evaluator["expected"]["type"])) if "expected" in self.evaluator else None - self.metric_options: Dict[str, Any] = self.evaluator.get("options", {}) + self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] if type(self.evaluator["func"]) == list \ + else getattr(metrics, self.evaluator["func"]) + self.metric_conj: str = self.evaluator.get("conj", "and") # take conjuction of multiple metrics + self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in self.evaluator["result"]] \ + if type(self.evaluator["result"]) == list else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"])) + if "expected" in self.evaluator: + self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in self.evaluator["expected"]] \ + if type(self.evaluator["expected"]) == list else getattr(getters, "get_{:}".format()) + else: self.expected_getter = [None for _ in len(self.metric)] if type(self.metric) == list else None + self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in self.evaluator["options"]] \ + if type(self.evaluator.get("options", {})) == list else self.evaluator["options"] if "options" in self.evaluator else \ + [{} for _ in len(self.metric)] if type(self.metric) == list else {} + assert type(self.evaluator["func"]) != list or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(self.metric_options)) def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]: logger.info("Resetting environment...") @@ -311,17 +327,39 @@ class DesktopEnv(gym.Env): self.setup_controller.setup(self.evaluator.get("postconfig", [])) - try: - result_state = self.result_getter(self, self.evaluator["result"]) - except FileNotFoundError: - logger.error("File not found!") - return 0 + if type(self.metric) == list: + for idx, metric in enumerate(self.metric): + try: + config = self.evaluator["result"][idx] + result_state = self.result_getter[idx](self, config) + except FileNotFoundError: + logger.error("File not found!") + if self.metric_conj == 'and': + return 0 - expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \ - else None + expected = self.evaluator["expected"][idx] + expected_state = self.expected_getter[idx](self, expected) if expected else None - metric: float = self.metric(result_state, expected_state, **self.metric_options) if expected_state is not None \ - else self.metric(result_state, **self.metric_options) + metric: int = metric(result_state, expected_state, **self.metric_options[idx]) if expected_state is not None \ + else metric(result_state, **self.metric_options[idx]) + + if self.metric_conj == 'and' and not bool(metric): + return 0 + elif self.metric_conj == 'or' and bool(metric): + return 1 + return 1 if self.metric_conj == 'and' else 0 + else: + try: + result_state = self.result_getter(self, self.evaluator["result"]) + except FileNotFoundError: + logger.error("File not found!") + return 0 + + expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \ + else None + + metric: float = self.metric(result_state, expected_state, **self.metric_options) if expected_state is not None \ + else self.metric(result_state, **self.metric_options) return metric diff --git a/desktop_env/evaluators/metrics/general.py b/desktop_env/evaluators/metrics/general.py index 6246861..98c9596 100644 --- a/desktop_env/evaluators/metrics/general.py +++ b/desktop_env/evaluators/metrics/general.py @@ -180,8 +180,8 @@ def check_json(result: str, rules: Dict[str, List[Dict[str, Union[List[str], str with open(result) as f: result: Dict[str, Any] = json.load(f) - expect_rules = rule.get("expect", {}) - unexpect_rules = rule.get("unexpect", {}) + expect_rules = rules.get("expect", {}) + unexpect_rules = rules.get("unexpect", {}) metric = True for r in expect_rules: