diff --git a/desktop_env/controllers/setup.py b/desktop_env/controllers/setup.py new file mode 100644 index 0000000..d2abc97 --- /dev/null +++ b/desktop_env/controllers/setup.py @@ -0,0 +1,40 @@ +import requests +import json + +class SetupController: + def __init__(self, http_server: str): + self.http_server = http_server + "/setup" + + def setup(self, config): + """ + Setup Config: + { + download: list[tuple[string]], # a list of tuples of url of file to download and the save path + ... + } + """ + self._download_setup(config) + # can add other setup steps + + + def _download_setup(self, config): + if not 'download' in config: + return + for url, path in config['download']: + if not url or not path: + raise Exception(f"Setup Download - Invalid URL ({url}) or path ({path}).") + + payload = json.dumps({"url": url, "path": path}) + headers = { + 'Content-Type': 'application/json' + } + + # send request to server to download file + try: + response = requests.post(self.http_server + "/download_file", headers=headers, data=payload) + if response.status_code == 200: + print("Command executed successfully:", response.text) + else: + print("Failed to download file. Status code:", response.text) + except requests.exceptions.RequestException as e: + print("An error occurred while trying to send the request:", e) diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index e2f8376..99a9a26 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -9,6 +9,7 @@ from typing import List import gymnasium as gym from desktop_env.controllers.python import PythonController +from desktop_env.controllers.setup import SetupController def _execute_command(command: List[str]) -> None: @@ -34,8 +35,9 @@ class DesktopEnv(gym.Env): # Initialize emulator and controller print("Initializing...") self._start_emulator() - self.host = self._get_vm_ip() + self.host = f"http://{self._get_vm_ip()}:5000" self.controller = PythonController(http_server=self.host) + self.setup_controller = SetupController(http_server=self.host) # mode: human or machine assert action_space in ["computer_13", "pyautogui"] @@ -118,6 +120,9 @@ class DesktopEnv(gym.Env): done = False # todo: Define episode termination condition for each example info = {} return observation, reward, done, info + + def setup(self, config: dict): + self.setup_controller.setup(config) def render(self, mode='rgb_array'): if mode == 'rgb_array': diff --git a/desktop_env/server/main.py b/desktop_env/server/main.py index f9e8dcd..93884ca 100644 --- a/desktop_env/server/main.py +++ b/desktop_env/server/main.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import platform import subprocess import requests @@ -83,6 +84,34 @@ def get_platform(): def get_cursor_position(): return pyautogui.position().x, pyautogui.position().y +@app.route("/setup/download_file", methods=['POST']) +def download_file(): + data = request.json + url = data.get('url', None) + path = data.get('path', None) + + if not url or not path: + return "Path or URL not supplied!", 400 + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + max_retries = 3 + for i in range(max_retries): + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + return "File downloaded successfully" + + except requests.RequestException as e: + print(f"Failed to download {url}. Retrying... ({max_retries - i - 1} attempts left)") + + return f"Failed to download {url}. No retries left. Error: {e}", 500 if __name__ == '__main__': app.run(debug=True, host="0.0.0.0") diff --git a/main.py b/main.py index bbfd8d2..7b993f2 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,9 @@ def human_agent(): # path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx", snapshot_path="base3", ) + # example setup + env.setup({"download": [("https://images.unsplash.com/photo-1683009427051-00a2fe827a2c", "C:/Users/Yuri/Desktop/1.jpg")]}) + # reset the environment to certain snapshot observation = env.reset()