setup controller
This commit is contained in:
40
desktop_env/controllers/setup.py
Normal file
40
desktop_env/controllers/setup.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
class SetupController:
|
||||
def __init__(self, http_server: str):
|
||||
self.http_server = http_server + "/setup"
|
||||
|
||||
def setup(self, config):
|
||||
"""
|
||||
Setup Config:
|
||||
{
|
||||
download: list[tuple[string]], # a list of tuples of url of file to download and the save path
|
||||
...
|
||||
}
|
||||
"""
|
||||
self._download_setup(config)
|
||||
# can add other setup steps
|
||||
|
||||
|
||||
def _download_setup(self, config):
|
||||
if not 'download' in config:
|
||||
return
|
||||
for url, path in config['download']:
|
||||
if not url or not path:
|
||||
raise Exception(f"Setup Download - Invalid URL ({url}) or path ({path}).")
|
||||
|
||||
payload = json.dumps({"url": url, "path": path})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# send request to server to download file
|
||||
try:
|
||||
response = requests.post(self.http_server + "/download_file", headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("Command executed successfully:", response.text)
|
||||
else:
|
||||
print("Failed to download file. Status code:", response.text)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("An error occurred while trying to send the request:", e)
|
||||
@@ -9,6 +9,7 @@ from typing import List
|
||||
import gymnasium as gym
|
||||
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
|
||||
|
||||
def _execute_command(command: List[str]) -> None:
|
||||
@@ -34,8 +35,9 @@ class DesktopEnv(gym.Env):
|
||||
# Initialize emulator and controller
|
||||
print("Initializing...")
|
||||
self._start_emulator()
|
||||
self.host = self._get_vm_ip()
|
||||
self.host = f"http://{self._get_vm_ip()}:5000"
|
||||
self.controller = PythonController(http_server=self.host)
|
||||
self.setup_controller = SetupController(http_server=self.host)
|
||||
|
||||
# mode: human or machine
|
||||
assert action_space in ["computer_13", "pyautogui"]
|
||||
@@ -118,6 +120,9 @@ class DesktopEnv(gym.Env):
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
return observation, reward, done, info
|
||||
|
||||
def setup(self, config: dict):
|
||||
self.setup_controller.setup(config)
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import platform
|
||||
import subprocess
|
||||
import requests
|
||||
@@ -83,6 +84,34 @@ def get_platform():
|
||||
def get_cursor_position():
|
||||
return pyautogui.position().x, pyautogui.position().y
|
||||
|
||||
@app.route("/setup/download_file", methods=['POST'])
|
||||
def download_file():
|
||||
data = request.json
|
||||
url = data.get('url', None)
|
||||
path = data.get('path', None)
|
||||
|
||||
if not url or not path:
|
||||
return "Path or URL not supplied!", 400
|
||||
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
max_retries = 3
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
return "File downloaded successfully"
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f"Failed to download {url}. Retrying... ({max_retries - i - 1} attempts left)")
|
||||
|
||||
return f"Failed to download {url}. No retries left. Error: {e}", 500
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, host="0.0.0.0")
|
||||
|
||||
3
main.py
3
main.py
@@ -10,6 +10,9 @@ def human_agent():
|
||||
# path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
|
||||
snapshot_path="base3",
|
||||
)
|
||||
# example setup
|
||||
env.setup({"download": [("https://images.unsplash.com/photo-1683009427051-00a2fe827a2c", "C:/Users/Yuri/Desktop/1.jpg")]})
|
||||
|
||||
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
|
||||
Reference in New Issue
Block a user