Enhance metric evaluation in DesktopEnv

- Add assertions to ensure the number of metrics matches the number of result and expected getters.
- Refactor metric calculation logic to handle cases with and without expected values more clearly.
- Improve comments for better understanding of single and multiple metric evaluations.
This commit is contained in:
Timothyxxx
2025-04-02 23:45:56 +08:00
parent d373817edb
commit ec583d6f0c

View File

@@ -274,7 +274,11 @@ class DesktopEnv(gym.Env):
return 0
if type(self.metric) == list:
# Multiple metrics to evaluate whether the task is successfully completed
results = []
assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same"
if "expected" in self.evaluator:
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
for idx, metric in enumerate(self.metric):
try:
config = self.evaluator["result"][idx]
@@ -284,12 +288,11 @@ class DesktopEnv(gym.Env):
if self.metric_conj == 'and':
return 0
expected = self.evaluator["expected"][idx]
expected_state = self.expected_getter[idx](self, expected) if expected else None
metric: int = metric(result_state, expected_state,
**self.metric_options[idx]) if expected_state is not None \
else metric(result_state, **self.metric_options[idx])
if "expected" in self.evaluator:
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
metric: int = metric(result_state, expected_state, **self.metric_options[idx])
else:
metric: int = metric(result_state, **self.metric_options[idx])
if self.metric_conj == 'and' and float(metric) == 0.0:
return 0
@@ -297,20 +300,21 @@ class DesktopEnv(gym.Env):
return 1
else:
results.append(metric)
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
else:
# Single metric to evaluate whether the task is successfully completed
try:
result_state = self.result_getter(self, self.evaluator["result"])
except FileNotFoundError:
logger.error("File not found!")
return 0
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
else None
metric: float = self.metric(result_state, expected_state,
**self.metric_options) if expected_state is not None \
else self.metric(result_state, **self.metric_options)
if "expected" in self.evaluator:
expected_state = self.expected_getter(self, self.evaluator["expected"])
metric: float = self.metric(result_state, expected_state, **self.metric_options)
else:
metric: float = self.metric(result_state, **self.metric_options)
return metric