From 6419d707bccfc2c3c5a12931caf5071362e81ee2 Mon Sep 17 00:00:00 2001 From: HappySix <33394488+FredWuCZ@users.noreply.github.com> Date: Sat, 28 Sep 2024 21:10:40 +0800 Subject: [PATCH 1/2] Support Docker VM manager and provider (#75) * Add docker provider framework * Update VM download link * Add stop container * Update docker manager & provider * Update * Update * Update provider --- desktop_env/controllers/python.py | 3 +- desktop_env/controllers/setup.py | 12 ++- desktop_env/desktop_env.py | 23 ++-- desktop_env/evaluators/getters/chrome.py | 42 +++++--- desktop_env/evaluators/getters/general.py | 4 +- desktop_env/providers/__init__.py | 20 ++-- desktop_env/providers/docker/manager.py | 123 ++++++++++++++++++++++ desktop_env/providers/docker/provider.py | 67 ++++++++++++ main.py | 2 +- requirements.txt | 4 - 10 files changed, 256 insertions(+), 44 deletions(-) create mode 100644 desktop_env/providers/docker/manager.py create mode 100644 desktop_env/providers/docker/provider.py diff --git a/desktop_env/controllers/python.py b/desktop_env/controllers/python.py index a3b19c5..c572083 100644 --- a/desktop_env/controllers/python.py +++ b/desktop_env/controllers/python.py @@ -12,9 +12,10 @@ logger = logging.getLogger("desktopenv.pycontroller") class PythonController: def __init__(self, vm_ip: str, + server_port: int, pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"): self.vm_ip = vm_ip - self.http_server = f"http://{vm_ip}:5000" + self.http_server = f"http://{vm_ip}:{server_port}" self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages self.retry_times = 3 self.retry_interval = 5 diff --git a/desktop_env/controllers/setup.py b/desktop_env/controllers/setup.py index b4ae4b2..f03e9cd 100644 --- a/desktop_env/controllers/setup.py +++ b/desktop_env/controllers/setup.py @@ -28,10 +28,12 @@ FILE_PATH = os.path.dirname(os.path.abspath(__file__)) class SetupController: - def __init__(self, vm_ip: str, cache_dir: str): + def __init__(self, vm_ip: str, server_port: int, chromium_port: int, cache_dir: str): self.vm_ip: str = vm_ip - self.http_server: str = f"http://{vm_ip}:5000" - self.http_server_setup_root: str = f"http://{vm_ip}:5000/setup" + self.server_port: int = server_port + self.chromium_port: int = chromium_port + self.http_server: str = f"http://{vm_ip}:{server_port}" + self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup" self.cache_dir: str = cache_dir def reset_cache_dir(self, cache_dir: str): @@ -348,7 +350,7 @@ class SetupController: # Chrome setup def _chrome_open_tabs_setup(self, urls_to_open: List[str]): host = self.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = self.chromium_port # fixme: this port is hard-coded, need to be changed from config file remote_debugging_url = f"http://{host}:{port}" logger.info("Connect to Chrome @: %s", remote_debugging_url) @@ -399,7 +401,7 @@ class SetupController: time.sleep(5) # Wait for Chrome to finish launching host = self.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = self.server_port # fixme: this port is hard-coded, need to be changed from config file remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index e02eacf..9b24d05 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -26,7 +26,7 @@ class DesktopEnv(gym.Env): def __init__( self, - provider_name: str = "vmware", + provider_name: str = "docker", region: str = None, path_to_vm: str = None, snapshot_name: str = "init_state", @@ -54,6 +54,11 @@ class DesktopEnv(gym.Env): """ # Initialize VM manager and vitualization provider self.region = region + + # Default + self.server_port = 5000 + self.chromium_port = 9222 + self.vnc_port = 8006 self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) self.os_type = os_type @@ -73,8 +78,9 @@ class DesktopEnv(gym.Env): self.require_terminal = require_terminal # Initialize emulator and controller - logger.info("Initializing...") - self._start_emulator() + if provider_name != "docker": # Check if this is applicable to other VM providers + logger.info("Initializing...") + self._start_emulator() # mode: human or machine self.instruction = None @@ -92,9 +98,14 @@ class DesktopEnv(gym.Env): self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type) # Get the ip from the virtual machine, and setup the controller - self.vm_ip = self.provider.get_ip_address(self.path_to_vm) - self.controller = PythonController(vm_ip=self.vm_ip) - self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base) + vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':') + self.vm_ip = vm_ip_ports[0] + if len(vm_ip_ports) > 1: + self.server_port = int(vm_ip_ports[1]) + self.chromium_port = int(vm_ip_ports[2]) + self.vnc_port = int(vm_ip_ports[3]) + self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port) + self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, cache_dir=self.cache_dir_base) def _revert_to_snapshot(self): # Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm diff --git a/desktop_env/evaluators/getters/chrome.py b/desktop_env/evaluators/getters/chrome.py index 94c0dda..ba384de 100644 --- a/desktop_env/evaluators/getters/chrome.py +++ b/desktop_env/evaluators/getters/chrome.py @@ -54,7 +54,8 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any: """ try: host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: # connect to remote Chrome instance @@ -68,7 +69,7 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any: "--remote-debugging-port=1337" ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -454,7 +455,8 @@ def get_extensions_installed_from_shop(env, config: Dict[str, str]): def get_page_info(env, config: Dict[str, str]): host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port url = config["url"] remote_debugging_url = f"http://{host}:{port}" @@ -478,7 +480,7 @@ def get_page_info(env, config: Dict[str, str]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -505,7 +507,8 @@ def get_page_info(env, config: Dict[str, str]): def get_open_tabs_info(env, config: Dict[str, str]): host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -528,7 +531,7 @@ def get_open_tabs_info(env, config: Dict[str, str]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) try: browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -643,7 +646,7 @@ def get_active_tab_info(env, config: Dict[str, str]): logger.error("Failed to get the url of active tab") return None host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -683,7 +686,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str: _path = os.path.join(env.cache_dir, config["dest"]) host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" @@ -706,7 +710,7 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str: ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -721,7 +725,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str: # fixme: needs to be changed (maybe through post-processing) since it's not working def get_chrome_saved_address(env, config: Dict[str, str]): host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -744,7 +749,7 @@ def get_chrome_saved_address(env, config: Dict[str, str]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -800,7 +805,8 @@ def get_number_of_search_results(env, config: Dict[str, str]): # todo: move into the config file url, result_selector = "https://google.com/search?q=query", '.search-result' host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -822,7 +828,7 @@ def get_number_of_search_results(env, config: Dict[str, str]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) page = browser.new_page() @@ -1145,7 +1151,8 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]): logger.error("active_tab_url is not a string") return None host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -1168,7 +1175,7 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) target_page = None @@ -1237,7 +1244,8 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]): especially used for www.recreation.gov examples """ host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -1259,7 +1267,7 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) page = browser.new_page() diff --git a/desktop_env/evaluators/getters/general.py b/desktop_env/evaluators/getters/general.py index 81ad69b..2f5ed32 100644 --- a/desktop_env/evaluators/getters/general.py +++ b/desktop_env/evaluators/getters/general.py @@ -7,7 +7,7 @@ logger = logging.getLogger("desktopenv.getters.general") def get_vm_command_line(env, config: Dict[str, str]): vm_ip = env.vm_ip - port = 5000 + port = env.server_port command = config["command"] shell = config.get("shell", False) @@ -23,7 +23,7 @@ def get_vm_command_line(env, config: Dict[str, str]): def get_vm_command_error(env, config: Dict[str, str]): vm_ip = env.vm_ip - port = 5000 + port = env.server_port command = config["command"] shell = config.get("shell", False) diff --git a/desktop_env/providers/__init__.py b/desktop_env/providers/__init__.py index 8c12197..7c95382 100644 --- a/desktop_env/providers/__init__.py +++ b/desktop_env/providers/__init__.py @@ -1,12 +1,4 @@ from desktop_env.providers.base import VMManager, Provider -from desktop_env.providers.vmware.manager import VMwareVMManager -from desktop_env.providers.vmware.provider import VMwareProvider -from desktop_env.providers.aws.manager import AWSVMManager -from desktop_env.providers.aws.provider import AWSProvider -from desktop_env.providers.azure.manager import AzureVMManager -from desktop_env.providers.azure.provider import AzureProvider -from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager -from desktop_env.providers.virtualbox.provider import VirtualBoxProvider def create_vm_manager_and_provider(provider_name: str, region: str): """ @@ -14,12 +6,24 @@ def create_vm_manager_and_provider(provider_name: str, region: str): """ provider_name = provider_name.lower().strip() if provider_name == "vmware": + from desktop_env.providers.vmware.manager import VMwareVMManager + from desktop_env.providers.vmware.provider import VMwareProvider return VMwareVMManager(), VMwareProvider(region) elif provider_name == "virtualbox": + from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager + from desktop_env.providers.virtualbox.provider import VirtualBoxProvider return VirtualBoxVMManager(), VirtualBoxProvider(region) elif provider_name in ["aws", "amazon web services"]: + from desktop_env.providers.aws.manager import AWSVMManager + from desktop_env.providers.aws.provider import AWSProvider return AWSVMManager(), AWSProvider(region) elif provider_name == "azure": + from desktop_env.providers.azure.manager import AzureVMManager + from desktop_env.providers.azure.provider import AzureProvider return AzureVMManager(), AzureProvider(region) + elif provider_name == "docker": + from desktop_env.providers.docker.manager import DockerVMManager + from desktop_env.providers.docker.provider import DockerProvider + return DockerVMManager(), DockerProvider(region) else: raise NotImplementedError(f"{provider_name} not implemented!") diff --git a/desktop_env/providers/docker/manager.py b/desktop_env/providers/docker/manager.py new file mode 100644 index 0000000..e4b09a1 --- /dev/null +++ b/desktop_env/providers/docker/manager.py @@ -0,0 +1,123 @@ +import os +import platform +import random +import re + +import threading +from filelock import FileLock +import uuid +import zipfile + +from time import sleep +import shutil +import psutil +import subprocess +import requests +from tqdm import tqdm +import docker + +import logging + +from desktop_env.providers.base import VMManager + +logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager") +logger.setLevel(logging.INFO) + +MAX_RETRY_TIMES = 10 +RETRY_INTERVAL = 5 + +UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2" +VMS_DIR = "./docker_vm_data" + +# Determine the platform and CPU architecture to decide the correct VM image to download +# if platform.system() == 'Darwin': # macOS +# # if os.uname().machine == 'arm64': # Apple Silicon +# URL = UBUNTU_ARM_URL +# # else: +# # url = UBUNTU_X86_URL +# elif platform.machine().lower() in ['amd64', 'x86_64']: +# URL = UBUNTU_X86_URL +# else: +# raise Exception("Unsupported platform or architecture") +URL = UBUNTU_X86_URL + +DOWNLOADED_FILE_NAME = URL.split('/')[-1] + +if platform.system() == 'Windows': + docker_path = r"C:\Program Files\Docker\Docker" + os.environ["PATH"] += os.pathsep + docker_path + +def _download_vm(vms_dir: str): + # Download the virtual machine image + logger.info("Downloading the virtual machine image...") + downloaded_size = 0 + + URL = UBUNTU_X86_URL + DOWNLOADED_FILE_NAME = URL.split('/')[-1] + downloaded_file_name = DOWNLOADED_FILE_NAME + + os.makedirs(vms_dir, exist_ok=True) + + while True: + downloaded_file_path = os.path.join(vms_dir, downloaded_file_name) + headers = {} + if os.path.exists(downloaded_file_path): + downloaded_size = os.path.getsize(downloaded_file_path) + headers["Range"] = f"bytes={downloaded_size}-" + + with requests.get(URL, headers=headers, stream=True) as response: + if response.status_code == 416: + # This means the range was not satisfiable, possibly the file was fully downloaded + logger.info("Fully downloaded or the file size changed.") + break + + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + + with open(downloaded_file_path, "ab") as file, tqdm( + desc="Progress", + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + initial=downloaded_size, + ascii=True + ) as progress_bar: + try: + for data in response.iter_content(chunk_size=1024): + size = file.write(data) + progress_bar.update(size) + except (requests.exceptions.RequestException, IOError) as e: + logger.error(f"Download error: {e}") + sleep(RETRY_INTERVAL) + logger.error("Retrying...") + else: + logger.info("Download succeeds.") + break # Download completed successfully + +class DockerVMManager(VMManager): + def __init__(self, registry_path=""): + pass + + def add_vm(self, vm_path): + pass + + def check_and_clean(self): + pass + + def delete_vm(self, vm_path): + pass + + def initialize_registry(self): + pass + + def list_free_vms(self): + return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME) + + def occupy_vm(self, vm_path): + pass + + def get_vm_path(self, os_type, region): + if not os.path.exists(os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)): + _download_vm(VMS_DIR) + return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME) \ No newline at end of file diff --git a/desktop_env/providers/docker/provider.py b/desktop_env/providers/docker/provider.py new file mode 100644 index 0000000..9ceab84 --- /dev/null +++ b/desktop_env/providers/docker/provider.py @@ -0,0 +1,67 @@ +import logging +import os +import platform +import subprocess +import time +import docker +import psutil +import requests + +from desktop_env.providers.base import Provider + +logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider") +logger.setLevel(logging.INFO) + +WAIT_TIME = 3 +RETRY_INTERVAL = 1 + +class DockerProvider(Provider): + def __init__(self, region: str): + self.client = docker.from_env() + self.vnc_port = self._get_available_port(8006) + self.server_port = self._get_available_port(5000) + # self.remote_debugging_port = self._get_available_port(1337) + self.chromium_port = self._get_available_port(9222) + self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed + + @staticmethod + def _get_available_port(port: int): + while port < 65354: + if port not in [conn.laddr.port for conn in psutil.net_connections()]: + return port + port += 1 + + def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): + logger.info(f"Occupying ports: {self.vnc_port}, {self.server_port}, {self.chromium_port}") + self.container = self.client.containers.run("happysixd/osworld-docker", environment=self.environment, cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={os.path.abspath(path_to_vm): {"bind": "/Ubuntu.qcow2", "mode": "ro"}}, ports={8006: self.vnc_port, 5000: self.server_port, 9222: self.chromium_port}, detach=True) + def download_screenshot(ip, port): + url = f"http://{ip}:{port}/screenshot" + try: + # max trey times 1, max timeout 1 + response = requests.get(url, timeout=(10, 10)) + if response.status_code == 200: + return True + except Exception as e: + time.sleep(RETRY_INTERVAL) + return False + + # Try downloading the screenshot until successful + while not download_screenshot("localhost", self.server_port): + logger.info("Check whether the virtual machine is ready...") + + def get_ip_address(self, path_to_vm: str) -> str: + return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}" + + def save_state(self, path_to_vm: str, snapshot_name: str): + raise NotImplementedError("Not available for Docker.") + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + pass + + def stop_emulator(self, path_to_vm: str): + logger.info("Stopping VM...") + self.container.stop() + self.container.remove() + time.sleep(WAIT_TIME) + +# docker run -it --rm -e "DISK_SIZE=64G" -e "RAM_SIZE=8G" -e "CPU_CORES=8" --volume /home/$USER/osworld/docker_vm_data/Ubuntu.qcow2:/Ubuntu.qcow2:ro --cap-add NET_ADMIN --device /dev/kvm -p 8008:8006 -p 5002:5000 happysixd/osworld-docker \ No newline at end of file diff --git a/main.py b/main.py index 4c6817f..6f03227 100644 --- a/main.py +++ b/main.py @@ -78,7 +78,7 @@ def human_agent(): result = env.evaluate() logger.info("Result: %.2f", result) - # env.close() + env.close() logger.info("Environment closed.") diff --git a/requirements.txt b/requirements.txt index 439f85b..2cc96ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,7 +53,3 @@ wrapt_timeout_decorator gdown tiktoken groq -boto3 -azure-identity -azure-mgmt-compute -azure-mgmt-network From 24bad80b531efe0e6e4fd35740a2fd60c95724f6 Mon Sep 17 00:00:00 2001 From: FredWuCZ Date: Sat, 28 Sep 2024 22:01:06 +0800 Subject: [PATCH 2/2] Add requirements for docker --- requirements.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/requirements.txt b/requirements.txt index 2cc96ab..72ca6fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,3 +53,8 @@ wrapt_timeout_decorator gdown tiktoken groq +boto3 +azure-identity +azure-mgmt-compute +azure-mgmt-network +docker \ No newline at end of file