Merge branch 'zdy'

This commit is contained in:
David Chang
2023-12-25 21:05:44 +08:00
10 changed files with 269 additions and 61 deletions

View File

@@ -1,13 +1,23 @@
import requests
import json
from requests_toolbelt.multipart.encoder import MultipartEncoder
import uuid
import os.path
from typing import Dict, List
from typing import Any
class SetupController:
def __init__(self, http_server: str):
def __init__( self
, http_server: str
, cache_dir: str
):
self.http_server = http_server + "/setup"
self.cache_dir: str = cache_dir
def reset_cache_dir(self, cache_dir: str):
self.cache_dir = cache_dir
def setup(self, config: List[Dict[str, Any]]):
"""
@@ -56,22 +66,56 @@ class SetupController:
for f in files:
url: str = f["url"]
path: str = f["path"]
cache_path: str = os.path.join( self.cache_dir
, "{:}_{:}".format(
uuid.uuid5(uuid.NAMESPACE_URL, url)
, os.path.basename(path)
)
)
if not url or not path:
raise Exception(f"Setup Download - Invalid URL ({url}) or path ({path}).")
payload = json.dumps({"url": url, "path": path})
headers = {
'Content-Type': 'application/json'
}
if not os.path.exists(cache_path):
max_retries = 3
downloaded = False
for i in range(max_retries):
try:
response = requests.get(url, stream=True)
response.raise_for_status()
# send request to server to download file
with open(cache_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("File downloaded successfully")
downloaded = True
break
except requests.RequestException as e:
print(f"Failed to download {url}. Retrying... ({max_retries - i - 1} attempts left)")
if not downloaded:
raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}")
#payload = json.dumps({"url": url, "path": path})
#headers = {
#'Content-Type': 'application/json'
#}
form = MultipartEncoder( { "file_path": path
, "file_data": (os.path.basename(path), open(cache_path, "rb"))
}
)
headers = {"Content-Type": form.content_type}
print(form.content_type)
# send request to server to upload file
try:
response = requests.post(self.http_server + "/download_file", headers=headers, data=payload)
response = requests.post(self.http_server + "/upload", headers=headers, data=form)
if response.status_code == 200:
print("Command executed successfully:", response.text)
else:
print("Failed to download file. Status code:", response.text)
print("Failed to upload file. Status code:", response.text)
except requests.exceptions.RequestException as e:
print("An error occurred while trying to send the request:", e)

View File

@@ -67,18 +67,6 @@ class DesktopEnv(gym.Env):
self.tmp_dir_base: str = tmp_dir
self.cache_dir_base: str = cache_dir
# Initialize emulator and controller
print("Initializing...")
self._start_emulator()
self.host = f"http://{self._get_vm_ip()}:5000"
self.controller = PythonController(http_server=self.host)
self.setup_controller = SetupController(http_server=self.host)
# mode: human or machine
assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space
# todo: define the action space and the observation space as gym did, or extend theirs
# task-aware stuffs
self.snapshot_path = task_config["snapshot"] # todo: handling the logic of snapshot directory
self.task_id: str = task_config["id"]
@@ -92,10 +80,24 @@ class DesktopEnv(gym.Env):
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
# episodic stuffs, like tmp dir and counters
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
# Initialize emulator and controller
print("Initializing...")
self._start_emulator()
self.host = f"http://{self._get_vm_ip()}:5000"
self.controller = PythonController(http_server=self.host)
self.setup_controller = SetupController(http_server=self.host, cache_dir=self.cache_dir)
# mode: human or machine
assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space
# todo: define the action space and the observation space as gym did, or extend theirs
# episodic stuffs, like tmp dir and counters, will be updated or reset
# when calling self.reset()
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
self._traj_no: int = -1
self._step_no: int = 0
self.action_history: List[Dict[str, any]] = []
def _start_emulator(self):
while True:
@@ -164,9 +166,12 @@ class DesktopEnv(gym.Env):
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
self.setup_controller.reset_cache_dir(self.cache_dir)
print("Setting counters...")
self._traj_no += 1
self._step_no = 0
self.action_history.clear()
print("Setup new temp dir...")
self.tmp_dir = tempfile.mkdtemp(
@@ -202,6 +207,7 @@ class DesktopEnv(gym.Env):
elif self.action_space == "pyautogui":
# the set of all possible python commands insides `pyautogui`
self.controller.execute_python_command(action)
self.action_history.append(action)
# todo: maybe for the better here we need to add a logic to wait until the rendering is done
time.sleep(pause)

View File

@@ -1 +1 @@
from .table import compare_table
from .table import compare_table, compare_with_sparklines

View File

@@ -1,14 +1,74 @@
def compare_table(expected, actual):
import pandas as pd
import pandas as pd
import zipfile
import lxml.etree
import lxml.cssselect
from lxml.etree import _Element
import xmltodict
#import pylightxl
from typing import Dict, List
#from typing import Any
def compare_table(actual, expected):
df1 = pd.read_excel(expected)
df2 = pd.read_excel(actual)
# Compare the DataFrames
return 1 if df1.equals(df2) else 0
_xlsx_namespaces = [ ("x14", "http://schemas.microsoft.com/office/spreadsheetml/2009/9/main")
, ("xm", "http://schemas.microsoft.com/office/excel/2006/main")
]
_xlsx_ns_mapping = dict(_xlsx_namespaces)
_xlsx_ns_imapping = dict(map(lambda itm: (itm[1], itm[0]), _xlsx_namespaces))
_sparklines_selector = lxml.cssselect.CSSSelector("x14|sparkline", namespaces=_xlsx_ns_mapping)
#print(_sparklines_selector.css)
def _load_sparklines(xlsx_file: str) -> Dict[str, str]:
"""
This function modifies data_frame in-place
Args:
xlsx_file (str): path to xlsx
Returns:
List[Dict[str, str]]: sparkline definitions in form of
{
"F3": "Sheet1!C3:E3"
}
"""
# read xlsx
with zipfile.ZipFile(xlsx_file, "r") as z_f:
with z_f.open("xl/worksheets/sheet1.xml") as f:
sheet1: _Element = lxml.etree.fromstring(f.read())
sparklines: List[_Element] = _sparklines_selector(sheet1)
sparklines_dict: Dict[str, str] = {}
for sp_l in sparklines:
sparkline_xml: str = lxml.etree.tostring(sp_l, encoding="unicode")
sparkline: Dict[str, Dict[str, str]] = xmltodict.parse( sparkline_xml
, process_namespaces=True
, namespaces=_xlsx_ns_imapping
)
sparklines_dict[sparkline["x14:sparkline"]["xm:sqref"]] = sparkline["x14:sparkline"]["xm:f"]
return sparklines_dict
def compare_with_sparklines(actual: str, expected: str) -> float:
df1 = pd.read_excel(actual)
df2 = pd.read_excel(expected)
normal_content_metric: bool = df1.equals(df2)
sp1 = _load_sparklines(actual)
sp2 = _load_sparklines(expected)
sparkline_metric: bool = sp1 == sp2
return float(normal_content_metric and sparkline_metric)
if __name__ == '__main__':
path1 = ""
path2 = ""
print(compare_table(path1, path2))
#path1 = ""
#path2 = ""
#print(compare_table(path1, path2))
path1 = "../../../../../任务数据/LibreOffice Calc/OrderId_Month_Chart_gold.xlsx"
path2 = "../../../../../任务数据/LibreOffice Calc/OrderId_Month_Chart.xlsx"
print(compare_with_sparklines(path1, path2))

View File

@@ -95,6 +95,16 @@ def get_file():
# If the file is not found, return a 404 error
return jsonify({"error": "File not found"}), 404
@app.route("/setup/upload", methods=["POST"])
def upload_file():
# Retrieve filename from the POST request
if 'file_path' in request.form and 'file_data' in request.files:
file_path = request.form['file_path']
file = request.files["file_data"]
file.save(file_path)
return "File Uploaded"
else:
return jsonify({"error": "file_path and file_data are required"}), 400
@app.route('/platform', methods=['GET'])
def get_platform():