Add docker provider framework

This commit is contained in:
FredWuCZ
2024-09-27 11:03:21 +08:00
parent 3b94cb4b74
commit 7c64b62735
3 changed files with 407 additions and 8 deletions

View File

@@ -1,12 +1,4 @@
from desktop_env.providers.base import VMManager, Provider
from desktop_env.providers.vmware.manager import VMwareVMManager
from desktop_env.providers.vmware.provider import VMwareProvider
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
from desktop_env.providers.azure.manager import AzureVMManager
from desktop_env.providers.azure.provider import AzureProvider
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
def create_vm_manager_and_provider(provider_name: str, region: str):
"""
@@ -14,12 +6,24 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
"""
provider_name = provider_name.lower().strip()
if provider_name == "vmware":
from desktop_env.providers.vmware.manager import VMwareVMManager
from desktop_env.providers.vmware.provider import VMwareProvider
return VMwareVMManager(), VMwareProvider(region)
elif provider_name == "virtualbox":
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
return VirtualBoxVMManager(), VirtualBoxProvider(region)
elif provider_name in ["aws", "amazon web services"]:
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
return AWSVMManager(), AWSProvider(region)
elif provider_name == "azure":
from desktop_env.providers.azure.manager import AzureVMManager
from desktop_env.providers.azure.provider import AzureProvider
return AzureVMManager(), AzureProvider(region)
elif provider_name == "docker":
from desktop_env.providers.docker.manager import DockerVMManager
from desktop_env.providers.docker.provider import DockerProvider
return DockerVMManager(), DockerProvider(region)
else:
raise NotImplementedError(f"{provider_name} not implemented!")

View File

