From 7c64b6273596f6be94401ba02a088ae5a43e78f4 Mon Sep 17 00:00:00 2001 From: FredWuCZ Date: Fri, 27 Sep 2024 11:03:21 +0800 Subject: [PATCH] Add docker provider framework --- desktop_env/providers/__init__.py | 20 +- desktop_env/providers/docker/manager.py | 331 +++++++++++++++++++++++ desktop_env/providers/docker/provider.py | 64 +++++ 3 files changed, 407 insertions(+), 8 deletions(-) create mode 100644 desktop_env/providers/docker/manager.py create mode 100644 desktop_env/providers/docker/provider.py diff --git a/desktop_env/providers/__init__.py b/desktop_env/providers/__init__.py index 8c12197..7c95382 100644 --- a/desktop_env/providers/__init__.py +++ b/desktop_env/providers/__init__.py @@ -1,12 +1,4 @@ from desktop_env.providers.base import VMManager, Provider -from desktop_env.providers.vmware.manager import VMwareVMManager -from desktop_env.providers.vmware.provider import VMwareProvider -from desktop_env.providers.aws.manager import AWSVMManager -from desktop_env.providers.aws.provider import AWSProvider -from desktop_env.providers.azure.manager import AzureVMManager -from desktop_env.providers.azure.provider import AzureProvider -from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager -from desktop_env.providers.virtualbox.provider import VirtualBoxProvider def create_vm_manager_and_provider(provider_name: str, region: str): """ @@ -14,12 +6,24 @@ def create_vm_manager_and_provider(provider_name: str, region: str): """ provider_name = provider_name.lower().strip() if provider_name == "vmware": + from desktop_env.providers.vmware.manager import VMwareVMManager + from desktop_env.providers.vmware.provider import VMwareProvider return VMwareVMManager(), VMwareProvider(region) elif provider_name == "virtualbox": + from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager + from desktop_env.providers.virtualbox.provider import VirtualBoxProvider return VirtualBoxVMManager(), VirtualBoxProvider(region) elif provider_name in ["aws", "amazon web services"]: + from desktop_env.providers.aws.manager import AWSVMManager + from desktop_env.providers.aws.provider import AWSProvider return AWSVMManager(), AWSProvider(region) elif provider_name == "azure": + from desktop_env.providers.azure.manager import AzureVMManager + from desktop_env.providers.azure.provider import AzureProvider return AzureVMManager(), AzureProvider(region) + elif provider_name == "docker": + from desktop_env.providers.docker.manager import DockerVMManager + from desktop_env.providers.docker.provider import DockerProvider + return DockerVMManager(), DockerProvider(region) else: raise NotImplementedError(f"{provider_name} not implemented!") diff --git a/desktop_env/providers/docker/manager.py b/desktop_env/providers/docker/manager.py new file mode 100644 index 0000000..e1ebdb8 --- /dev/null +++ b/desktop_env/providers/docker/manager.py @@ -0,0 +1,331 @@ +import os +import platform +import random +import re + +import threading +from filelock import FileLock +import uuid +import zipfile + +from time import sleep +import shutil +import psutil +import subprocess +import requests +from tqdm import tqdm +import docker + +import logging + +from desktop_env.providers.base import VMManager + +logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager") +logger.setLevel(logging.INFO) + +MAX_RETRY_TIMES = 10 +RETRY_INTERVAL = 5 +UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip" + +# Determine the platform and CPU architecture to decide the correct VM image to download +# if platform.system() == 'Darwin': # macOS +# # if os.uname().machine == 'arm64': # Apple Silicon +# URL = UBUNTU_ARM_URL +# # else: +# # url = UBUNTU_X86_URL +# elif platform.machine().lower() in ['amd64', 'x86_64']: +# URL = UBUNTU_X86_URL +# else: +# raise Exception("Unsupported platform or architecture") + +URL = UBUNTU_X86_URL + +DOWNLOADED_FILE_NAME = URL.split('/')[-1] +REGISTRY_PATH = '.docker_vms' +LOCK_FILE_NAME = '.docker_lck' +VMS_DIR = "./docker_vm_data" +update_lock = threading.Lock() + +if platform.system() == 'Windows': + docker_path = r"C:\Program Files\Docker\Docker" + os.environ["PATH"] += os.pathsep + docker_path + +def generate_new_vm_name(vms_dir, os_type): + registry_idx = 0 + prefix = os_type + while True: + attempted_new_name = f"{prefix}{registry_idx}" + if os.path.exists( + os.path.join(vms_dir, attempted_new_name, attempted_new_name + ".qcow2")): + registry_idx += 1 + else: + return attempted_new_name + +def _install_vm(vm_name, vms_dir, downloaded_file_name, os_type, original_vm_name="Ubuntu"): + os.makedirs(vms_dir, exist_ok=True) + + def __download_and_unzip_vm(): + # Download the virtual machine image + logger.info("Downloading the virtual machine image...") + downloaded_size = 0 + + if os_type == "Ubuntu": + if platform.system() == 'Darwin': + URL = UBUNTU_X86_URL + elif platform.machine().lower() in ['amd64', 'x86_64']: + URL = UBUNTU_X86_URL + elif os_type == "Windows": + if platform.machine().lower() in ['amd64', 'x86_64']: + URL = WINDOWS_X86_URL + DOWNLOADED_FILE_NAME = URL.split('/')[-1] + downloaded_file_name = DOWNLOADED_FILE_NAME + + while True: + downloaded_file_path = os.path.join(vms_dir, downloaded_file_name) + headers = {} + if os.path.exists(downloaded_file_path): + downloaded_size = os.path.getsize(downloaded_file_path) + headers["Range"] = f"bytes={downloaded_size}-" + + with requests.get(URL, headers=headers, stream=True) as response: + if response.status_code == 416: + # This means the range was not satisfiable, possibly the file was fully downloaded + logger.info("Fully downloaded or the file size changed.") + break + + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + + with open(downloaded_file_path, "ab") as file, tqdm( + desc="Progress", + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + initial=downloaded_size, + ascii=True + ) as progress_bar: + try: + for data in response.iter_content(chunk_size=1024): + size = file.write(data) + progress_bar.update(size) + except (requests.exceptions.RequestException, IOError) as e: + logger.error(f"Download error: {e}") + sleep(RETRY_INTERVAL) + logger.error("Retrying...") + else: + logger.info("Download succeeds.") + break # Download completed successfully + + # Unzip the downloaded file + logger.info("Unzipping the downloaded file...☕️") + with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref: + zip_ref.extractall(os.path.join(vms_dir, vm_name)) + logger.info("Files have been successfully extracted to the directory: " + str(os.path.join(vms_dir, vm_name))) + + vm_path = os.path.join(vms_dir, vm_name, vm_name, vm_name + ".vmx") + + # Start the virtual machine + def start_vm(vm_path, max_retries=20): + pass + + if not start_vm(vm_path): + raise ValueError("Error encountered during installation, please rerun the code for retrying.") + + def get_vm_ip_and_port(vm_path, max_retries=20): + pass + + vm_ip, vm_port = get_vm_ip_and_port(vm_path) + if not vm_ip: + raise ValueError("Error encountered during installation, please rerun the code for retrying.") + + # Function used to check whether the virtual machine is ready + def download_screenshot(ip, port): + url = f"http://{ip}:{port}/screenshot" + try: + # max trey times 1, max timeout 1 + response = requests.get(url, timeout=(10, 10)) + if response.status_code == 200: + return True + except Exception as e: + logger.error(f"Error: {e}") + logger.error(f"Type: {type(e).__name__}") + logger.error(f"Error detail: {str(e)}") + sleep(RETRY_INTERVAL) + return False + + # Try downloading the screenshot until successful + while not download_screenshot(vm_ip, vm_port): + logger.info("Check whether the virtual machine is ready...") + + logger.info("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...") + + +class DockerVMManager(VMManager): + def __init__(self, registry_path=REGISTRY_PATH): + self.registry_path = registry_path + self.lock = FileLock(LOCK_FILE_NAME, timeout=60) + self.initialize_registry() + self.client = docker.from_env() + + def initialize_registry(self): + with self.lock: # Locking during initialization + if not os.path.exists(self.registry_path): + with open(self.registry_path, 'w') as file: + file.write('') + + def add_vm(self, vm_path, lock_needed=True): + if lock_needed: + with self.lock: + self._add_vm(vm_path) + else: + self._add_vm(vm_path) + + def _add_vm(self, vm_path, region=None): + assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'." + with self.lock: + with open(self.registry_path, 'r') as file: + lines = file.readlines() + new_lines = lines + [f'{vm_path}|free\n'] + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def occupy_vm(self, vm_path, pid, lock_needed=True): + if lock_needed: + with self.lock: + self._occupy_vm(vm_path, pid) + else: + self._occupy_vm(vm_path, pid) + + def _occupy_vm(self, vm_path, pid, region=None): + assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'." + with self.lock: + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + registered_vm_path, _ = line.strip().split('|') + if registered_vm_path == vm_path: + new_lines.append(f'{registered_vm_path}|{pid}\n') + else: + new_lines.append(line) + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def delete_vm(self, vm_path, lock_needed=True): + if lock_needed: + with self.lock: + self._delete_vm(vm_path) + else: + self._delete_vm(vm_path) + + def _delete_vm(self, vm_path): + raise NotImplementedError + + def check_and_clean(self, vms_dir, lock_needed=True): + if lock_needed: + with self.lock: + self._check_and_clean(vms_dir) + else: + self._check_and_clean(vms_dir) + + def _check_and_clean(self, vms_dir): + with self.lock: # Lock when cleaning up the registry and vms_dir + # Check and clean on the running vms, detect the released ones and mark then as 'free' + active_pids = {p.pid for p in psutil.process_iter()} + new_lines = [] + vm_paths = [] + + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if not os.path.exists(vm_path): + logger.info(f"VM {vm_path} not found, releasing it.") + new_lines.append(f'{vm_path}|free\n') + continue + + vm_paths.append(vm_path) + if pid_str == "free": + new_lines.append(line) + continue + + if int(pid_str) in active_pids: + new_lines.append(line) + else: + new_lines.append(f'{vm_path}|free\n') + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + # Check and clean on the files inside vms_dir, delete the unregistered ones + os.makedirs(vms_dir, exist_ok=True) + vm_names = os.listdir(vms_dir) + for vm_name in vm_names: + # skip the downloaded .zip file + if vm_name == DOWNLOADED_FILE_NAME: + continue + # Skip the .DS_Store file on macOS + if vm_name == ".DS_Store": + continue + + flag = True + for vm_path in vm_paths: + if vm_name + ".qcow2" in vm_path: + flag = False + elif vm_name + ".img" in vm_path: + flag = False + if flag: + shutil.rmtree(os.path.join(vms_dir, vm_name)) + + def list_free_vms(self, lock_needed=True): + if lock_needed: + with self.lock: + return self._list_free_vms() + else: + return self._list_free_vms() + + def _list_free_vms(self): + with self.lock: # Lock when reading the registry + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if pid_str == "free": + free_vms.append((vm_path, pid_str)) + return free_vms + + def get_vm_path(self, os_type, region=None): + with self.lock: + if not DockerVMManager.checked_and_cleaned: + DockerVMManager.checked_and_cleaned = True + self._check_and_clean(vms_dir=VMS_DIR) + + allocation_needed = False + with self.lock: + free_vms_paths = self._list_free_vms() + if len(free_vms_paths) == 0: + # No free virtual machine available, generate a new one + allocation_needed = True + else: + # Choose the first free virtual machine + chosen_vm_path = free_vms_paths[0][0] + self._occupy_vm(chosen_vm_path, os.getpid()) + return chosen_vm_path + + if allocation_needed: + logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") + new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR, os_type=os_type) + + original_vm_name = None + if os_type == "Ubuntu": + original_vm_name = "Ubuntu" + elif os_type == "Windows": + original_vm_name = "Windows 10 x64" + + new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR, + downloaded_file_name=DOWNLOADED_FILE_NAME, original_vm_name=original_vm_name, os_type=os_type) + with self.lock: + self._add_vm(new_vm_path) + self._occupy_vm(new_vm_path, os.getpid()) + return new_vm_path diff --git a/desktop_env/providers/docker/provider.py b/desktop_env/providers/docker/provider.py new file mode 100644 index 0000000..b60eaca --- /dev/null +++ b/desktop_env/providers/docker/provider.py @@ -0,0 +1,64 @@ +import logging +import os +import platform +import subprocess +import time +import docker + +from desktop_env.providers.base import Provider + +logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider") +logger.setLevel(logging.INFO) + +WAIT_TIME = 3 + + +def get_vmrun_type(return_list=False): + if platform.system() == 'Windows' or platform.system() == 'Linux': + if return_list: + return ['-T', 'ws'] + else: + return '-T ws' + elif platform.system() == 'Darwin': # Darwin is the system name for macOS + if return_list: + return ['-T', 'fusion'] + else: + return '-T fusion' + else: + raise Exception("Unsupported operating system") + + +class DockerProvider(Provider): + def __init__(self, region: str): + self.client = docker.from_env() + + @staticmethod + def _execute_command(command: list, return_output=False): + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding="utf-8" + ) + + if return_output: + output = process.communicate()[0].strip() + return output + else: + return None + + def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): + pass + + def get_ip_address(self, path_to_vm: str) -> str: + pass + + def save_state(self, path_to_vm: str, snapshot_name: str): + pass + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + pass + + def stop_emulator(self, path_to_vm: str): + pass \ No newline at end of file