diff --git a/.gitignore b/.gitignore index 184e64e..490eebb 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,5 @@ frontend/.next/ frontend/.idea tags +tags-opts snapshots diff --git a/desktop_env/controllers/setup.py b/desktop_env/controllers/setup.py index b42b2ff..26cb474 100644 --- a/desktop_env/controllers/setup.py +++ b/desktop_env/controllers/setup.py @@ -1,12 +1,22 @@ import requests import json +from requests_toolbelt.multipart.encoder import MultipartEncoder + +import uuid +import os.path from typing import Dict, List from typing import Any class SetupController: - def __init__(self, http_server: str): + def __init__( self + , http_server: str + , cache_dir: str + ): self.http_server = http_server + "/setup" + self.cache_dir: str = cache_dir + def reset_cache_dir(self, cache_dir: str): + self.cache_dir = cache_dir def setup(self, config: List[Dict[str, Any]]): """ @@ -55,22 +65,56 @@ class SetupController: for f in files: url: str = f["url"] path: str = f["path"] + cache_path: str = os.path.join( self.cache_dir + , "{:}_{:}".format( + uuid.uuid5(uuid.NAMESPACE_URL, url) + , os.path.basename(path) + ) + ) if not url or not path: raise Exception(f"Setup Download - Invalid URL ({url}) or path ({path}).") - payload = json.dumps({"url": url, "path": path}) - headers = { - 'Content-Type': 'application/json' - } + if not os.path.exists(cache_path): + max_retries = 3 + downloaded = False + for i in range(max_retries): + try: + response = requests.get(url, stream=True, verify=False) + response.raise_for_status() - # send request to server to download file + with open(cache_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + print("File downloaded successfully") + downloaded = True + break + + except requests.RequestException as e: + print(f"Failed to download {url}. Retrying... ({max_retries - i - 1} attempts left)") + if not downloaded: + raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}") + + #payload = json.dumps({"url": url, "path": path}) + #headers = { + #'Content-Type': 'application/json' + #} + + form = MultipartEncoder( { "file_path": path + , "file_data": (os.path.basename(path), open(cache_path, "rb")) + } + ) + headers = {"Content-Type": form.content_type} + print(form.content_type) + + # send request to server to upload file try: - response = requests.post(self.http_server + "/download_file", headers=headers, data=payload) + response = requests.post(self.http_server + "/upload", headers=headers, data=form) if response.status_code == 200: print("Command executed successfully:", response.text) else: - print("Failed to download file. Status code:", response.text) + print("Failed to upload file. Status code:", response.text) except requests.exceptions.RequestException as e: print("An error occurred while trying to send the request:", e) diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index d1a8026..eea834a 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -66,18 +66,6 @@ class DesktopEnv(gym.Env): self.tmp_dir_base: str = tmp_dir self.cache_dir_base: str = cache_dir - # Initialize emulator and controller - print("Initializing...") - self._start_emulator() - self.host = f"http://{self._get_vm_ip()}:5000" - self.controller = PythonController(http_server=self.host) - self.setup_controller = SetupController(http_server=self.host) - - # mode: human or machine - assert action_space in ["computer_13", "pyautogui"] - self.action_space = action_space - # todo: define the action space and the observation space as gym did, or extend theirs - # task-aware stuffs self.snapshot_path = task_config["snapshot"] # todo: handling the logic of snapshot directory self.task_id: str = task_config["id"] @@ -91,6 +79,18 @@ class DesktopEnv(gym.Env): self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"])) self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"])) + # Initialize emulator and controller + print("Initializing...") + self._start_emulator() + self.host = f"http://{self._get_vm_ip()}:5000" + self.controller = PythonController(http_server=self.host) + self.setup_controller = SetupController(http_server=self.host, cache_dir=self.cache_dir) + + # mode: human or machine + assert action_space in ["computer_13", "pyautogui"] + self.action_space = action_space + # todo: define the action space and the observation space as gym did, or extend theirs + # episodic stuffs, like tmp dir and counters, will be updated or reset # when calling self.reset() self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset @@ -165,6 +165,8 @@ class DesktopEnv(gym.Env): self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"])) self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"])) + self.setup_controller.reset_cache_dir(self.cache_dir) + print("Setting counters...") self._traj_no += 1 self._step_no = 0 diff --git a/desktop_env/server/main.py b/desktop_env/server/main.py index 1e88787..67dc015 100644 --- a/desktop_env/server/main.py +++ b/desktop_env/server/main.py @@ -96,6 +96,16 @@ def get_file(): # If the file is not found, return a 404 error return jsonify({"error": "File not found"}), 404 +@app.route("/setup/upload", methods=["POST"]) +def upload_file(): + # Retrieve filename from the POST request + if 'file_path' in request.form and 'file_data' in request.files: + file_path = request.form['file_path'] + file = request.files["file_data"] + file.save(file_path) + return "File Uploaded" + else: + return jsonify({"error": "file_path and file_data are required"}), 400 @app.route('/platform', methods=['GET']) def get_platform(): diff --git a/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json b/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json index 539c4aa..60decfa 100644 --- a/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json +++ b/evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json @@ -4,23 +4,23 @@ "instruction": "Fill the missing row and column which show the total value", "source": "https://youtube.com/shorts/feldd-Pn48c?si=9xJiem2uAHm6Jshb", "config": [ - { - "type": "download", - "parameters": { - "files": [ - { - "url": "http://101.43.24.67/s/DbaHsQpPA7dxrA8/download/Quarterly_Product_Sales_by_Zone.xlsx", - "path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx" - } - ] - } - }, - { - "type": "open", - "parameters": { - "path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx" - } - } + { + "type": "download", + "parameters": { + "files": [ + { + "url": "http://101.43.24.67/s/DbaHsQpPA7dxrA8/download/Quarterly_Product_Sales_by_Zone.xlsx", + "path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx" + } + ] + } + }, + { + "type": "open", + "parameters": { + "path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx" + } + } ], "trajectory": "trajectories/f9584479-3d0d-4c79-affa-9ad7afdd8850", "related_apps": [ @@ -28,15 +28,15 @@ ], "evaluator": { "func": "compare_table", - "expected": { - "type": "cloud_file", - "path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx", - "dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx" - }, - "result": { - "type": "vm_file", - "path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx", - "dest": "Quarterly_Product_Sales_by_Zone.xlsx" - } + "expected": { + "type": "cloud_file", + "path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx", + "dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx" + }, + "result": { + "type": "vm_file", + "path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx", + "dest": "Quarterly_Product_Sales_by_Zone.xlsx" + } } } diff --git a/main.py b/main.py index 911b34f..147a137 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ def human_agent(): with open("evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json", "r") as f: example = json.load(f) - example["snapshot"] = "Init6" + example["snapshot"] = "Snapshot 10" #env = DesktopEnv( path_to_vm="/home/yuri/vmware/Windows 10 x64/Windows 10 x64.vmx" # path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx", diff --git a/requirements.txt b/requirements.txt index 0a27eff..f3907ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ pyautogui~=0.9.54 psutil~=5.9.6 tqdm~=4.65.0 pandas~=2.0.3 -flask~=3.0.0 \ No newline at end of file +flask~=3.0.0 +requests-toolbelt~=1.0.0