Fix minor bugs
This commit is contained in:
@@ -112,6 +112,7 @@ class DesktopEnv(gym.Env):
|
|||||||
bytes_per_pixel = bits_per_pixel // 8
|
bytes_per_pixel = bits_per_pixel // 8
|
||||||
vram_size = width * height * bytes_per_pixel
|
vram_size = width * height * bytes_per_pixel
|
||||||
return vram_size
|
return vram_size
|
||||||
|
|
||||||
if not os.path.isfile(self.path_to_vm):
|
if not os.path.isfile(self.path_to_vm):
|
||||||
logger.warning(f"The specified vmx file does not exist: {self.path_to_vm}")
|
logger.warning(f"The specified vmx file does not exist: {self.path_to_vm}")
|
||||||
return False
|
return False
|
||||||
@@ -213,26 +214,42 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
# evaluator dict
|
# evaluator dict
|
||||||
# func -> metric function string, or list of metric function strings
|
# func -> metric function string, or list of metric function strings
|
||||||
# conj -> conjuction of multiple metrics if func is a list with length > 1, "and"/"or"
|
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
|
||||||
# result -> result getter config, or list of result getter configs
|
# result -> result getter config, or list of result getter configs
|
||||||
# expected (optional) -> expected getter config, or list of expected getter configs
|
# expected (optional) -> expected getter config, or list of expected getter configs
|
||||||
# options (optional) -> metric options, or list of metric options
|
# options (optional) -> metric options, or list of metric options
|
||||||
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
|
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
|
||||||
# even if one of the metrics does not need expected or options field, it should be included in the list with None
|
# even if one of the metrics does not need expected or options field, it should be included in the list with None
|
||||||
self.evaluator = task_config["evaluator"]
|
self.evaluator = task_config["evaluator"]
|
||||||
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] if type(self.evaluator["func"]) == list \
|
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \
|
||||||
|
if isinstance(self.evaluator["func"], list) \
|
||||||
else getattr(metrics, self.evaluator["func"])
|
else getattr(metrics, self.evaluator["func"])
|
||||||
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjuction of multiple metrics
|
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
||||||
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in self.evaluator["result"]] \
|
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
||||||
if type(self.evaluator["result"]) == list else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
|
self.evaluator["result"]] \
|
||||||
|
if isinstance(self.evaluator["result"], list) \
|
||||||
|
else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
|
||||||
if "expected" in self.evaluator:
|
if "expected" in self.evaluator:
|
||||||
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in self.evaluator["expected"]] \
|
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
||||||
if type(self.evaluator["expected"]) == list else getattr(getters, "get_{:}".format())
|
self.evaluator["expected"]] \
|
||||||
else: self.expected_getter = [None for _ in len(self.metric)] if type(self.metric) == list else None
|
if isinstance(self.evaluator["expected"], list) \
|
||||||
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in self.evaluator["options"]] \
|
else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
|
||||||
if type(self.evaluator.get("options", {})) == list else self.evaluator["options"] if "options" in self.evaluator else \
|
else:
|
||||||
[{} for _ in len(self.metric)] if type(self.metric) == list else {}
|
self.expected_getter = [None] * len(self.metric) \
|
||||||
assert type(self.evaluator["func"]) != list or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(self.metric_options))
|
if isinstance(self.metric, list) \
|
||||||
|
else None
|
||||||
|
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in
|
||||||
|
self.evaluator["options"]] \
|
||||||
|
if isinstance(self.evaluator.get("options", {}), list) \
|
||||||
|
else self.evaluator["options"] \
|
||||||
|
if "options" in self.evaluator \
|
||||||
|
else [{}] * len(self.metric) \
|
||||||
|
if isinstance(self.metric, list) \
|
||||||
|
else {}
|
||||||
|
|
||||||
|
assert (not isinstance(self.evaluator["func"], list)
|
||||||
|
or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(
|
||||||
|
self.metric_options)))
|
||||||
|
|
||||||
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
||||||
logger.info("Resetting environment...")
|
logger.info("Resetting environment...")
|
||||||
@@ -340,7 +357,8 @@ class DesktopEnv(gym.Env):
|
|||||||
expected = self.evaluator["expected"][idx]
|
expected = self.evaluator["expected"][idx]
|
||||||
expected_state = self.expected_getter[idx](self, expected) if expected else None
|
expected_state = self.expected_getter[idx](self, expected) if expected else None
|
||||||
|
|
||||||
metric: int = metric(result_state, expected_state, **self.metric_options[idx]) if expected_state is not None \
|
metric: int = metric(result_state, expected_state,
|
||||||
|
**self.metric_options[idx]) if expected_state is not None \
|
||||||
else metric(result_state, **self.metric_options[idx])
|
else metric(result_state, **self.metric_options[idx])
|
||||||
|
|
||||||
if self.metric_conj == 'and' and not bool(metric):
|
if self.metric_conj == 'and' and not bool(metric):
|
||||||
@@ -358,7 +376,8 @@ class DesktopEnv(gym.Env):
|
|||||||
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
|
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
|
||||||
else None
|
else None
|
||||||
|
|
||||||
metric: float = self.metric(result_state, expected_state, **self.metric_options) if expected_state is not None \
|
metric: float = self.metric(result_state, expected_state,
|
||||||
|
**self.metric_options) if expected_state is not None \
|
||||||
else self.metric(result_state, **self.metric_options)
|
else self.metric(result_state, **self.metric_options)
|
||||||
|
|
||||||
return metric
|
return metric
|
||||||
|
|||||||
Reference in New Issue
Block a user