ver Jan4th

updated interfaces for thunderbird evaluation, not tested
This commit is contained in:
David Chang
2024-01-04 22:41:57 +08:00
parent f831aa93df
commit 5fedf5b891
10 changed files with 361 additions and 93 deletions

View File

@@ -68,19 +68,8 @@ class DesktopEnv(gym.Env):
self.cache_dir_base: str = cache_dir
# task-aware stuffs
self.snapshot_path = task_config["snapshot"] # todo: handling the logic of snapshot directory
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"]
self.evaluator = task_config["evaluator"]
self.metric: Metric = getattr(metrics, self.evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(
self.evaluator["expected"]["type"])) if "expected" in self.evaluator else None
self.metric_options: Dict[str, Any] = self.evaluator.get("options", {})
# todo: handling the logic of snapshot directory
self._set_task_info(task_config)
# Initialize emulator and controller
print("Initializing...")
@@ -151,25 +140,27 @@ class DesktopEnv(gym.Env):
screenshot_image_path = self._get_screenshot()
return screenshot_image_path
def _set_task_info(self, task_config: Dict[str, Any]):
self.snapshot_path = task_config["snapshot"]
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"]
self.evaluator = task_config["evaluator"]
self.metric: Metric = getattr(metrics, self.evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(
self.evaluator["expected"]["type"])) if "expected" in self.evaluator else None
self.metric_options: Dict[str, Any] = self.evaluator.get("options", {})
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None):
print("Resetting environment...")
print("Switching task...")
if task_config is not None:
self.snapshot_path = task_config["snapshot"]
self.task_id = task_config["id"]
self.cache_dir = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"]
self.evaluator = task_config["evaluator"]
self.metric: Metric = getattr(metrics, self.evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(
self.evaluator["expected"]["type"])) if "expected" in self.evaluator else None
self.metric_options = self.evaluator.get("options", {})
self._set_task_info(task_config)
self.setup_controller.reset_cache_dir(self.cache_dir)
print("Setting counters...")
@@ -228,6 +219,9 @@ class DesktopEnv(gym.Env):
"""
Evaluate whether the task is successfully completed.
"""
self.setup_controller.setup(self.evaluator["postconfig"])
result_state = self.result_getter(self, self.evaluator["result"])
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
else None