Merge remote-tracking branch 'origin/main'
# Conflicts: # main.py
This commit is contained in:
40
desktop_env/controllers/setup.py
Normal file
40
desktop_env/controllers/setup.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
class SetupController:
|
||||||
|
def __init__(self, http_server: str):
|
||||||
|
self.http_server = http_server + "/setup"
|
||||||
|
|
||||||
|
def setup(self, config):
|
||||||
|
"""
|
||||||
|
Setup Config:
|
||||||
|
{
|
||||||
|
download: list[tuple[string]], # a list of tuples of url of file to download and the save path
|
||||||
|
...
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
self._download_setup(config)
|
||||||
|
# can add other setup steps
|
||||||
|
|
||||||
|
|
||||||
|
def _download_setup(self, config):
|
||||||
|
if not 'download' in config:
|
||||||
|
return
|
||||||
|
for url, path in config['download']:
|
||||||
|
if not url or not path:
|
||||||
|
raise Exception(f"Setup Download - Invalid URL ({url}) or path ({path}).")
|
||||||
|
|
||||||
|
payload = json.dumps({"url": url, "path": path})
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
|
# send request to server to download file
|
||||||
|
try:
|
||||||
|
response = requests.post(self.http_server + "/download_file", headers=headers, data=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("Command executed successfully:", response.text)
|
||||||
|
else:
|
||||||
|
print("Failed to download file. Status code:", response.text)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print("An error occurred while trying to send the request:", e)
|
||||||
@@ -9,12 +9,14 @@ from typing import List
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from desktop_env.controllers.python import PythonController
|
from desktop_env.controllers.python import PythonController
|
||||||
|
from desktop_env.controllers.setup import SetupController
|
||||||
|
|
||||||
|
|
||||||
def _execute_command(command: List[str]) -> None:
|
def _execute_command(command: List[str]) -> None:
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||||
|
return result.stdout
|
||||||
|
|
||||||
|
|
||||||
class DesktopEnv(gym.Env):
|
class DesktopEnv(gym.Env):
|
||||||
@@ -23,19 +25,19 @@ class DesktopEnv(gym.Env):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
path_to_vm: str,
|
path_to_vm: str,
|
||||||
host: str = "192.168.7.128:5000",
|
|
||||||
snapshot_path: str = "base",
|
snapshot_path: str = "base",
|
||||||
action_space: str = "pyautogui",
|
action_space: str = "pyautogui",
|
||||||
):
|
):
|
||||||
# Initialize environment variables
|
# Initialize environment variables
|
||||||
self.path_to_vm = path_to_vm
|
self.path_to_vm = path_to_vm
|
||||||
self.host = host
|
|
||||||
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||||
|
|
||||||
# Initialize emulator and controller
|
# Initialize emulator and controller
|
||||||
print("Initializing...")
|
print("Initializing...")
|
||||||
self._start_emulator()
|
self._start_emulator()
|
||||||
|
self.host = f"http://{self._get_vm_ip()}:5000"
|
||||||
self.controller = PythonController(http_server=self.host)
|
self.controller = PythonController(http_server=self.host)
|
||||||
|
self.setup_controller = SetupController(http_server=self.host)
|
||||||
|
|
||||||
# mode: human or machine
|
# mode: human or machine
|
||||||
assert action_space in ["computer_13", "pyautogui"]
|
assert action_space in ["computer_13", "pyautogui"]
|
||||||
@@ -53,10 +55,23 @@ class DesktopEnv(gym.Env):
|
|||||||
else:
|
else:
|
||||||
print("Starting VM...")
|
print("Starting VM...")
|
||||||
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
||||||
time.sleep(10)
|
time.sleep(3)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"Error executing command: {e.output.decode().strip()}")
|
print(f"Error executing command: {e.output.decode().strip()}")
|
||||||
|
|
||||||
|
def _get_vm_ip(self):
|
||||||
|
max_retries = 10
|
||||||
|
print("Getting IP Address...")
|
||||||
|
for _ in range(max_retries):
|
||||||
|
try:
|
||||||
|
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm]).strip()
|
||||||
|
print(f"IP address: {output}")
|
||||||
|
return output
|
||||||
|
except:
|
||||||
|
time.sleep(5)
|
||||||
|
print("Retrying...")
|
||||||
|
raise Exception("Failed to get VM IP address!")
|
||||||
|
|
||||||
def _save_state(self):
|
def _save_state(self):
|
||||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||||
|
|
||||||
@@ -105,6 +120,9 @@ class DesktopEnv(gym.Env):
|
|||||||
done = False # todo: Define episode termination condition for each example
|
done = False # todo: Define episode termination condition for each example
|
||||||
info = {}
|
info = {}
|
||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
def setup(self, config: dict):
|
||||||
|
self.setup_controller.setup(config)
|
||||||
|
|
||||||
def render(self, mode='rgb_array'):
|
def render(self, mode='rgb_array'):
|
||||||
if mode == 'rgb_array':
|
if mode == 'rgb_array':
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import platform
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
import requests
|
import requests
|
||||||
@@ -83,6 +84,34 @@ def get_platform():
|
|||||||
def get_cursor_position():
|
def get_cursor_position():
|
||||||
return pyautogui.position().x, pyautogui.position().y
|
return pyautogui.position().x, pyautogui.position().y
|
||||||
|
|
||||||
|
@app.route("/setup/download_file", methods=['POST'])
|
||||||
|
def download_file():
|
||||||
|
data = request.json
|
||||||
|
url = data.get('url', None)
|
||||||
|
path = data.get('path', None)
|
||||||
|
|
||||||
|
if not url or not path:
|
||||||
|
return "Path or URL not supplied!", 400
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
for i in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
return "File downloaded successfully"
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"Failed to download {url}. Retrying... ({max_retries - i - 1} attempts left)")
|
||||||
|
|
||||||
|
return f"Failed to download {url}. No retries left. Error: {e}", 500
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(debug=True, host="0.0.0.0")
|
app.run(debug=True, host="0.0.0.0")
|
||||||
|
|||||||
Reference in New Issue
Block a user