ver Dec22nd

re-organized the evaluator structure to improve the extensibility
This commit is contained in:
David Chang
2023-12-22 14:01:26 +08:00
parent 295d09f1b2
commit f4664bd069
8 changed files with 91 additions and 53 deletions

View File

@@ -5,16 +5,20 @@ import subprocess
import time import time
#import uuid #import uuid
#import platform #import platform
from typing import List from typing import List, Dict
from typing import Callable, Any
import tempfile import tempfile
import gymnasium as gym import gymnasium as gym
import requests #import requests
from desktop_env.controllers.python import PythonController from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController from desktop_env.controllers.setup import SetupController
from desktop_env.evaluators import eval_funcs #from desktop_env.evaluators import eval_funcs
from desktop_env.evaluators import metrics, getters
Metric = Callable[[Any, Any], float]
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
def _execute_command(command: List[str]) -> None: def _execute_command(command: List[str]) -> None:
if command[:4] == ["vmrun", "-T", "ws", "start"]: if command[:4] == ["vmrun", "-T", "ws", "start"]:
@@ -56,7 +60,14 @@ class DesktopEnv(gym.Env):
like like
{ {
"func": str as the evaluation method "func": str as the evaluation method
"paths": paths to the involved files needed for evaluation "result": {
"type": str
...other keys
}
"condition": {
"type": str
...other keys
}
} }
action_space (str): "computer_13" | "pyautogui" action_space (str): "computer_13" | "pyautogui"
@@ -84,7 +95,11 @@ class DesktopEnv(gym.Env):
self.setup_controller = SetupController(http_server=self.host) self.setup_controller = SetupController(http_server=self.host)
self.instruction = instruction self.instruction = instruction
self.config = config self.config = config
self.evaluator = evaluator self.evaluator = evaluator
self.metric: Metric = getattr(metrics, evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(evaluator["expected"]["type"]))
# mode: human or machine # mode: human or machine
assert action_space in ["computer_13", "pyautogui"] assert action_space in ["computer_13", "pyautogui"]
@@ -201,42 +216,22 @@ class DesktopEnv(gym.Env):
""" """
Evaluate whether the task is successfully completed. Evaluate whether the task is successfully completed.
""" """
def copy_file_to_local(_file_info):
#random_uuid = str(uuid.uuid4())
#os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
#_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
_path = os.path.join(self.cache_dir, _file_info["dest"])
if os.path.exists(_path):
return _path
if _file_info["type"] == "cloud_file":
url = _file_info["path"]
response = requests.get(url, stream=True, verify=False)
response.raise_for_status()
with open(_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
elif _file_info["type"] == "vm_file":
# fixme: stream this part maybe as well
file = self.controller.get_file(_file_info["path"])
with open(_path, "wb") as f:
f.write(file)
else:
raise NotImplementedError
return _path
# todo: make this more flexible by refactoring # todo: make this more flexible by refactoring
eval_func = eval_funcs[self.evaluator["func"]] #eval_func = eval_funcs[self.evaluator["func"]]
eval_func_vars = {} #eval_func_vars = {}
#
#for var_name, file_info in self.evaluator["paths"].items():
#path = copy_file_to_local(file_info)
#eval_func_vars[var_name] = path
#
#return eval_func(**eval_func_vars)
for var_name, file_info in self.evaluator["paths"].items(): result = self.result_getter(self, self.evaluator["result"])
path = copy_file_to_local(file_info) expected = self.expected_getter(self, self.evaluator["expected"])
eval_func_vars[var_name] = path metric: float = self.metric(result, expected)
return eval_func(**eval_func_vars) return metric
def render(self, mode='rgb_array'): def render(self, mode='rgb_array'):
if mode == 'rgb_array': if mode == 'rgb_array':

View File

@@ -1,5 +1,5 @@
from .table import compare_table #from .table import compare_table
eval_funcs = { #eval_funcs = {
"compare_table(expected, actual)": compare_table #"compare_table(expected, actual)": compare_table
} #}

View File

@@ -0,0 +1 @@
from .file import get_cloud_file, get_vm_file

View File

@@ -0,0 +1,43 @@
from typing import Dict
import os
import requests
def get_cloud_file(env, config: Dict[str, str]) -> str:
"""
Config:
path (str): the url to download from
dest (str): file name of the downloaded file
"""
_path = os.path.join(env.cache_dir, config["dest"])
if os.path.exists(_path):
return _path
url = config["path"]
response = requests.get(url, stream=True, verify=False)
response.raise_for_status()
with open(_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return _path
def get_vm_file(env, config: Dict[str, str]) -> str:
"""
Config:
path (str): absolute path on the VM to fetch
dest (str): file name of the downloaded file
"""
_path = os.path.join(env.cache_dir, config["dest"])
if os.path.exists(_path):
return _path
file = env.controller.get_file(config["path"])
with open(_path, "wb") as f:
f.write(file)
return _path

View File

@@ -0,0 +1 @@
from .table import compare_table

View File

@@ -27,18 +27,16 @@
"libreoffice calc" "libreoffice calc"
], ],
"evaluator": { "evaluator": {
"func": "compare_table(expected, actual)", "func": "compare_table",
"paths": { "result": {
"expected": { "type": "cloud_file",
"type": "cloud_file", "path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx",
"path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx", "dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx"
"dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx" },
}, "expected": {
"actual": { "type": "vm_file",
"type": "vm_file", "path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx",
"path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx", "dest": "Quarterly_Product_Sales_by_Zone.xlsx"
"dest": "Quarterly_Product_Sales_by_Zone.xlsx" }
}
}
} }
} }