diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index ada435d..712ade8 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -6,7 +6,7 @@ import time #import uuid #import platform from typing import List, Dict -from typing import Callable, Any +from typing import Callable, Any, Optional import tempfile import gymnasium as gym @@ -37,40 +37,24 @@ class DesktopEnv(gym.Env): def __init__( self, path_to_vm: str, - snapshot_path: str = "base", - task_id: str = "", - instruction: str = None, - config: dict = None, - evaluator: dict = None, action_space: str = "computer_13", + task_config: Dict[str, Any] = None, tmp_dir: str = "tmp", cache_dir: str = "cache" ): """ Args: path_to_vm (str): path to .vmx file - snapshot_path (str): name of the based snapshot, can be found in - `snapshot*.displayName` field in .vmsd file - - task_id (str): identifying the task - instruction (str): task instruction - config (List[Dict[str, Any]]): the config dict, refer to the doc of - SetupController - evaluator (Dict[str, Any]): defines the evaluation method. dict - like - { - "func": str as the evaluation method - "result": { - "type": str - ...other keys - } - "condition": { - "type": str - ...other keys - } - } action_space (str): "computer_13" | "pyautogui" + task_config (Dict[str, Any]): manages task configs integratedly, + including + * base snapshot + * task id (uuid) + * instruction + * setup config + * evaluator config + tmp_dir (str): temporary directory to store trajectory stuffs like the extracted screenshots cache_dir (str): cache directory to cache task-related stuffs like @@ -79,13 +63,8 @@ class DesktopEnv(gym.Env): # Initialize environment variables self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) - self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory - - self.task_id: str = task_id self.tmp_dir_base: str = tmp_dir - self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset - self.cache_dir: str = os.path.join(cache_dir, task_id) - os.makedirs(self.cache_dir, exist_ok=True) + self.cache_dir_base: str = cache_dir # Initialize emulator and controller print("Initializing...") @@ -93,20 +72,27 @@ class DesktopEnv(gym.Env): self.host = f"http://{self._get_vm_ip()}:5000" self.controller = PythonController(http_server=self.host) self.setup_controller = SetupController(http_server=self.host) - self.instruction = instruction - self.config = config - - self.evaluator = evaluator - self.metric: Metric = getattr(metrics, evaluator["func"]) - self.result_getter: Getter = getattr(getters, "get_{:}".format(evaluator["result"]["type"])) - self.expected_getter: Getter = getattr(getters, "get_{:}".format(evaluator["expected"]["type"])) # mode: human or machine assert action_space in ["computer_13", "pyautogui"] self.action_space = action_space # todo: define the action space and the observation space as gym did, or extend theirs - # counters + # task-aware stuffs + self.snapshot_path = task_config["snapshot"] # todo: handling the logic of snapshot directory + self.task_id: str = task_config["id"] + self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id) + os.makedirs(self.cache_dir, exist_ok=True) + self.instruction = task_config["instruction"] + self.config = task_config["config"] + + self.evaluator = task_config["evaluator"] + self.metric: Metric = getattr(metrics, self.evaluator["func"]) + self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"])) + self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"])) + + # episodic stuffs, like tmp dir and counters + self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset self._traj_no: int = -1 self._step_no: int = 0 @@ -160,9 +146,23 @@ class DesktopEnv(gym.Env): screenshot_image_path = self._get_screenshot() return screenshot_image_path - def reset(self, seed=None, options=None): + def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None): print("Resetting environment...") + print("Switching task...") + if task_config is not None: + self.snapshot_path = task_config["snapshot"] + self.task_id = task_config["id"] + self.cache_dir = os.path.join(self.cache_dir_base, self.task_id) + os.makedirs(self.cache_dir, exist_ok=True) + self.instruction = task_config["instruction"] + self.config = task_config["config"] + + self.evaluator = task_config["evaluator"] + self.metric: Metric = getattr(metrics, self.evaluator["func"]) + self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"])) + self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"])) + print("Setting counters...") self._traj_no += 1 self._step_no = 0 diff --git a/main.py b/main.py index 4fbe822..911b34f 100644 --- a/main.py +++ b/main.py @@ -9,17 +9,13 @@ def human_agent(): with open("evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json", "r") as f: example = json.load(f) + example["snapshot"] = "Init6" #env = DesktopEnv( path_to_vm="/home/yuri/vmware/Windows 10 x64/Windows 10 x64.vmx" # path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx", env = DesktopEnv( path_to_vm="/home/david/vmware/KUbuntu 64-bit/KUbuntu 64-bit.vmx" , action_space="computer_13" - #, snapshot_path="base_setup" - , snapshot_path="Init6" - , task_id=example["id"] - , instruction=example["instruction"] - , config=example["config"] - , evaluator=example["evaluator"] + , task_config=example ) # reset the environment to certain snapshot