Merge branch 'zdy'

This commit is contained in:
David Chang
2023-12-22 14:11:26 +08:00
8 changed files with 91 additions and 53 deletions

View File

@@ -5,16 +5,20 @@ import subprocess
import time
#import uuid
#import platform
from typing import List
from typing import List, Dict
from typing import Callable, Any
import tempfile
import gymnasium as gym
import requests
#import requests
from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController
from desktop_env.evaluators import eval_funcs
#from desktop_env.evaluators import eval_funcs
from desktop_env.evaluators import metrics, getters
Metric = Callable[[Any, Any], float]
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
def _execute_command(command: List[str]) -> None:
if command[:4] == ["vmrun", "-T", "ws", "start"]:
@@ -56,7 +60,14 @@ class DesktopEnv(gym.Env):
like
{
"func": str as the evaluation method
"paths": paths to the involved files needed for evaluation
"result": {
"type": str
...other keys
}
"condition": {
"type": str
...other keys
}
}
action_space (str): "computer_13" | "pyautogui"
@@ -84,7 +95,11 @@ class DesktopEnv(gym.Env):
self.setup_controller = SetupController(http_server=self.host)
self.instruction = instruction
self.config = config
self.evaluator = evaluator
self.metric: Metric = getattr(metrics, evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(evaluator["expected"]["type"]))
# mode: human or machine
assert action_space in ["computer_13", "pyautogui"]
@@ -201,42 +216,22 @@ class DesktopEnv(gym.Env):
"""
Evaluate whether the task is successfully completed.
"""
def copy_file_to_local(_file_info):
#random_uuid = str(uuid.uuid4())
#os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
#_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
_path = os.path.join(self.cache_dir, _file_info["dest"])
if os.path.exists(_path):
return _path
if _file_info["type"] == "cloud_file":
url = _file_info["path"]
response = requests.get(url, stream=True)
response.raise_for_status()
with open(_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
elif _file_info["type"] == "vm_file":
# fixme: stream this part maybe as well
file = self.controller.get_file(_file_info["path"])
with open(_path, "wb") as f:
f.write(file)
else:
raise NotImplementedError
return _path
# todo: make this more flexible by refactoring
eval_func = eval_funcs[self.evaluator["func"]]
eval_func_vars = {}
#eval_func = eval_funcs[self.evaluator["func"]]
#eval_func_vars = {}
#
#for var_name, file_info in self.evaluator["paths"].items():
#path = copy_file_to_local(file_info)
#eval_func_vars[var_name] = path
#
#return eval_func(**eval_func_vars)
for var_name, file_info in self.evaluator["paths"].items():
path = copy_file_to_local(file_info)
eval_func_vars[var_name] = path
result = self.result_getter(self, self.evaluator["result"])
expected = self.expected_getter(self, self.evaluator["expected"])
metric: float = self.metric(result, expected)
return eval_func(**eval_func_vars)
return metric
def render(self, mode='rgb_array'):
if mode == 'rgb_array':

View File

@@ -1,5 +1,5 @@
from .table import compare_table
#from .table import compare_table
eval_funcs = {
"compare_table(expected, actual)": compare_table
}
#eval_funcs = {
#"compare_table(expected, actual)": compare_table
#}

View File

@@ -0,0 +1 @@
from .file import get_cloud_file, get_vm_file

View File

@@ -0,0 +1,43 @@
from typing import Dict
import os
import requests
def get_cloud_file(env, config: Dict[str, str]) -> str:
"""
Config:
path (str): the url to download from
dest (str): file name of the downloaded file
"""
_path = os.path.join(env.cache_dir, config["dest"])
if os.path.exists(_path):
return _path
url = config["path"]
response = requests.get(url, stream=True)
response.raise_for_status()
with open(_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return _path
def get_vm_file(env, config: Dict[str, str]) -> str:
"""
Config:
path (str): absolute path on the VM to fetch
dest (str): file name of the downloaded file
"""
_path = os.path.join(env.cache_dir, config["dest"])
if os.path.exists(_path):
return _path
file = env.controller.get_file(config["path"])
with open(_path, "wb") as f:
f.write(file)
return _path

View File

@@ -0,0 +1 @@
from .table import compare_table

View File

@@ -27,18 +27,16 @@
"libreoffice calc"
],
"evaluator": {
"func": "compare_table(expected, actual)",
"paths": {
"expected": {
"type": "cloud_file",
"path": "https://drive.usercontent.google.com/download?id=17f1wZuJPvUEc5at_Fy3c18VFdOk0x7xz&export=download&authuser=0&confirm=t&uuid=6d2edffd-0ce0-426e-9820-8af25b4667f3&at=APZUnTVh7JS85dwZBaV2hytWQgDK:1702361510956"
"dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx"
},
"actual": {
"type": "vm_file",
"path": "C:\\Users\\tianbaox\\Desktop\\Quarterly_Product_Sales_by_Zone.xlsx"
"dest": "Quarterly_Product_Sales_by_Zone.xlsx"
}
}
"func": "compare_table",
"expected": {
"type": "cloud_file",
"path": "https://drive.usercontent.google.com/download?id=17f1wZuJPvUEc5at_Fy3c18VFdOk0x7xz&export=download&authuser=0&confirm=t&uuid=6d2edffd-0ce0-426e-9820-8af25b4667f3&at=APZUnTVh7JS85dwZBaV2hytWQgDK:1702361510956",
"dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx"
},
"result": {
"type": "vm_file",
"path": "C:\\Users\\tianbaox\\Desktop\\Quarterly_Product_Sales_by_Zone.xlsx",
"dest": "Quarterly_Product_Sales_by_Zone.xlsx"
}
}
}