ver Dec22nd
re-organized the evaluator structure to improve the extensibility
This commit is contained in:
@@ -5,16 +5,20 @@ import subprocess
|
|||||||
import time
|
import time
|
||||||
#import uuid
|
#import uuid
|
||||||
#import platform
|
#import platform
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
from typing import Callable, Any
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import requests
|
#import requests
|
||||||
|
|
||||||
from desktop_env.controllers.python import PythonController
|
from desktop_env.controllers.python import PythonController
|
||||||
from desktop_env.controllers.setup import SetupController
|
from desktop_env.controllers.setup import SetupController
|
||||||
from desktop_env.evaluators import eval_funcs
|
#from desktop_env.evaluators import eval_funcs
|
||||||
|
from desktop_env.evaluators import metrics, getters
|
||||||
|
|
||||||
|
Metric = Callable[[Any, Any], float]
|
||||||
|
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
||||||
|
|
||||||
def _execute_command(command: List[str]) -> None:
|
def _execute_command(command: List[str]) -> None:
|
||||||
if command[:4] == ["vmrun", "-T", "ws", "start"]:
|
if command[:4] == ["vmrun", "-T", "ws", "start"]:
|
||||||
@@ -56,7 +60,14 @@ class DesktopEnv(gym.Env):
|
|||||||
like
|
like
|
||||||
{
|
{
|
||||||
"func": str as the evaluation method
|
"func": str as the evaluation method
|
||||||
"paths": paths to the involved files needed for evaluation
|
"result": {
|
||||||
|
"type": str
|
||||||
|
...other keys
|
||||||
|
}
|
||||||
|
"condition": {
|
||||||
|
"type": str
|
||||||
|
...other keys
|
||||||
|
}
|
||||||
}
|
}
|
||||||
action_space (str): "computer_13" | "pyautogui"
|
action_space (str): "computer_13" | "pyautogui"
|
||||||
|
|
||||||
@@ -84,7 +95,11 @@ class DesktopEnv(gym.Env):
|
|||||||
self.setup_controller = SetupController(http_server=self.host)
|
self.setup_controller = SetupController(http_server=self.host)
|
||||||
self.instruction = instruction
|
self.instruction = instruction
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
self.metric: Metric = getattr(metrics, evaluator["func"])
|
||||||
|
self.result_getter: Getter = getattr(getters, "get_{:}".format(evaluator["result"]["type"]))
|
||||||
|
self.expected_getter: Getter = getattr(getters, "get_{:}".format(evaluator["expected"]["type"]))
|
||||||
|
|
||||||
# mode: human or machine
|
# mode: human or machine
|
||||||
assert action_space in ["computer_13", "pyautogui"]
|
assert action_space in ["computer_13", "pyautogui"]
|
||||||
@@ -201,42 +216,22 @@ class DesktopEnv(gym.Env):
|
|||||||
"""
|
"""
|
||||||
Evaluate whether the task is successfully completed.
|
Evaluate whether the task is successfully completed.
|
||||||
"""
|
"""
|
||||||
def copy_file_to_local(_file_info):
|
|
||||||
#random_uuid = str(uuid.uuid4())
|
|
||||||
#os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
|
||||||
#_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
|
|
||||||
_path = os.path.join(self.cache_dir, _file_info["dest"])
|
|
||||||
if os.path.exists(_path):
|
|
||||||
return _path
|
|
||||||
|
|
||||||
if _file_info["type"] == "cloud_file":
|
|
||||||
url = _file_info["path"]
|
|
||||||
response = requests.get(url, stream=True, verify=False)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
with open(_path, 'wb') as f:
|
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
|
||||||
if chunk:
|
|
||||||
f.write(chunk)
|
|
||||||
elif _file_info["type"] == "vm_file":
|
|
||||||
# fixme: stream this part maybe as well
|
|
||||||
file = self.controller.get_file(_file_info["path"])
|
|
||||||
with open(_path, "wb") as f:
|
|
||||||
f.write(file)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
return _path
|
|
||||||
|
|
||||||
# todo: make this more flexible by refactoring
|
# todo: make this more flexible by refactoring
|
||||||
eval_func = eval_funcs[self.evaluator["func"]]
|
#eval_func = eval_funcs[self.evaluator["func"]]
|
||||||
eval_func_vars = {}
|
#eval_func_vars = {}
|
||||||
|
#
|
||||||
|
#for var_name, file_info in self.evaluator["paths"].items():
|
||||||
|
#path = copy_file_to_local(file_info)
|
||||||
|
#eval_func_vars[var_name] = path
|
||||||
|
#
|
||||||
|
#return eval_func(**eval_func_vars)
|
||||||
|
|
||||||
for var_name, file_info in self.evaluator["paths"].items():
|
result = self.result_getter(self, self.evaluator["result"])
|
||||||
path = copy_file_to_local(file_info)
|
expected = self.expected_getter(self, self.evaluator["expected"])
|
||||||
eval_func_vars[var_name] = path
|
metric: float = self.metric(result, expected)
|
||||||
|
|
||||||
return eval_func(**eval_func_vars)
|
return metric
|
||||||
|
|
||||||
def render(self, mode='rgb_array'):
|
def render(self, mode='rgb_array'):
|
||||||
if mode == 'rgb_array':
|
if mode == 'rgb_array':
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from .table import compare_table
|
#from .table import compare_table
|
||||||
|
|
||||||
eval_funcs = {
|
#eval_funcs = {
|
||||||
"compare_table(expected, actual)": compare_table
|
#"compare_table(expected, actual)": compare_table
|
||||||
}
|
#}
|
||||||
|
|||||||
1
desktop_env/evaluators/getters/__init__.py
Normal file
1
desktop_env/evaluators/getters/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .file import get_cloud_file, get_vm_file
|
||||||
43
desktop_env/evaluators/getters/file.py
Normal file
43
desktop_env/evaluators/getters/file.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
def get_cloud_file(env, config: Dict[str, str]) -> str:
|
||||||
|
"""
|
||||||
|
Config:
|
||||||
|
path (str): the url to download from
|
||||||
|
dest (str): file name of the downloaded file
|
||||||
|
"""
|
||||||
|
|
||||||
|
_path = os.path.join(env.cache_dir, config["dest"])
|
||||||
|
if os.path.exists(_path):
|
||||||
|
return _path
|
||||||
|
|
||||||
|
url = config["path"]
|
||||||
|
response = requests.get(url, stream=True, verify=False)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(_path, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
return _path
|
||||||
|
|
||||||
|
def get_vm_file(env, config: Dict[str, str]) -> str:
|
||||||
|
"""
|
||||||
|
Config:
|
||||||
|
path (str): absolute path on the VM to fetch
|
||||||
|
dest (str): file name of the downloaded file
|
||||||
|
"""
|
||||||
|
|
||||||
|
_path = os.path.join(env.cache_dir, config["dest"])
|
||||||
|
if os.path.exists(_path):
|
||||||
|
return _path
|
||||||
|
|
||||||
|
file = env.controller.get_file(config["path"])
|
||||||
|
with open(_path, "wb") as f:
|
||||||
|
f.write(file)
|
||||||
|
|
||||||
|
return _path
|
||||||
1
desktop_env/evaluators/metrics/__init__.py
Normal file
1
desktop_env/evaluators/metrics/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .table import compare_table
|
||||||
@@ -27,18 +27,16 @@
|
|||||||
"libreoffice calc"
|
"libreoffice calc"
|
||||||
],
|
],
|
||||||
"evaluator": {
|
"evaluator": {
|
||||||
"func": "compare_table(expected, actual)",
|
"func": "compare_table",
|
||||||
"paths": {
|
"result": {
|
||||||
"expected": {
|
"type": "cloud_file",
|
||||||
"type": "cloud_file",
|
"path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx",
|
||||||
"path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx",
|
"dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx"
|
||||||
"dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx"
|
},
|
||||||
},
|
"expected": {
|
||||||
"actual": {
|
"type": "vm_file",
|
||||||
"type": "vm_file",
|
"path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx",
|
||||||
"path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx",
|
"dest": "Quarterly_Product_Sales_by_Zone.xlsx"
|
||||||
"dest": "Quarterly_Product_Sales_by_Zone.xlsx"
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user