400 lines
16 KiB
Python
400 lines
16 KiB
Python
import json
|
|
import logging
|
|
import os.path
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from typing import Any, Union, Optional
|
|
from typing import Dict, List
|
|
|
|
import requests
|
|
from playwright.sync_api import sync_playwright
|
|
from requests_toolbelt.multipart.encoder import MultipartEncoder
|
|
|
|
from desktop_env.evaluators.metrics.utils import compare_urls
|
|
|
|
logger = logging.getLogger("desktopenv.setup")
|
|
|
|
|
|
class SetupController:
|
|
def __init__(self, vm_ip: str, cache_dir: str):
|
|
self.vm_ip: str = vm_ip
|
|
self.http_server: str = f"http://{vm_ip}:5000"
|
|
self.http_server_setup_root: str = f"http://{vm_ip}:5000/setup"
|
|
self.cache_dir: str = cache_dir
|
|
|
|
def reset_cache_dir(self, cache_dir: str):
|
|
self.cache_dir = cache_dir
|
|
|
|
def setup(self, config: List[Dict[str, Any]]):
|
|
"""
|
|
Args:
|
|
config (List[Dict[str, Any]]): list of dict like {str: Any}. each
|
|
config dict has the structure like
|
|
{
|
|
"type": str, corresponding to the `_{:}_setup` methods of
|
|
this class
|
|
"parameters": dick like {str, Any} providing the keyword
|
|
parameters
|
|
}
|
|
"""
|
|
|
|
for cfg in config:
|
|
config_type: str = cfg["type"]
|
|
parameters: Dict[str, Any] = cfg["parameters"]
|
|
|
|
# Assumes all the setup the functions should follow this name
|
|
# protocol
|
|
setup_function: str = "_{:}_setup".format(config_type)
|
|
assert hasattr(self, setup_function)
|
|
getattr(self, setup_function)(**parameters)
|
|
|
|
logger.info("SETUP: %s(%s)", setup_function, str(parameters))
|
|
|
|
# self._download_setup(config)
|
|
# self._change_wallpaper(config)
|
|
# self._tidy_desktop(config) todo: implement this
|
|
# self._open_setup(config)
|
|
# can add other setup steps
|
|
|
|
# ZDY_COMMENT: merged with launch
|
|
# def _command_setup(self, command: str):
|
|
# """
|
|
# Directly send a command into the virtual machine os for setting up.
|
|
# """
|
|
# payload = json.dumps({"command": command})
|
|
# headers = {
|
|
# 'Content-Type': 'application/json'
|
|
# }
|
|
# timeout = 5
|
|
# timout_whitelist = ["vlc"]
|
|
#
|
|
# try:
|
|
#
|
|
# response = requests.post(self.http_server + "/execute", headers=headers, data=payload, timeout=timeout)
|
|
# if response.status_code == 200:
|
|
# print("Command executed successfully:", response.text)
|
|
# else:
|
|
# print("Failed to execute command. Status code:", response.status_code)
|
|
# except requests.exceptions.Timeout as e:
|
|
# if command in timout_whitelist:
|
|
# print("Command executed successfully:", command)
|
|
# else:
|
|
# print("An error occurred while trying to execute the command:", e)
|
|
# except requests.exceptions.RequestException as e:
|
|
# print("An error occurred while trying to execute the command:", e)
|
|
|
|
def _download_setup(self, files: List[Dict[str, str]]):
|
|
"""
|
|
Args:
|
|
files (List[Dict[str, str]]): files to download. lisf of dict like
|
|
{
|
|
"url": str, the url to download
|
|
"path": str, the path on the VM to store the downloaded file
|
|
}
|
|
"""
|
|
|
|
# if not config:
|
|
# return
|
|
# if not 'download' in config:
|
|
# return
|
|
# for url, path in config['download']:
|
|
for f in files:
|
|
url: str = f["url"]
|
|
path: str = f["path"]
|
|
cache_path: str = os.path.join(self.cache_dir, "{:}_{:}".format(
|
|
uuid.uuid5(uuid.NAMESPACE_URL, url),
|
|
os.path.basename(path)))
|
|
|
|
if not url or not path:
|
|
raise Exception(f"Setup Download - Invalid URL ({url}) or path ({path}).")
|
|
|
|
if not os.path.exists(cache_path):
|
|
max_retries = 3
|
|
downloaded = False
|
|
for i in range(max_retries):
|
|
try:
|
|
response = requests.get(url, stream=True)
|
|
response.raise_for_status()
|
|
|
|
with open(cache_path, 'wb') as f:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
if chunk:
|
|
f.write(chunk)
|
|
logger.info("File downloaded successfully")
|
|
downloaded = True
|
|
break
|
|
|
|
except requests.RequestException as e:
|
|
logger.error(f"Failed to download {url}. Retrying... ({max_retries - i - 1} attempts left)")
|
|
if not downloaded:
|
|
raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}")
|
|
|
|
# payload = json.dumps({"url": url, "path": path})
|
|
# headers = {
|
|
# 'Content-Type': 'application/json'
|
|
# }
|
|
|
|
form = MultipartEncoder({
|
|
"file_path": path,
|
|
"file_data": (os.path.basename(path), open(cache_path, "rb"))
|
|
})
|
|
headers = {"Content-Type": form.content_type}
|
|
logger.debug(form.content_type)
|
|
|
|
# send request to server to upload file
|
|
try:
|
|
logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload")
|
|
response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form)
|
|
if response.status_code == 200:
|
|
logger.info("Command executed successfully: %s", response.text)
|
|
else:
|
|
logger.error("Failed to upload file. Status code: %s", response.text)
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error("An error occurred while trying to send the request: %s", e)
|
|
|
|
def _change_wallpaper_setup(self, path: str):
|
|
# if not config:
|
|
# return
|
|
# if not 'wallpaper' in config:
|
|
# return
|
|
|
|
# path = config['wallpaper']
|
|
if not path:
|
|
raise Exception(f"Setup Wallpaper - Invalid path ({path}).")
|
|
|
|
payload = json.dumps({"path": path})
|
|
headers = {
|
|
'Content-Type': 'application/json'
|
|
}
|
|
|
|
# send request to server to change wallpaper
|
|
try:
|
|
response = requests.post(self.http_server + "/setup" + "/change_wallpaper", headers=headers, data=payload)
|
|
if response.status_code == 200:
|
|
logger.info("Command executed successfully: %s", response.text)
|
|
else:
|
|
logger.error("Failed to change wallpaper. Status code: %s", response.text)
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error("An error occurred while trying to send the request: %s", e)
|
|
|
|
def _tidy_desktop_setup(self, **config):
|
|
raise NotImplementedError()
|
|
|
|
def _open_setup(self, path: str):
|
|
# if not config:
|
|
# return
|
|
# if not 'open' in config:
|
|
# return
|
|
# for path in config['open']:
|
|
if not path:
|
|
raise Exception(f"Setup Open - Invalid path ({path}).")
|
|
|
|
payload = json.dumps({"path": path})
|
|
headers = {
|
|
'Content-Type': 'application/json'
|
|
}
|
|
|
|
# send request to server to open file
|
|
try:
|
|
response = requests.post(self.http_server + "/setup" + "/open_file", headers=headers, data=payload)
|
|
if response.status_code == 200:
|
|
logger.info("Command executed successfully: %s", response.text)
|
|
else:
|
|
logger.error("Failed to open file. Status code: %s", response.text)
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error("An error occurred while trying to send the request: %s", e)
|
|
|
|
def _launch_setup(self, command: Union[str, List[str]], shell: bool = False):
|
|
if not command:
|
|
raise Exception("Empty command to launch.")
|
|
|
|
if not shell and isinstance(command, str) and len(command.split()) > 1:
|
|
logger.warning("Command should be a list of strings. Now it is a string. Will split it by space.")
|
|
command = command.split()
|
|
|
|
payload = json.dumps({"command": command, "shell": shell})
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
try:
|
|
response = requests.post(self.http_server + "/setup" + "/launch", headers=headers, data=payload)
|
|
if response.status_code == 200:
|
|
logger.info("Command executed successfully: %s", response.text)
|
|
else:
|
|
logger.error("Failed to launch application. Status code: %s", response.text)
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error("An error occurred while trying to send the request: %s", e)
|
|
|
|
def _execute_setup(
|
|
self,
|
|
command: List[str],
|
|
stdout: str = "",
|
|
stderr: str = "",
|
|
shell: bool = False,
|
|
until: Optional[Dict[str, Any]] = None
|
|
):
|
|
if not command:
|
|
raise Exception("Empty comman to launch.")
|
|
|
|
until: Dict[str, Any] = until or {}
|
|
terminates: bool = False
|
|
nb_failings = 0
|
|
|
|
payload = json.dumps({"command": command, "shell": shell})
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
while not terminates:
|
|
try:
|
|
response = requests.post(self.http_server + "/setup" + "/execute", headers=headers, data=payload)
|
|
if response.status_code == 200:
|
|
results: Dict[str, str] = response.json()
|
|
if stdout:
|
|
with open(os.path.join(self.cache_dir, stdout), "w") as f:
|
|
f.write(results["output"])
|
|
if stderr:
|
|
with open(os.path.join(self.cache_dir, stderr), "w") as f:
|
|
f.write(results["error"])
|
|
logger.info("Command executed successfully: %s -> %s"
|
|
, " ".join(command)
|
|
, response.text
|
|
)
|
|
else:
|
|
logger.error("Failed to launch application. Status code: %s", response.text)
|
|
results = None
|
|
nb_failings += 1
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error("An error occurred while trying to send the request: %s", e)
|
|
traceback.print_exc()
|
|
|
|
results = None
|
|
nb_failings += 1
|
|
|
|
if len(until) == 0:
|
|
terminates = True
|
|
elif results is not None:
|
|
terminates = "returncode" in until and results["returncode"] == until["returncode"] \
|
|
or "stdout" in until and until["stdout"] in results["output"] \
|
|
or "stderr" in until and until["stderr"] in results["error"]
|
|
terminates = terminates or nb_failings >= 5
|
|
if not terminates:
|
|
time.sleep(0.3)
|
|
|
|
def _command_setup(self, command: List[str], **kwargs):
|
|
self._execute_setup(command, **kwargs)
|
|
|
|
def _sleep_setup(self, seconds: float):
|
|
time.sleep(seconds)
|
|
|
|
def _act_setup(self, action_seq: List[Union[Dict[str, Any], str]]):
|
|
# TODO
|
|
raise NotImplementedError()
|
|
|
|
def _replay_setup(self, trajectory: str):
|
|
"""
|
|
Args:
|
|
trajectory (str): path to the replay trajectory file
|
|
"""
|
|
|
|
# TODO
|
|
raise NotImplementedError()
|
|
|
|
def _activate_window_setup(self, window_name: str):
|
|
if not window_name:
|
|
raise Exception(f"Setup Open - Invalid path ({window_name}).")
|
|
|
|
payload = json.dumps({"window_name": window_name})
|
|
headers = {
|
|
'Content-Type': 'application/json'
|
|
}
|
|
|
|
# send request to server to open file
|
|
try:
|
|
response = requests.post(self.http_server + "/setup" + "/activate_window", headers=headers, data=payload)
|
|
if response.status_code == 200:
|
|
logger.info("Command executed successfully: %s", response.text)
|
|
else:
|
|
logger.error(f"Failed to activate window {window_name}. Status code: %s", response.text)
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error("An error occurred while trying to send the request: %s", e)
|
|
|
|
# Chrome setup
|
|
def _chrome_open_tabs_setup(self, urls_to_open: List[str]):
|
|
host = self.vm_ip
|
|
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
|
|
|
remote_debugging_url = f"http://{host}:{port}"
|
|
with sync_playwright() as p:
|
|
browser = None
|
|
for attempt in range(15):
|
|
try:
|
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
|
break
|
|
except Exception as e:
|
|
if attempt < 14:
|
|
logger.error(f"Attempt {attempt + 1}: Failed to connect, retrying. Error: {e}")
|
|
time.sleep(1)
|
|
else:
|
|
logger.error(f"Failed to connect after multiple attempts: {e}")
|
|
raise e
|
|
|
|
if not browser:
|
|
return
|
|
|
|
for i, url in enumerate(urls_to_open):
|
|
# Use the first context (which should be the only one if using default profile)
|
|
if i == 0:
|
|
context = browser.contexts[0]
|
|
|
|
page = context.new_page() # Create a new page (tab) within the existing context
|
|
page.goto(url)
|
|
logger.info(f"Opened tab {i + 1}: {url}")
|
|
|
|
if i == 0:
|
|
# clear the default tab
|
|
default_page = context.pages[0]
|
|
default_page.close()
|
|
|
|
# Do not close the context or browser; they will remain open after script ends
|
|
return browser, context
|
|
|
|
def _chrome_close_tabs_setup(self, urls_to_close: List[str]):
|
|
time.sleep(5) # Wait for Chrome to finish launching
|
|
|
|
host = self.vm_ip
|
|
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
|
|
|
remote_debugging_url = f"http://{host}:{port}"
|
|
with sync_playwright() as p:
|
|
browser = None
|
|
for attempt in range(15):
|
|
try:
|
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
|
break
|
|
except Exception as e:
|
|
if attempt < 14:
|
|
logger.error(f"Attempt {attempt + 1}: Failed to connect, retrying. Error: {e}")
|
|
time.sleep(1)
|
|
else:
|
|
logger.error(f"Failed to connect after multiple attempts: {e}")
|
|
raise e
|
|
|
|
if not browser:
|
|
return
|
|
|
|
for i, url in enumerate(urls_to_close):
|
|
# Use the first context (which should be the only one if using default profile)
|
|
if i == 0:
|
|
context = browser.contexts[0]
|
|
|
|
for page in context.pages:
|
|
|
|
# if two urls are the same, close the tab
|
|
if compare_urls(page.url, url):
|
|
context.pages.pop(context.pages.index(page))
|
|
page.close()
|
|
logger.info(f"Closed tab {i + 1}: {url}")
|
|
break
|
|
|
|
# Do not close the context or browser; they will remain open after script ends
|
|
return browser, context
|