Enhance metric evaluation in DesktopEnv
- Add assertions to ensure the number of metrics matches the number of result and expected getters. - Refactor metric calculation logic to handle cases with and without expected values more clearly. - Improve comments for better understanding of single and multiple metric evaluations.
This commit is contained in:
@@ -274,7 +274,11 @@ class DesktopEnv(gym.Env):
|
||||
return 0
|
||||
|
||||
if type(self.metric) == list:
|
||||
# Multiple metrics to evaluate whether the task is successfully completed
|
||||
results = []
|
||||
assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same"
|
||||
if "expected" in self.evaluator:
|
||||
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
|
||||
for idx, metric in enumerate(self.metric):
|
||||
try:
|
||||
config = self.evaluator["result"][idx]
|
||||
@@ -284,12 +288,11 @@ class DesktopEnv(gym.Env):
|
||||
if self.metric_conj == 'and':
|
||||
return 0
|
||||
|
||||
expected = self.evaluator["expected"][idx]
|
||||
expected_state = self.expected_getter[idx](self, expected) if expected else None
|
||||
|
||||
metric: int = metric(result_state, expected_state,
|
||||
**self.metric_options[idx]) if expected_state is not None \
|
||||
else metric(result_state, **self.metric_options[idx])
|
||||
if "expected" in self.evaluator:
|
||||
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
|
||||
metric: int = metric(result_state, expected_state, **self.metric_options[idx])
|
||||
else:
|
||||
metric: int = metric(result_state, **self.metric_options[idx])
|
||||
|
||||
if self.metric_conj == 'and' and float(metric) == 0.0:
|
||||
return 0
|
||||
@@ -297,20 +300,21 @@ class DesktopEnv(gym.Env):
|
||||
return 1
|
||||
else:
|
||||
results.append(metric)
|
||||
|
||||
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
|
||||
else:
|
||||
# Single metric to evaluate whether the task is successfully completed
|
||||
try:
|
||||
result_state = self.result_getter(self, self.evaluator["result"])
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
return 0
|
||||
|
||||
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
|
||||
else None
|
||||
|
||||
metric: float = self.metric(result_state, expected_state,
|
||||
**self.metric_options) if expected_state is not None \
|
||||
else self.metric(result_state, **self.metric_options)
|
||||
if "expected" in self.evaluator:
|
||||
expected_state = self.expected_getter(self, self.evaluator["expected"])
|
||||
metric: float = self.metric(result_state, expected_state, **self.metric_options)
|
||||
else:
|
||||
metric: float = self.metric(result_state, **self.metric_options)
|
||||
|
||||
return metric
|
||||
|
||||
|
||||
Reference in New Issue
Block a user