From 0b3e7dca248a36221ecf97458f8cd3328626a636 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Sun, 21 Apr 2024 19:51:15 +0800 Subject: [PATCH] Add support for automatic VM download and configuration, enable auto-scaling management; move metadata retrieval out of the init function to speed up environment initialization. --- README.md | 10 +- desktop_env/envs/__init__.py | 335 ++++++++++++++++++++++++++++++++ desktop_env/envs/desktop_env.py | 23 +-- setup_vm.py | 143 -------------- 4 files changed, 349 insertions(+), 162 deletions(-) delete mode 100644 setup_vm.py diff --git a/README.md b/README.md index 5706ede..4ffb97a 100644 --- a/README.md +++ b/README.md @@ -64,10 +64,7 @@ vmrun -T ws list ``` If the installation along with the environment variable set is successful, you will see the message showing the current running virtual machines. -3. Run our setup script to download the necessary virtual machines and set up the environment☕: -```bash -python setup_vm.py -``` +All set! Our setup script will automatically download the necessary virtual machines and configure the environment for you. ### On AWS or Azure (Virtualized platform) We are working on supporting it 👷. Please hold tight! @@ -108,10 +105,7 @@ example = { } } -env = DesktopEnv( - path_to_vm=r"Ubuntu/DesktopEnv-Ubuntu 64-bit Arm.vmx", - action_space="pyautogui" -) +env = DesktopEnv(action_space="pyautogui") obs = env.reset(task_config=example) obs, reward, done, info = env.step("pyautogui.rightClick()") diff --git a/desktop_env/envs/__init__.py b/desktop_env/envs/__init__.py index e69de29..3f54d12 100644 --- a/desktop_env/envs/__init__.py +++ b/desktop_env/envs/__init__.py @@ -0,0 +1,335 @@ +import os +import platform +import random +import re +import subprocess +import threading +import uuid +import zipfile +from time import sleep + +import psutil +import requests +from tqdm import tqdm + +__version__ = "0.1.6" + +MAX_RETRY_TIMES = 10 +UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_arm/resolve/main/Ubuntu.zip" +UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_x86/resolve/main/Ubuntu.zip" +REGISTRY_PATH = '.vms' +REGISTRY_IDX_PATH = ".vms_idx" +update_lock = threading.Lock() + + +class VirtualMachineManager: + def __init__(self, registry_path=REGISTRY_PATH, registry_idx_path=REGISTRY_IDX_PATH): + self.registry_path = registry_path + self.registry_idx_path = registry_idx_path + self.lock = threading.Lock() + self.initialize_registry() + + def initialize_registry(self): + with self.lock: # Locking during initialization + if not os.path.exists(self.registry_path): + with open(self.registry_path, 'w') as file: + file.write('') + if not os.path.exists(self.registry_idx_path): + with open(self.registry_idx_path, 'w') as file: + file.write('0') + + def add_vm(self, vm_path): + with self.lock: + with open(self.registry_path, 'r') as file: + lines = file.readlines() + new_lines = lines + [f'{vm_path}|free\n'] + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def occupy_vm(self, vm_path, pid): + with self.lock: + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + registered_vm_path, _ = line.strip().split('|') + if registered_vm_path == vm_path: + new_lines.append(f'{registered_vm_path}|{pid}\n') + else: + new_lines.append(line) + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def release_vm(self, vm_path): + with self.lock: # Lock when modifying the registry + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + registered_vm_path, _ = line.strip().split('|') + if registered_vm_path != vm_path: + new_lines.append(line) + else: + new_lines.append(f'{registered_vm_path}|free\n') + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def check_and_clean(self): + with self.lock: # Lock when cleaning up the registry + active_pids = {p.pid for p in psutil.process_iter()} + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if pid_str == "free": + new_lines.append(line) + continue + + if int(pid_str) in active_pids: + new_lines.append(line) + else: + new_lines.append(f'{vm_path}|free\n') + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def list_vms(self): + with self.lock: # Lock when reading the registry + all_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + all_vms.append((vm_path, pid_str)) + return all_vms + + def list_free_vms(self): + with self.lock: # Lock when reading the registry + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if pid_str == "free": + free_vms.append((vm_path, pid_str)) + return free_vms + + def generate_new_vm_name(self): + with self.lock: # Lock when generating a new path + with open(self.registry_idx_path, 'r') as file: + idx = int(file.read()) + + new_name = f"Ubuntu{idx}" + + with open(self.registry_idx_path, 'w') as file: + file.write(str(idx + 1)) + + return new_name + + +def _update_vm(vmx_path, target_vm_name): + """Update the VMX file with the new VM name and other parameters, so that the VM can be started successfully without conflict with the original VM.""" + with update_lock: + dir_path, vmx_file = os.path.split(vmx_path) + + def _generate_mac_address(): + # VMware MAC address range starts with 00:0c:29 + mac = [0x00, 0x0c, 0x29, + random.randint(0x00, 0x7f), + random.randint(0x00, 0xff), + random.randint(0x00, 0xff)] + return ':'.join(map(lambda x: "%02x" % x, mac)) + + # Backup the original file + with open(vmx_path, 'r') as file: + original_content = file.read() + + # Generate new values + new_uuid_bios = str(uuid.uuid4()) + new_uuid_location = str(uuid.uuid4()) + new_mac_address = _generate_mac_address() + new_vmci_id = str(random.randint(-2147483648, 2147483647)) # Random 32-bit integer + + # Update the content + updated_content = re.sub(r'displayName = ".*?"', f'displayName = "{target_vm_name}"', original_content) + updated_content = re.sub(r'uuid.bios = ".*?"', f'uuid.bios = "{new_uuid_bios}"', updated_content) + updated_content = re.sub(r'uuid.location = ".*?"', f'uuid.location = "{new_uuid_location}"', updated_content) + updated_content = re.sub(r'ethernet0.generatedAddress = ".*?"', + f'ethernet0.generatedAddress = "{new_mac_address}"', + updated_content) + updated_content = re.sub(r'vmci0.id = ".*?"', f'vmci0.id = "{new_vmci_id}"', updated_content) + + # Write the updated content back to the file + with open(vmx_path, 'w') as file: + file.write(updated_content) + + print(".vmx file updated successfully.") + + vmx_file_base_name = os.path.splitext(vmx_file)[0] + + assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'." + files_to_rename = ['vmx', 'nvram', 'vmsd', 'vmxf'] + + for ext in files_to_rename: + original_file = os.path.join(dir_path, f"{vmx_file_base_name}.{ext}") + target_file = os.path.join(dir_path, f"{target_vm_name}.{ext}") + os.rename(original_file, target_file) + + # Update the dir_path to the target vm_name, only replace the last character + # Split the path into parts up to the last folder + path_parts = dir_path.rstrip(os.sep).split(os.sep) + path_parts[-1] = target_vm_name + target_dir_path = os.sep.join(path_parts) + os.rename(dir_path, target_dir_path) + + print("VM files renamed successfully.") + + +def _install_virtual_machine(vm_name, working_dir="./vm_data", downloaded_file_name="Ubuntu.zip", original_vm_name="Ubuntu"): + os.makedirs(working_dir, exist_ok=True) + def __download_and_unzip_vm(): + # Determine the platform and CPU architecture to decide the correct VM image to download + if platform.machine() == 'arm64': # macOS with Apple Silicon + url = UBUNTU_ARM_URL + elif platform.machine().lower() in ['amd64', "x86_64"]: + url = UBUNTU_X86_URL + else: + raise Exception("Unsupported platform or architecture") + + # Download the virtual machine image + print("Downloading the virtual machine image...") + downloaded_size = 0 + + while True: + downloaded_file_path = os.path.join(working_dir, downloaded_file_name) + headers = {} + if os.path.exists(downloaded_file_path): + downloaded_size = os.path.getsize(downloaded_file_path) + headers["Range"] = f"bytes={downloaded_size}-" + + with requests.get(url, headers=headers, stream=True) as response: + if response.status_code == 416: + # This means the range was not satisfiable, possibly the file was fully downloaded + print("Fully downloaded or the file sized changed.") + break + + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + + with open(downloaded_file_path, "ab") as file, tqdm( + desc="Progress", + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + initial=downloaded_size, + ascii=True + ) as progress_bar: + try: + for data in response.iter_content(chunk_size=1024): + size = file.write(data) + progress_bar.update(size) + except (requests.exceptions.RequestException, IOError) as e: + print(f"Download error: {e}") + sleep(1) # Wait for 1 second before retrying + print("Retrying...") + else: + print("Download succeeds.") + break # Download completed successfully + + # Unzip the downloaded file + print("Unzipping the downloaded file...☕️") + with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref: + zip_ref.extractall(os.path.join(working_dir, vm_name)) + print("Files have been successfully extracted to the directory:", os.path.join(working_dir, vm_name)) + + vm_path = os.path.join(working_dir, vm_name, vm_name, vm_name + ".vmx") + + # Execute the function to download and unzip the VM, and update the vm metadata + if not os.path.exists(vm_path): + __download_and_unzip_vm() + _update_vm(os.path.join(working_dir, vm_name, original_vm_name, original_vm_name + ".vmx"), vm_name) + else: + print(f"Virtual machine exists: {vm_path}") + + # Determine the platform of the host machine and decide the parameter for vmrun + def get_vmrun_type(): + if platform.system() == 'Windows' or platform.system() == 'Linux': + return '-T ws' + elif platform.system() == 'Darwin': # Darwin is the system name for macOS + return '-T fusion' + else: + raise Exception("Unsupported operating system") + + # Start the virtual machine + subprocess.run(f'vmrun {get_vmrun_type()} start "{vm_path}" nogui', shell=True) + print("Starting virtual machine...") + + # Get the IP address of the virtual machine + for i in range(MAX_RETRY_TIMES): + get_vm_ip = subprocess.run(f'vmrun {get_vmrun_type()} getGuestIPAddress "{vm_path}" -wait', shell=True, + capture_output=True, + text=True) + if "Error" in get_vm_ip.stdout: + print("Retry on getting IP") + continue + print("Virtual machine IP address:", get_vm_ip.stdout.strip()) + break + + vm_ip = get_vm_ip.stdout.strip() + + def is_url_accessible(url, timeout=1): + try: + response = requests.head(url, timeout=timeout) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + url = f"http://{vm_ip}:5000/screenshot" + check_url = is_url_accessible(url) + + # Function used to check whether the virtual machine is ready + def download_screenshot(ip): + url = f"http://{ip}:5000/screenshot" + try: + # max trey times 1, max timeout 1 + response = requests.get(url, timeout=(1, 1)) + if response.status_code == 200: + return True + except Exception as e: + print(f"Error: {e}") + print(f"Type: {type(e).__name__}") + print(f"Error detail: {str(e)}") + sleep(2) + return False + + # Try downloading the screenshot until successful + while not download_screenshot(vm_ip): + print("Check whether the virtual machine is ready...") + + print("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...") + + # Create a snapshot of the virtual machine + subprocess.run(f'vmrun {get_vmrun_type()} snapshot "{vm_path}" "init_state"', shell=True) + print("Snapshot created.") + + return vm_path + + +def _get_vm_path(): + vm_manager = VirtualMachineManager(REGISTRY_PATH) + vm_manager.check_and_clean() + free_vms_paths = vm_manager.list_free_vms() + if len(free_vms_paths) == 0: + # No free virtual machine available, generate a new one + print("No free virtual machine available. Generating a new one, which would take a while...☕") + new_vm_name = vm_manager.generate_new_vm_name() + new_vm_path = _install_virtual_machine(new_vm_name) + vm_manager.add_vm(new_vm_path) + vm_manager.occupy_vm(new_vm_path, os.getpid()) + return new_vm_path + else: + # Choose the first free virtual machine + chosen_vm_path = free_vms_paths[0][0] + vm_manager.occupy_vm(chosen_vm_path, os.getpid()) + return chosen_vm_path diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 5123a07..125679a 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -12,6 +12,7 @@ import gymnasium as gym from desktop_env.controllers.python import PythonController from desktop_env.controllers.setup import SetupController from desktop_env.evaluators import metrics, getters +from . import _get_vm_path logger = logging.getLogger("desktopenv.env") @@ -45,7 +46,7 @@ class DesktopEnv(gym.Env): def __init__( self, - path_to_vm: str, + path_to_vm: str = None, snapshot_name: str = "init_state", action_space: str = "computer_13", cache_dir: str = "cache", @@ -68,10 +69,10 @@ class DesktopEnv(gym.Env): """ # Initialize environment variables - self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) + self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm if path_to_vm else _get_vm_path()))) self.snapshot_name = snapshot_name self.cache_dir_base: str = cache_dir - self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM + # todo: add the logic to get the screen size from the VM self.headless = headless self.require_a11y_tree = require_a11y_tree self.require_terminal = require_terminal @@ -83,10 +84,6 @@ class DesktopEnv(gym.Env): self.controller = PythonController(vm_ip=self.vm_ip) self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base) - # Meta info of the VM - self.vm_platform: str = self.controller.get_vm_platform() - self.vm_screen_size = self.controller.get_vm_screen_size() - # mode: human or machine self.instruction = None assert action_space in ["computer_13", "pyautogui"] @@ -98,6 +95,14 @@ class DesktopEnv(gym.Env): self._step_no: int = 0 self.action_history: List[Dict[str, any]] = [] + @property + def vm_platform(self): + return self.controller.get_vm_platform() + + @property + def vm_screen_size(self): + return self.controller.get_vm_screen_size() + def _start_emulator(self): while True: try: @@ -229,10 +234,6 @@ class DesktopEnv(gym.Env): self._start_emulator() logger.info("Emulator started.") - logger.info("Get meta info of the VM...") - self.vm_platform = self.controller.get_vm_platform() - self.vm_screen_size = self.controller.get_vm_screen_size() - logger.info("Setting up environment...") self.setup_controller.setup(self.config) diff --git a/setup_vm.py b/setup_vm.py deleted file mode 100644 index 1f3202b..0000000 --- a/setup_vm.py +++ /dev/null @@ -1,143 +0,0 @@ -import os -import platform -import subprocess -import requests -from tqdm import tqdm -import zipfile -from time import sleep - -import socket - -# Define the path to the virtual machine -VM_PATH = r"Ubuntu\Ubuntu.vmx" # change this to the path of your downloaded virtual machine -MAX_RETRY_TIMES = 10 - - -def download_and_unzip_vm(): - # Determine the platform and CPU architecture to decide the correct VM image to download - if platform.machine() == 'arm64': # macOS with Apple Silicon - url = "https://huggingface.co/datasets/xlangai/ubuntu_arm/resolve/main/Ubuntu.zip" - elif platform.machine().lower() in ['amd64', "x86_64"]: - url = "https://huggingface.co/datasets/xlangai/ubuntu_x86/resolve/main/Ubuntu.zip" - else: - raise Exception("Unsupported platform or architecture") - - # Download the virtual machine image - print("Downloading the virtual machine image...") - filename = "Ubuntu.zip" - downloaded_size = 0 - - while True: - headers = {} - if os.path.exists(filename): - downloaded_size = os.path.getsize(filename) - headers["Range"] = f"bytes={downloaded_size}-" - - with requests.get(url, headers=headers, stream=True) as response: - response.raise_for_status() - total_size = int(response.headers.get('content-length', 0)) - - with open(filename, "ab") as file, tqdm( - desc="Progress", - total=total_size, - unit='iB', - unit_scale=True, - unit_divisor=1024, - initial=downloaded_size, - ascii=True - ) as progress_bar: - try: - for data in response.iter_content(chunk_size=1024): - size = file.write(data) - progress_bar.update(size) - except (requests.exceptions.RequestException, IOError) as e: - print(f"Download error: {e}") - sleep(1) # Wait for 1 second before retrying - print("Retrying...") - else: - print("Download succeeds.") - break # Download completed successfully - - # Unzip the downloaded file - print("Unzipping the downloaded file...☕️") - current_directory = os.getcwd() - with zipfile.ZipFile('Ubuntu.zip', 'r') as zip_ref: - zip_ref.extractall(current_directory) - print("Files have been successfully extracted to the current working directory:", current_directory) - - -# Execute the function to download and unzip the VM -if not os.path.exists(VM_PATH): - download_and_unzip_vm() -else: - print(f"Virtual machine exists: {VM_PATH}") - - -# Determine the platform of the host machine and decide the parameter for vmrun -def get_vmrun_type(): - if platform.system() == 'Windows' or platform.system() == 'Linux': - return '-T ws' - elif platform.system() == 'Darwin': # Darwin is the system name for macOS - return '-T fusion' - else: - raise Exception("Unsupported operating system") - - -# Start the virtual machine -subprocess.run(f'vmrun {get_vmrun_type()} start "{VM_PATH}" nogui', shell=True) -print("Starting virtual machine...") - -# Get the IP address of the virtual machine -for i in range(MAX_RETRY_TIMES): - get_vm_ip = subprocess.run(f'vmrun {get_vmrun_type()} getGuestIPAddress "{VM_PATH}" -wait', shell=True, - capture_output=True, - text=True) - if "Error" in get_vm_ip.stdout: - print("Retry on getting IP") - continue - print("Virtual machine IP address:", get_vm_ip.stdout.strip()) - break - -vm_ip = get_vm_ip.stdout.strip() - - -def is_url_accessible(url, timeout=1): - try: - response = requests.head(url, timeout=timeout) - return response.status_code == 200 - except requests.exceptions.RequestException: - return False - - -print("--------------------------------") -url = f"http://{vm_ip}:5000/screenshot" -ckeck_url = is_url_accessible(url) -print(f"check url: {url} | is accessible: {ckeck_url}") -print("--------------------------------") - - -# Function used to check whether the virtual machine is ready -def download_screenshot(ip): - url = f"http://{ip}:5000/screenshot" - try: - # max trey times 1, max timeout 1 - response = requests.get(url, timeout=(1, 1)) - if response.status_code == 200: - return True - except Exception as e: - print(f"Error: {e}") - print(f"Type: {type(e).__name__}") - print(f"Error detail: {str(e)}") - sleep(2) - return False - - -# Try downloading the screenshot until successful -while not download_screenshot(vm_ip): - print("Check whether the virtual machine is ready...") - -print("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...") - -# Create a snapshot of the virtual machine -subprocess.run(f'vmrun {get_vmrun_type()} snapshot "{VM_PATH}" "init_state"', shell=True) -print("Snapshot created.")