Enhance metric evaluation in DesktopEnv
- Add assertions to ensure the number of metrics matches the number of result and expected getters. - Refactor metric calculation logic to handle cases with and without expected values more clearly. - Improve comments for better understanding of single and multiple metric evaluations.
This commit is contained in:
@@ -274,7 +274,11 @@ class DesktopEnv(gym.Env):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
if type(self.metric) == list:
|
if type(self.metric) == list:
|
||||||
|
# Multiple metrics to evaluate whether the task is successfully completed
|
||||||
results = []
|
results = []
|
||||||
|
assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same"
|
||||||
|
if "expected" in self.evaluator:
|
||||||
|
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
|
||||||
for idx, metric in enumerate(self.metric):
|
for idx, metric in enumerate(self.metric):
|
||||||
try:
|
try:
|
||||||
config = self.evaluator["result"][idx]
|
config = self.evaluator["result"][idx]
|
||||||
@@ -284,12 +288,11 @@ class DesktopEnv(gym.Env):
|
|||||||
if self.metric_conj == 'and':
|
if self.metric_conj == 'and':
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
expected = self.evaluator["expected"][idx]
|
if "expected" in self.evaluator:
|
||||||
expected_state = self.expected_getter[idx](self, expected) if expected else None
|
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
|
||||||
|
metric: int = metric(result_state, expected_state, **self.metric_options[idx])
|
||||||
metric: int = metric(result_state, expected_state,
|
else:
|
||||||
**self.metric_options[idx]) if expected_state is not None \
|
metric: int = metric(result_state, **self.metric_options[idx])
|
||||||
else metric(result_state, **self.metric_options[idx])
|
|
||||||
|
|
||||||
if self.metric_conj == 'and' and float(metric) == 0.0:
|
if self.metric_conj == 'and' and float(metric) == 0.0:
|
||||||
return 0
|
return 0
|
||||||
@@ -297,20 +300,21 @@ class DesktopEnv(gym.Env):
|
|||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
results.append(metric)
|
results.append(metric)
|
||||||
|
|
||||||
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
|
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
|
||||||
else:
|
else:
|
||||||
|
# Single metric to evaluate whether the task is successfully completed
|
||||||
try:
|
try:
|
||||||
result_state = self.result_getter(self, self.evaluator["result"])
|
result_state = self.result_getter(self, self.evaluator["result"])
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error("File not found!")
|
logger.error("File not found!")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
|
if "expected" in self.evaluator:
|
||||||
else None
|
expected_state = self.expected_getter(self, self.evaluator["expected"])
|
||||||
|
metric: float = self.metric(result_state, expected_state, **self.metric_options)
|
||||||
metric: float = self.metric(result_state, expected_state,
|
else:
|
||||||
**self.metric_options) if expected_state is not None \
|
metric: float = self.metric(result_state, **self.metric_options)
|
||||||
else self.metric(result_state, **self.metric_options)
|
|
||||||
|
|
||||||
return metric
|
return metric
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user