@@ -0,0 +1,331 @@
import os
import platform
import random
import re
import threading
from filelock import FileLock
import uuid
import zipfile
from time import sleep
import shutil
import psutil
import subprocess
import requests
from tqdm import tqdm
import docker
import logging
from desktop_env.providers.base import VMManager
logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager")
logger.setLevel(logging.INFO)
MAX_RETRY_TIMES = 10
RETRY_INTERVAL = 5
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip"
# Determine the platform and CPU architecture to decide the correct VM image to download
# if platform.system() == 'Darwin': # macOS
# # if os.uname().machine == 'arm64': # Apple Silicon
# URL = UBUNTU_ARM_URL
# # else:
# # url = UBUNTU_X86_URL
# elif platform.machine().lower() in ['amd64', 'x86_64']:
# URL = UBUNTU_X86_URL
# else:
# raise Exception("Unsupported platform or architecture")
URL = UBUNTU_X86_URL
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
REGISTRY_PATH = '.docker_vms'
LOCK_FILE_NAME = '.docker_lck'
VMS_DIR = "./docker_vm_data"
update_lock = threading.Lock()
if platform.system() == 'Windows':
docker_path = r"C:\Program Files\Docker\Docker"
os.environ["PATH"] += os.pathsep + docker_path
def generate_new_vm_name(vms_dir, os_type):
registry_idx = 0
prefix = os_type
while True:
attempted_new_name = f"{prefix}{registry_idx}"
if os.path.exists(
os.path.join(vms_dir, attempted_new_name, attempted_new_name + ".qcow2")):
registry_idx += 1
else:
return attempted_new_name
def _install_vm(vm_name, vms_dir, downloaded_file_name, os_type, original_vm_name="Ubuntu"):
os.makedirs(vms_dir, exist_ok=True)
def __download_and_unzip_vm():
# Download the virtual machine image
logger.info("Downloading the virtual machine image...")
downloaded_size = 0
if os_type == "Ubuntu":
if platform.system() == 'Darwin':
URL = UBUNTU_X86_URL
elif platform.machine().lower() in ['amd64', 'x86_64']:
URL = UBUNTU_X86_URL
elif os_type == "Windows":
if platform.machine().lower() in ['amd64', 'x86_64']:
URL = WINDOWS_X86_URL
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
downloaded_file_name = DOWNLOADED_FILE_NAME
while True:
downloaded_file_path = os.path.join(vms_dir, downloaded_file_name)
headers = {}
if os.path.exists(downloaded_file_path):
downloaded_size = os.path.getsize(downloaded_file_path)
headers["Range"] = f"bytes={downloaded_size}-"
with requests.get(URL, headers=headers, stream=True) as response:
if response.status_code == 416:
# This means the range was not satisfiable, possibly the file was fully downloaded
logger.info("Fully downloaded or the file size changed.")
break
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(downloaded_file_path, "ab") as file, tqdm(
desc="Progress",
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
initial=downloaded_size,
ascii=True
) as progress_bar:
try:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
progress_bar.update(size)
except (requests.exceptions.RequestException, IOError) as e:
logger.error(f"Download error: {e}")
sleep(RETRY_INTERVAL)
logger.error("Retrying...")
else:
logger.info("Download succeeds.")
break # Download completed successfully
# Unzip the downloaded file
logger.info("Unzipping the downloaded file...☕️")
with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref:
zip_ref.extractall(os.path.join(vms_dir, vm_name))
logger.info("Files have been successfully extracted to the directory: " + str(os.path.join(vms_dir, vm_name)))
vm_path = os.path.join(vms_dir, vm_name, vm_name, vm_name + ".vmx")
# Start the virtual machine
def start_vm(vm_path, max_retries=20):
pass
if not start_vm(vm_path):
raise ValueError("Error encountered during installation, please rerun the code for retrying.")
def get_vm_ip_and_port(vm_path, max_retries=20):
pass
vm_ip, vm_port = get_vm_ip_and_port(vm_path)
if not vm_ip:
raise ValueError("Error encountered during installation, please rerun the code for retrying.")
# Function used to check whether the virtual machine is ready
def download_screenshot(ip, port):
url = f"http://{ip}:{port}/screenshot"
try:
# max trey times 1, max timeout 1
response = requests.get(url, timeout=(10, 10))
if response.status_code == 200:
return True
except Exception as e:
logger.error(f"Error: {e}")
logger.error(f"Type: {type(e).__name__}")
logger.error(f"Error detail: {str(e)}")
sleep(RETRY_INTERVAL)
return False
# Try downloading the screenshot until successful
while not download_screenshot(vm_ip, vm_port):
logger.info("Check whether the virtual machine is ready...")
logger.info("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...")
class DockerVMManager(VMManager):
def __init__(self, registry_path=REGISTRY_PATH):
self.registry_path = registry_path
self.lock = FileLock(LOCK_FILE_NAME, timeout=60)
self.initialize_registry()
self.client = docker.from_env()
def initialize_registry(self):
with self.lock: # Locking during initialization
if not os.path.exists(self.registry_path):
with open(self.registry_path, 'w') as file:
file.write('')
def add_vm(self, vm_path, lock_needed=True):
if lock_needed:
with self.lock:
self._add_vm(vm_path)
else:
self._add_vm(vm_path)
def _add_vm(self, vm_path, region=None):
assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'."
with self.lock:
with open(self.registry_path, 'r') as file:
lines = file.readlines()
new_lines = lines + [f'{vm_path}|free\n']
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def occupy_vm(self, vm_path, pid, lock_needed=True):
if lock_needed:
with self.lock:
self._occupy_vm(vm_path, pid)
else:
self._occupy_vm(vm_path, pid)
def _occupy_vm(self, vm_path, pid, region=None):
assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'."
with self.lock:
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
registered_vm_path, _ = line.strip().split('|')
if registered_vm_path == vm_path:
new_lines.append(f'{registered_vm_path}|{pid}\n')
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def delete_vm(self, vm_path, lock_needed=True):
if lock_needed:
with self.lock:
self._delete_vm(vm_path)
else:
self._delete_vm(vm_path)
def _delete_vm(self, vm_path):
raise NotImplementedError
def check_and_clean(self, vms_dir, lock_needed=True):
if lock_needed:
with self.lock:
self._check_and_clean(vms_dir)
else:
self._check_and_clean(vms_dir)
def _check_and_clean(self, vms_dir):
with self.lock: # Lock when cleaning up the registry and vms_dir
# Check and clean on the running vms, detect the released ones and mark then as 'free'
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_paths = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path, pid_str = line.strip().split('|')
if not os.path.exists(vm_path):
logger.info(f"VM {vm_path} not found, releasing it.")
new_lines.append(f'{vm_path}|free\n')
continue
vm_paths.append(vm_path)
if pid_str == "free":
new_lines.append(line)
continue
if int(pid_str) in active_pids:
new_lines.append(line)
else:
new_lines.append(f'{vm_path}|free\n')
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
# Check and clean on the files inside vms_dir, delete the unregistered ones
os.makedirs(vms_dir, exist_ok=True)
vm_names = os.listdir(vms_dir)
for vm_name in vm_names:
# skip the downloaded .zip file
if vm_name == DOWNLOADED_FILE_NAME:
continue
# Skip the .DS_Store file on macOS
if vm_name == ".DS_Store":
continue
flag = True
for vm_path in vm_paths:
if vm_name + ".qcow2" in vm_path:
flag = False
elif vm_name + ".img" in vm_path:
flag = False
if flag:
shutil.rmtree(os.path.join(vms_dir, vm_name))
def list_free_vms(self, lock_needed=True):
if lock_needed:
with self.lock:
return self._list_free_vms()
else:
return self._list_free_vms()
def _list_free_vms(self):
with self.lock: # Lock when reading the registry
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path, pid_str = line.strip().split('|')
if pid_str == "free":
free_vms.append((vm_path, pid_str))
return free_vms
def get_vm_path(self, os_type, region=None):
with self.lock:
if not DockerVMManager.checked_and_cleaned:
DockerVMManager.checked_and_cleaned = True
self._check_and_clean(vms_dir=VMS_DIR)
allocation_needed = False
with self.lock:
free_vms_paths = self._list_free_vms()
if len(free_vms_paths) == 0:
# No free virtual machine available, generate a new one
allocation_needed = True
else:
# Choose the first free virtual machine
chosen_vm_path = free_vms_paths[0][0]
self._occupy_vm(chosen_vm_path, os.getpid())
return chosen_vm_path
if allocation_needed:
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR, os_type=os_type)
original_vm_name = None
if os_type == "Ubuntu":
original_vm_name = "Ubuntu"
elif os_type == "Windows":
original_vm_name = "Windows 10 x64"
new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR,
downloaded_file_name=DOWNLOADED_FILE_NAME, original_vm_name=original_vm_name, os_type=os_type)
with self.lock:
self._add_vm(new_vm_path)
self._occupy_vm(new_vm_path, os.getpid())
return new_vm_path

View File

@@ -0,0 +1,64 @@
import logging
import os
import platform
import subprocess
import time
import docker
from desktop_env.providers.base import Provider
logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider")
logger.setLevel(logging.INFO)
WAIT_TIME = 3
def get_vmrun_type(return_list=False):
if platform.system() == 'Windows' or platform.system() == 'Linux':
if return_list:
return ['-T', 'ws']
else:
return '-T ws'
elif platform.system() == 'Darwin': # Darwin is the system name for macOS
if return_list:
return ['-T', 'fusion']
else:
return '-T fusion'
else:
raise Exception("Unsupported operating system")
class DockerProvider(Provider):
def __init__(self, region: str):
self.client = docker.from_env()
@staticmethod
def _execute_command(command: list, return_output=False):
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8"
)
if return_output:
output = process.communicate()[0].strip()
return output
else:
return None
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
pass
def get_ip_address(self, path_to_vm: str) -> str:
pass
def save_state(self, path_to_vm: str, snapshot_name: str):
pass
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
pass
def stop_emulator(self, path_to_vm: str):
pass