From ec583d6f0c2954400306f56b2a38b862146a6308 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Wed, 2 Apr 2025 23:45:56 +0800 Subject: [PATCH] Enhance metric evaluation in DesktopEnv - Add assertions to ensure the number of metrics matches the number of result and expected getters. - Refactor metric calculation logic to handle cases with and without expected values more clearly. - Improve comments for better understanding of single and multiple metric evaluations. --- desktop_env/desktop_env.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index dd5d364..1d16771 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -274,7 +274,11 @@ class DesktopEnv(gym.Env): return 0 if type(self.metric) == list: + # Multiple metrics to evaluate whether the task is successfully completed results = [] + assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same" + if "expected" in self.evaluator: + assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same" for idx, metric in enumerate(self.metric): try: config = self.evaluator["result"][idx] @@ -284,12 +288,11 @@ class DesktopEnv(gym.Env): if self.metric_conj == 'and': return 0 - expected = self.evaluator["expected"][idx] - expected_state = self.expected_getter[idx](self, expected) if expected else None - - metric: int = metric(result_state, expected_state, - **self.metric_options[idx]) if expected_state is not None \ - else metric(result_state, **self.metric_options[idx]) + if "expected" in self.evaluator: + expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx]) + metric: int = metric(result_state, expected_state, **self.metric_options[idx]) + else: + metric: int = metric(result_state, **self.metric_options[idx]) if self.metric_conj == 'and' and float(metric) == 0.0: return 0 @@ -297,20 +300,21 @@ class DesktopEnv(gym.Env): return 1 else: results.append(metric) + return sum(results) / len(results) if self.metric_conj == 'and' else max(results) else: + # Single metric to evaluate whether the task is successfully completed try: result_state = self.result_getter(self, self.evaluator["result"]) except FileNotFoundError: logger.error("File not found!") return 0 - expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \ - else None - - metric: float = self.metric(result_state, expected_state, - **self.metric_options) if expected_state is not None \ - else self.metric(result_state, **self.metric_options) + if "expected" in self.evaluator: + expected_state = self.expected_getter(self, self.evaluator["expected"]) + metric: float = self.metric(result_state, expected_state, **self.metric_options) + else: + metric: float = self.metric(result_state, **self.metric_options) return metric