Merge branch 'zdy'
This commit is contained in:
@@ -5,16 +5,20 @@ import subprocess
|
||||
import time
|
||||
#import uuid
|
||||
#import platform
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
from typing import Callable, Any
|
||||
import tempfile
|
||||
|
||||
import gymnasium as gym
|
||||
import requests
|
||||
#import requests
|
||||
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
from desktop_env.evaluators import eval_funcs
|
||||
#from desktop_env.evaluators import eval_funcs
|
||||
from desktop_env.evaluators import metrics, getters
|
||||
|
||||
Metric = Callable[[Any, Any], float]
|
||||
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
||||
|
||||
def _execute_command(command: List[str]) -> None:
|
||||
if command[:4] == ["vmrun", "-T", "ws", "start"]:
|
||||
@@ -56,7 +60,14 @@ class DesktopEnv(gym.Env):
|
||||
like
|
||||
{
|
||||
"func": str as the evaluation method
|
||||
"paths": paths to the involved files needed for evaluation
|
||||
"result": {
|
||||
"type": str
|
||||
...other keys
|
||||
}
|
||||
"condition": {
|
||||
"type": str
|
||||
...other keys
|
||||
}
|
||||
}
|
||||
action_space (str): "computer_13" | "pyautogui"
|
||||
|
||||
@@ -84,7 +95,11 @@ class DesktopEnv(gym.Env):
|
||||
self.setup_controller = SetupController(http_server=self.host)
|
||||
self.instruction = instruction
|
||||
self.config = config
|
||||
|
||||
self.evaluator = evaluator
|
||||
self.metric: Metric = getattr(metrics, evaluator["func"])
|
||||
self.result_getter: Getter = getattr(getters, "get_{:}".format(evaluator["result"]["type"]))
|
||||
self.expected_getter: Getter = getattr(getters, "get_{:}".format(evaluator["expected"]["type"]))
|
||||
|
||||
# mode: human or machine
|
||||
assert action_space in ["computer_13", "pyautogui"]
|
||||
@@ -201,42 +216,22 @@ class DesktopEnv(gym.Env):
|
||||
"""
|
||||
Evaluate whether the task is successfully completed.
|
||||
"""
|
||||
def copy_file_to_local(_file_info):
|
||||
#random_uuid = str(uuid.uuid4())
|
||||
#os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||
#_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
|
||||
_path = os.path.join(self.cache_dir, _file_info["dest"])
|
||||
if os.path.exists(_path):
|
||||
return _path
|
||||
|
||||
if _file_info["type"] == "cloud_file":
|
||||
url = _file_info["path"]
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
elif _file_info["type"] == "vm_file":
|
||||
# fixme: stream this part maybe as well
|
||||
file = self.controller.get_file(_file_info["path"])
|
||||
with open(_path, "wb") as f:
|
||||
f.write(file)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return _path
|
||||
|
||||
# todo: make this more flexible by refactoring
|
||||
eval_func = eval_funcs[self.evaluator["func"]]
|
||||
eval_func_vars = {}
|
||||
#eval_func = eval_funcs[self.evaluator["func"]]
|
||||
#eval_func_vars = {}
|
||||
#
|
||||
#for var_name, file_info in self.evaluator["paths"].items():
|
||||
#path = copy_file_to_local(file_info)
|
||||
#eval_func_vars[var_name] = path
|
||||
#
|
||||
#return eval_func(**eval_func_vars)
|
||||
|
||||
for var_name, file_info in self.evaluator["paths"].items():
|
||||
path = copy_file_to_local(file_info)
|
||||
eval_func_vars[var_name] = path
|
||||
result = self.result_getter(self, self.evaluator["result"])
|
||||
expected = self.expected_getter(self, self.evaluator["expected"])
|
||||
metric: float = self.metric(result, expected)
|
||||
|
||||
return eval_func(**eval_func_vars)
|
||||
return metric
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .table import compare_table
|
||||
#from .table import compare_table
|
||||
|
||||
eval_funcs = {
|
||||
"compare_table(expected, actual)": compare_table
|
||||
}
|
||||
#eval_funcs = {
|
||||
#"compare_table(expected, actual)": compare_table
|
||||
#}
|
||||
|
||||
1
desktop_env/evaluators/getters/__init__.py
Normal file
1
desktop_env/evaluators/getters/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .file import get_cloud_file, get_vm_file
|
||||
43
desktop_env/evaluators/getters/file.py
Normal file
43
desktop_env/evaluators/getters/file.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Dict
|
||||
|
||||
import os
|
||||
import requests
|
||||
|
||||
def get_cloud_file(env, config: Dict[str, str]) -> str:
|
||||
"""
|
||||
Config:
|
||||
path (str): the url to download from
|
||||
dest (str): file name of the downloaded file
|
||||
"""
|
||||
|
||||
_path = os.path.join(env.cache_dir, config["dest"])
|
||||
if os.path.exists(_path):
|
||||
return _path
|
||||
|
||||
url = config["path"]
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
return _path
|
||||
|
||||
def get_vm_file(env, config: Dict[str, str]) -> str:
|
||||
"""
|
||||
Config:
|
||||
path (str): absolute path on the VM to fetch
|
||||
dest (str): file name of the downloaded file
|
||||
"""
|
||||
|
||||
_path = os.path.join(env.cache_dir, config["dest"])
|
||||
if os.path.exists(_path):
|
||||
return _path
|
||||
|
||||
file = env.controller.get_file(config["path"])
|
||||
with open(_path, "wb") as f:
|
||||
f.write(file)
|
||||
|
||||
return _path
|
||||
1
desktop_env/evaluators/metrics/__init__.py
Normal file
1
desktop_env/evaluators/metrics/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .table import compare_table
|
||||
@@ -27,18 +27,16 @@
|
||||
"libreoffice calc"
|
||||
],
|
||||
"evaluator": {
|
||||
"func": "compare_table(expected, actual)",
|
||||
"paths": {
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://drive.usercontent.google.com/download?id=17f1wZuJPvUEc5at_Fy3c18VFdOk0x7xz&export=download&authuser=0&confirm=t&uuid=6d2edffd-0ce0-426e-9820-8af25b4667f3&at=APZUnTVh7JS85dwZBaV2hytWQgDK:1702361510956"
|
||||
"dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx"
|
||||
},
|
||||
"actual": {
|
||||
"type": "vm_file",
|
||||
"path": "C:\\Users\\tianbaox\\Desktop\\Quarterly_Product_Sales_by_Zone.xlsx"
|
||||
"dest": "Quarterly_Product_Sales_by_Zone.xlsx"
|
||||
}
|
||||
}
|
||||
"func": "compare_table",
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://drive.usercontent.google.com/download?id=17f1wZuJPvUEc5at_Fy3c18VFdOk0x7xz&export=download&authuser=0&confirm=t&uuid=6d2edffd-0ce0-426e-9820-8af25b4667f3&at=APZUnTVh7JS85dwZBaV2hytWQgDK:1702361510956",
|
||||
"dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx"
|
||||
},
|
||||
"result": {
|
||||
"type": "vm_file",
|
||||
"path": "C:\\Users\\tianbaox\\Desktop\\Quarterly_Product_Sales_by_Zone.xlsx",
|
||||
"dest": "Quarterly_Product_Sales_by_Zone.xlsx"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user