diff --git a/desktop_env/controllers/setup.py b/desktop_env/controllers/setup.py index a550389..b42b2ff 100644 --- a/desktop_env/controllers/setup.py +++ b/desktop_env/controllers/setup.py @@ -8,13 +8,17 @@ class SetupController: def __init__(self, http_server: str): self.http_server = http_server + "/setup" - def setup(self, config): + def setup(self, config: List[Dict[str, Any]]): """ - Setup Config: - { - download: list[tuple[string]], # a list of tuples of url of file to download and the save path - ... - } + Args: + config (List[Dict[str, Any]]): list of dict like {str: Any}. each + config dict has the structure like + { + "type": str, corresponding to the `_{:}_setup` methods of + this class + "parameters": dick like {str, Any} providing the keyword + parameters + } """ for cfg in config: diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 34270c8..d058d1f 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -3,9 +3,10 @@ from __future__ import annotations import os import subprocess import time -import uuid -import platform +#import uuid +#import platform from typing import List +import tempfile import gymnasium as gym import requests @@ -33,15 +34,48 @@ class DesktopEnv(gym.Env): self, path_to_vm: str, snapshot_path: str = "base", + task_id: str = "", instruction: str = None, config: dict = None, evaluator: dict = None, action_space: str = "computer_13", + tmp_dir: str = "tmp", + cache_dir: str = "cache" ): + """ + Args: + path_to_vm (str): path to .vmx file + snapshot_path (str): name of the based snapshot, can be found in + `snapshot*.displayName` field in .vmsd file + + task_id (str): identifying the task + instruction (str): task instruction + config (List[Dict[str, Any]]): the config dict, refer to the doc of + SetupController + evaluator (Dict[str, Any]): defines the evaluation method. dict + like + { + "func": str as the evaluation method + "paths": paths to the involved files needed for evaluation + } + action_space (str): "computer_13" | "pyautogui" + + tmp_dir (str): temporary directory to store trajectory stuffs like + the extracted screenshots + cache_dir (str): cache directory to cache task-related stuffs like + reference file for evaluation + """ + # Initialize environment variables - self.path_to_vm = path_to_vm + self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory + self.task_id: str = task_id + self.tmp_dir_base: str = tmp_dir + self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset + self.cache_dir: str = os.path.join(cache_dir, task_id) + os.makedirs(self.cache_dir, exist_ok=True) + # Initialize emulator and controller print("Initializing...") self._start_emulator() @@ -57,12 +91,18 @@ class DesktopEnv(gym.Env): self.action_space = action_space # todo: define the action space and the observation space as gym did, or extend theirs + # counters + self._traj_no: int = -1 + self._step_no: int = 0 + def _start_emulator(self): while True: try: output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT) output = output.decode() - if self.path_to_vm.lstrip("~/") in output: + output: List[str] = output.splitlines() + #if self.path_to_vm.lstrip("~/") in output: + if self.path_to_vm in output: print("VM is running.") break else: @@ -89,9 +129,10 @@ class DesktopEnv(gym.Env): _execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path]) def _get_screenshot(self): - random_uuid = str(uuid.uuid4()) - os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True) - image_path = os.path.join("tmp", random_uuid, "screenshot.png") + #random_uuid = str(uuid.uuid4()) + #os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True) + #image_path = os.path.join("tmp", random_uuid, "screenshot.png") + image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no)) # Get the screenshot and save to the image_path screenshot = self.controller.get_screenshot() @@ -107,6 +148,16 @@ class DesktopEnv(gym.Env): def reset(self, seed=None, options=None): print("Resetting environment...") + print("Setting counters...") + self._traj_no += 1 + self._step_no = 0 + + print("Setup new temp dir...") + self.tmp_dir = tempfile.mkdtemp( prefix="{:d}.{:}.".format(self._traj_no, self.task_id) + , dir=self.tmp_dir_base + ) + os.makedirs(os.path.join(self.tmp_dir, "screenshots")) + print("Reverting to snapshot to {}...".format(self.snapshot_path)) _execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path]) time.sleep(5) @@ -125,6 +176,8 @@ class DesktopEnv(gym.Env): return observation def step(self, action, pause=0.5): + self._step_no += 1 + # fixme: add reminding logic here, decide if the action is valid for the current action_space if self.action_space == "computer_13": # the set of all possible actions defined in the action representation @@ -149,9 +202,13 @@ class DesktopEnv(gym.Env): Evaluate whether the task is successfully completed. """ def copy_file_to_local(_file_info): - random_uuid = str(uuid.uuid4()) - os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True) - _path = os.path.join("tmp", random_uuid, "tmp.xlsx") + #random_uuid = str(uuid.uuid4()) + #os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True) + #_path = os.path.join("tmp", random_uuid, "tmp.xlsx") + _path = os.path.join(self.cache_dir, _file_info["dest"]) + if os.path.exists(_path): + return _path + if _file_info["type"] == "cloud_file": url = _file_info["path"] response = requests.get(url, stream=True) diff --git a/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json b/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json index b3f61c2..bd04600 100644 --- a/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json +++ b/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json @@ -32,10 +32,12 @@ "expected": { "type": "cloud_file", "path": "https://drive.usercontent.google.com/download?id=17f1wZuJPvUEc5at_Fy3c18VFdOk0x7xz&export=download&authuser=0&confirm=t&uuid=6d2edffd-0ce0-426e-9820-8af25b4667f3&at=APZUnTVh7JS85dwZBaV2hytWQgDK:1702361510956" + "dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx" }, "actual": { "type": "vm_file", "path": "C:\\Users\\tianbaox\\Desktop\\Quarterly_Product_Sales_by_Zone.xlsx" + "dest": "Quarterly_Product_Sales_by_Zone.xlsx" } } } diff --git a/main.py b/main.py index 1c6a242..7be30f0 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ def human_agent(): env = DesktopEnv( path_to_vm=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx" , action_space="computer_13" , snapshot_path="base_setup" + , task_id=example["id"] , instruction=example["instruction"] , config=example["config"] , evaluator=example["evaluator"]