diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 2328547..ada435d 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -5,16 +5,20 @@ import subprocess import time #import uuid #import platform -from typing import List +from typing import List, Dict +from typing import Callable, Any import tempfile import gymnasium as gym -import requests +#import requests from desktop_env.controllers.python import PythonController from desktop_env.controllers.setup import SetupController -from desktop_env.evaluators import eval_funcs +#from desktop_env.evaluators import eval_funcs +from desktop_env.evaluators import metrics, getters +Metric = Callable[[Any, Any], float] +Getter = Callable[[gym.Env, Dict[str, Any]], Any] def _execute_command(command: List[str]) -> None: if command[:4] == ["vmrun", "-T", "ws", "start"]: @@ -56,7 +60,14 @@ class DesktopEnv(gym.Env): like { "func": str as the evaluation method - "paths": paths to the involved files needed for evaluation + "result": { + "type": str + ...other keys + } + "condition": { + "type": str + ...other keys + } } action_space (str): "computer_13" | "pyautogui" @@ -84,7 +95,11 @@ class DesktopEnv(gym.Env): self.setup_controller = SetupController(http_server=self.host) self.instruction = instruction self.config = config + self.evaluator = evaluator + self.metric: Metric = getattr(metrics, evaluator["func"]) + self.result_getter: Getter = getattr(getters, "get_{:}".format(evaluator["result"]["type"])) + self.expected_getter: Getter = getattr(getters, "get_{:}".format(evaluator["expected"]["type"])) # mode: human or machine assert action_space in ["computer_13", "pyautogui"] @@ -201,42 +216,22 @@ class DesktopEnv(gym.Env): """ Evaluate whether the task is successfully completed. """ - def copy_file_to_local(_file_info): - #random_uuid = str(uuid.uuid4()) - #os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True) - #_path = os.path.join("tmp", random_uuid, "tmp.xlsx") - _path = os.path.join(self.cache_dir, _file_info["dest"]) - if os.path.exists(_path): - return _path - - if _file_info["type"] == "cloud_file": - url = _file_info["path"] - response = requests.get(url, stream=True, verify=False) - response.raise_for_status() - - with open(_path, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - elif _file_info["type"] == "vm_file": - # fixme: stream this part maybe as well - file = self.controller.get_file(_file_info["path"]) - with open(_path, "wb") as f: - f.write(file) - else: - raise NotImplementedError - - return _path # todo: make this more flexible by refactoring - eval_func = eval_funcs[self.evaluator["func"]] - eval_func_vars = {} + #eval_func = eval_funcs[self.evaluator["func"]] + #eval_func_vars = {} +# + #for var_name, file_info in self.evaluator["paths"].items(): + #path = copy_file_to_local(file_info) + #eval_func_vars[var_name] = path +# + #return eval_func(**eval_func_vars) - for var_name, file_info in self.evaluator["paths"].items(): - path = copy_file_to_local(file_info) - eval_func_vars[var_name] = path + result = self.result_getter(self, self.evaluator["result"]) + expected = self.expected_getter(self, self.evaluator["expected"]) + metric: float = self.metric(result, expected) - return eval_func(**eval_func_vars) + return metric def render(self, mode='rgb_array'): if mode == 'rgb_array': diff --git a/desktop_env/evaluators/__init__.py b/desktop_env/evaluators/__init__.py index 32a19e7..c88feff 100644 --- a/desktop_env/evaluators/__init__.py +++ b/desktop_env/evaluators/__init__.py @@ -1,5 +1,5 @@ -from .table import compare_table +#from .table import compare_table -eval_funcs = { - "compare_table(expected, actual)": compare_table -} +#eval_funcs = { + #"compare_table(expected, actual)": compare_table +#} diff --git a/desktop_env/evaluators/getters/__init__.py b/desktop_env/evaluators/getters/__init__.py new file mode 100644 index 0000000..7d8c3fd --- /dev/null +++ b/desktop_env/evaluators/getters/__init__.py @@ -0,0 +1 @@ +from .file import get_cloud_file, get_vm_file diff --git a/desktop_env/evaluators/getters/file.py b/desktop_env/evaluators/getters/file.py new file mode 100644 index 0000000..0c3969f --- /dev/null +++ b/desktop_env/evaluators/getters/file.py @@ -0,0 +1,43 @@ +from typing import Dict + +import os +import requests + +def get_cloud_file(env, config: Dict[str, str]) -> str: + """ + Config: + path (str): the url to download from + dest (str): file name of the downloaded file + """ + + _path = os.path.join(env.cache_dir, config["dest"]) + if os.path.exists(_path): + return _path + + url = config["path"] + response = requests.get(url, stream=True, verify=False) + response.raise_for_status() + + with open(_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + return _path + +def get_vm_file(env, config: Dict[str, str]) -> str: + """ + Config: + path (str): absolute path on the VM to fetch + dest (str): file name of the downloaded file + """ + + _path = os.path.join(env.cache_dir, config["dest"]) + if os.path.exists(_path): + return _path + + file = env.controller.get_file(config["path"]) + with open(_path, "wb") as f: + f.write(file) + + return _path diff --git a/desktop_env/evaluators/replay.py b/desktop_env/evaluators/getters/replay.py similarity index 100% rename from desktop_env/evaluators/replay.py rename to desktop_env/evaluators/getters/replay.py diff --git a/desktop_env/evaluators/metrics/__init__.py b/desktop_env/evaluators/metrics/__init__.py new file mode 100644 index 0000000..740c9e2 --- /dev/null +++ b/desktop_env/evaluators/metrics/__init__.py @@ -0,0 +1 @@ +from .table import compare_table diff --git a/desktop_env/evaluators/table.py b/desktop_env/evaluators/metrics/table.py similarity index 100% rename from desktop_env/evaluators/table.py rename to desktop_env/evaluators/metrics/table.py diff --git a/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json b/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json index ce816b2..078cc35 100644 --- a/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json +++ b/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json @@ -27,18 +27,16 @@ "libreoffice calc" ], "evaluator": { - "func": "compare_table(expected, actual)", - "paths": { - "expected": { - "type": "cloud_file", - "path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx", - "dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx" - }, - "actual": { - "type": "vm_file", - "path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx", - "dest": "Quarterly_Product_Sales_by_Zone.xlsx" - } - } + "func": "compare_table", + "result": { + "type": "cloud_file", + "path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx", + "dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx" + }, + "expected": { + "type": "vm_file", + "path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx", + "dest": "Quarterly_Product_Sales_by_Zone.xlsx" + } } }