From bdd21d06ca50cb8605c9ca26c3f513f144ab0b07 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 19 Jan 2024 20:34:11 +0800 Subject: [PATCH] Fix minor bugs --- desktop_env/envs/desktop_env.py | 51 ++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index df49f1e..2f4287e 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -112,6 +112,7 @@ class DesktopEnv(gym.Env): bytes_per_pixel = bits_per_pixel // 8 vram_size = width * height * bytes_per_pixel return vram_size + if not os.path.isfile(self.path_to_vm): logger.warning(f"The specified vmx file does not exist: {self.path_to_vm}") return False @@ -213,26 +214,42 @@ class DesktopEnv(gym.Env): # evaluator dict # func -> metric function string, or list of metric function strings - # conj -> conjuction of multiple metrics if func is a list with length > 1, "and"/"or" + # conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or" # result -> result getter config, or list of result getter configs # expected (optional) -> expected getter config, or list of expected getter configs # options (optional) -> metric options, or list of metric options # if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length # even if one of the metrics does not need expected or options field, it should be included in the list with None self.evaluator = task_config["evaluator"] - self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] if type(self.evaluator["func"]) == list \ - else getattr(metrics, self.evaluator["func"]) - self.metric_conj: str = self.evaluator.get("conj", "and") # take conjuction of multiple metrics - self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in self.evaluator["result"]] \ - if type(self.evaluator["result"]) == list else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"])) + self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \ + if isinstance(self.evaluator["func"], list) \ + else getattr(metrics, self.evaluator["func"]) + self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics + self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in + self.evaluator["result"]] \ + if isinstance(self.evaluator["result"], list) \ + else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"])) if "expected" in self.evaluator: - self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in self.evaluator["expected"]] \ - if type(self.evaluator["expected"]) == list else getattr(getters, "get_{:}".format()) - else: self.expected_getter = [None for _ in len(self.metric)] if type(self.metric) == list else None - self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in self.evaluator["options"]] \ - if type(self.evaluator.get("options", {})) == list else self.evaluator["options"] if "options" in self.evaluator else \ - [{} for _ in len(self.metric)] if type(self.metric) == list else {} - assert type(self.evaluator["func"]) != list or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(self.metric_options)) + self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in + self.evaluator["expected"]] \ + if isinstance(self.evaluator["expected"], list) \ + else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"])) + else: + self.expected_getter = [None] * len(self.metric) \ + if isinstance(self.metric, list) \ + else None + self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in + self.evaluator["options"]] \ + if isinstance(self.evaluator.get("options", {}), list) \ + else self.evaluator["options"] \ + if "options" in self.evaluator \ + else [{}] * len(self.metric) \ + if isinstance(self.metric, list) \ + else {} + + assert (not isinstance(self.evaluator["func"], list) + or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len( + self.metric_options))) def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]: logger.info("Resetting environment...") @@ -340,9 +357,10 @@ class DesktopEnv(gym.Env): expected = self.evaluator["expected"][idx] expected_state = self.expected_getter[idx](self, expected) if expected else None - metric: int = metric(result_state, expected_state, **self.metric_options[idx]) if expected_state is not None \ + metric: int = metric(result_state, expected_state, + **self.metric_options[idx]) if expected_state is not None \ else metric(result_state, **self.metric_options[idx]) - + if self.metric_conj == 'and' and not bool(metric): return 0 elif self.metric_conj == 'or' and bool(metric): @@ -358,7 +376,8 @@ class DesktopEnv(gym.Env): expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \ else None - metric: float = self.metric(result_state, expected_state, **self.metric_options) if expected_state is not None \ + metric: float = self.metric(result_state, expected_state, + **self.metric_options) if expected_state is not None \ else self.metric(result_state, **self.metric_options) return metric