From 19106467f8719a767e717832a7171d245e48f23b Mon Sep 17 00:00:00 2001 From: HappySix <33394488+FredWuCZ@users.noreply.github.com> Date: Mon, 17 Jun 2024 22:46:04 +0800 Subject: [PATCH] VirtualBox (#46) * Initailize aws support * Add README for the VM server * Refactor OSWorld for supporting more cloud services. * Initialize vmware and aws implementation v1, waiting for verification * Initlize files for azure, gcp and virtualbox support * Debug on the VMware provider * Fix on aws interface mapping * Fix instance type * Refactor * Clean * Add Azure provider * hk region; debug * Fix lock * Remove print * Remove key_name requirements when allocating aws vm * Clean README * Fix reset * Fix bugs * Add VirtualBox and Azure providers * Add VirtualBox OVF link * Raise exception on macOS host * Init RAEDME for VBox * Update VirtualBox VM download link * Update requirements and setup.py; Improve robustness on Windows * Fix network adapter * Go through on Windows machine * Add default adapter option * Fix minor error --------- Co-authored-by: Timothyxxx <384084775@qq.com> Co-authored-by: XinyuanWangCS Co-authored-by: Tianbao Xie <47296835+Timothyxxx@users.noreply.github.com> --- desktop_env/desktop_env.py | 9 +- desktop_env/providers/__init__.py | 9 +- desktop_env/providers/aws/AWS_GUIDELINE.md | 8 +- desktop_env/providers/azure/manager.py | 85 ++++ desktop_env/providers/azure/provider.py | 205 +++++++++ .../providers/virtualbox/INSTALL_VITUALBOX.md | 11 + desktop_env/providers/virtualbox/manager.py | 398 ++++++++++++++++++ desktop_env/providers/virtualbox/provider.py | 121 ++++++ desktop_env/providers/vmware/manager.py | 13 +- main.py | 10 +- requirements.txt | 4 + setup.py | 6 +- 12 files changed, 864 insertions(+), 15 deletions(-) create mode 100644 desktop_env/providers/virtualbox/INSTALL_VITUALBOX.md diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 9a47275..dc102c8 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -26,7 +26,7 @@ class DesktopEnv(gym.Env): def __init__( self, - provider_name: str = "vmware", + provider_name: str = "virtualbox", region: str = None, path_to_vm: str = None, snapshot_name: str = "init_state", @@ -55,8 +55,11 @@ class DesktopEnv(gym.Env): self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) # Initialize environment variables - self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) if path_to_vm else \ - self.manager.get_vm_path(region) + if path_to_vm: + self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \ + if provider_name in {"vmware", "virtualbox"} else path_to_vm + else: + self.path_to_vm = self.manager.get_vm_path(region) self.snapshot_name = snapshot_name self.cache_dir_base: str = cache_dir diff --git a/desktop_env/providers/__init__.py b/desktop_env/providers/__init__.py index 9a91620..8c12197 100644 --- a/desktop_env/providers/__init__.py +++ b/desktop_env/providers/__init__.py @@ -3,7 +3,10 @@ from desktop_env.providers.vmware.manager import VMwareVMManager from desktop_env.providers.vmware.provider import VMwareProvider from desktop_env.providers.aws.manager import AWSVMManager from desktop_env.providers.aws.provider import AWSProvider - +from desktop_env.providers.azure.manager import AzureVMManager +from desktop_env.providers.azure.provider import AzureProvider +from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager +from desktop_env.providers.virtualbox.provider import VirtualBoxProvider def create_vm_manager_and_provider(provider_name: str, region: str): """ @@ -12,7 +15,11 @@ def create_vm_manager_and_provider(provider_name: str, region: str): provider_name = provider_name.lower().strip() if provider_name == "vmware": return VMwareVMManager(), VMwareProvider(region) + elif provider_name == "virtualbox": + return VirtualBoxVMManager(), VirtualBoxProvider(region) elif provider_name in ["aws", "amazon web services"]: return AWSVMManager(), AWSProvider(region) + elif provider_name == "azure": + return AzureVMManager(), AzureProvider(region) else: raise NotImplementedError(f"{provider_name} not implemented!") diff --git a/desktop_env/providers/aws/AWS_GUIDELINE.md b/desktop_env/providers/aws/AWS_GUIDELINE.md index bab8a0b..1029b43 100644 --- a/desktop_env/providers/aws/AWS_GUIDELINE.md +++ b/desktop_env/providers/aws/AWS_GUIDELINE.md @@ -1,4 +1,6 @@ -# README for AWS VM Management +# ☁ Configuration of AWS + +--- Welcome to the AWS VM Management documentation. Before you proceed with using the code to manage AWS services, please ensure the following variables are set correctly according to your AWS environment. @@ -9,8 +11,8 @@ You need to assign values to several variables crucial for the operation of thes - Example: `'.aws_vms'` - **`DEFAULT_REGION`**: Default AWS region where your instances will be launched. - Example: `"us-east-1"` -- **`IMAGE_ID_MAP`**: Dictionary mapping regions to specific AMI IDs that should be used for instance creation. - - Example: +- **`IMAGE_ID_MAP`**: Dictionary mapping regions to specific AMI IDs that should be used for instance creation. Here we already set the AMI id to the official OSWorld image of Ubuntu supported by us. + - Formatted as follows: ```python IMAGE_ID_MAP = { "us-east-1": "ami-09bab251951b4272c", diff --git a/desktop_env/providers/azure/manager.py b/desktop_env/providers/azure/manager.py index e69de29..6076511 100644 --- a/desktop_env/providers/azure/manager.py +++ b/desktop_env/providers/azure/manager.py @@ -0,0 +1,85 @@ +import os +import threading +import boto3 +import psutil + +import logging + +from desktop_env.providers.base import VMManager + +logger = logging.getLogger("desktopenv.providers.azure.AzureVMManager") +logger.setLevel(logging.INFO) + +REGISTRY_PATH = '.azure_vms' + + +def _allocate_vm(region): + raise NotImplementedError + + +class AzureVMManager(VMManager): + def __init__(self, registry_path=REGISTRY_PATH): + self.registry_path = registry_path + self.lock = threading.Lock() + self.initialize_registry() + + def initialize_registry(self): + with self.lock: # Locking during initialization + if not os.path.exists(self.registry_path): + with open(self.registry_path, 'w') as file: + file.write('') + + def add_vm(self, vm_path, region): + with self.lock: + with open(self.registry_path, 'r') as file: + lines = file.readlines() + vm_path_at_vm_region = "{}@{}".format(vm_path, region) + new_lines = lines + [f'{vm_path_at_vm_region}|free\n'] + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def occupy_vm(self, vm_path, pid, region): + with self.lock: + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + registered_vm_path, _ = line.strip().split('|') + if registered_vm_path == "{}@{}".format(vm_path, region): + new_lines.append(f'{registered_vm_path}|{pid}\n') + else: + new_lines.append(line) + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def check_and_clean(self): + raise NotImplementedError + + def list_free_vms(self, region): + with self.lock: # Lock when reading the registry + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path_at_vm_region, pid_str = line.strip().split('|') + vm_path, vm_region = vm_path_at_vm_region.split("@") + if pid_str == "free" and vm_region == region: + free_vms.append((vm_path, pid_str)) + return free_vms + + def get_vm_path(self, region): + self.check_and_clean() + free_vms_paths = self.list_free_vms(region) + if len(free_vms_paths) == 0: + # No free virtual machine available, generate a new one + logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") + new_vm_path = _allocate_vm(region) + self.add_vm(new_vm_path, region) + self.occupy_vm(new_vm_path, os.getpid(), region) + return new_vm_path + else: + # Choose the first free virtual machine + chosen_vm_path = free_vms_paths[0][0] + self.occupy_vm(chosen_vm_path, os.getpid(), region) + return chosen_vm_path + diff --git a/desktop_env/providers/azure/provider.py b/desktop_env/providers/azure/provider.py index e69de29..fb43503 100644 --- a/desktop_env/providers/azure/provider.py +++ b/desktop_env/providers/azure/provider.py @@ -0,0 +1,205 @@ +import os +import time +from azure.identity import DefaultAzureCredential +from azure.mgmt.compute import ComputeManagementClient +from azure.mgmt.network import NetworkManagementClient +from azure.core.exceptions import ResourceNotFoundError + +import logging + +from desktop_env.providers.base import Provider + +logger = logging.getLogger("desktopenv.providers.azure.AzureProvider") +logger.setLevel(logging.INFO) + +WAIT_DELAY = 15 +MAX_ATTEMPTS = 10 + +# To use the Azure provider, download azure-cli by https://learn.microsoft.com/en-us/cli/azure/install-azure-cli, +# use "az login" to log into you Azure account, +# and set environment variable "AZURE_SUBSCRIPTION_ID" to your subscription ID. +# Provide your resource group name and VM name in the format "RESOURCE_GROUP_NAME/VM_NAME" and pass as an argument for "-p". + +class AzureProvider(Provider): + def __init__(self, region: str = None): + super().__init__(region) + credential = DefaultAzureCredential() + try: + self.subscription_id = os.environ["AZURE_SUBSCRIPTION_ID"] + except: + logger.error("Azure subscription ID not found. Please set environment variable \"AZURE_SUBSCRIPTION_ID\".") + raise + self.compute_client = ComputeManagementClient(credential, self.subscription_id) + self.network_client = NetworkManagementClient(credential, self.subscription_id) + + def start_emulator(self, path_to_vm: str, headless: bool): + logger.info("Starting Azure VM...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') + power_state = vm.instance_view.statuses[-1].code + if power_state == "PowerState/running": + logger.info("VM is already running.") + return + + try: + # Start the instance + for _ in range(MAX_ATTEMPTS): + async_vm_start = self.compute_client.virtual_machines.begin_start(resource_group_name, vm_name) + logger.info(f"VM {path_to_vm} is starting...") + # Wait for the instance to start + async_vm_start.wait(timeout=WAIT_DELAY) + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') + power_state = vm.instance_view.statuses[-1].code + if power_state == "PowerState/running": + logger.info(f"VM {path_to_vm} is already running.") + break + except Exception as e: + logger.error(f"Failed to start the Azure VM {path_to_vm}: {str(e)}") + raise + + def get_ip_address(self, path_to_vm: str) -> str: + logger.info("Getting Azure VM IP address...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name) + + for interface in vm.network_profile.network_interfaces: + name=" ".join(interface.id.split('/')[-1:]) + sub="".join(interface.id.split('/')[4]) + + try: + thing=self.network_client.network_interfaces.get(sub, name).ip_configurations + + network_card_id = thing[0].public_ip_address.id.split('/')[-1] + public_ip_address = self.network_client.public_ip_addresses.get(resource_group_name, network_card_id) + logger.info(f"VM IP address is {public_ip_address.ip_address}") + return public_ip_address.ip_address + + except Exception as e: + logger.error(f"Cannot get public IP for VM {path_to_vm}") + raise + + def save_state(self, path_to_vm: str, snapshot_name: str): + print("Saving Azure VM state...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name) + + try: + # Backup each disk attached to the VM + for disk in vm.storage_profile.data_disks + [vm.storage_profile.os_disk]: + # Create a snapshot of the disk + snapshot = { + 'location': vm.location, + 'creation_data': { + 'create_option': 'Copy', + 'source_uri': disk.managed_disk.id + } + } + async_snapshot_creation = self.compute_client.snapshots.begin_create_or_update(resource_group_name, snapshot_name, snapshot) + async_snapshot_creation.wait(timeout=WAIT_DELAY) + + logger.info(f"Successfully created snapshot {snapshot_name} for VM {path_to_vm}.") + except Exception as e: + logger.error(f"Failed to create snapshot {snapshot_name} of the Azure VM {path_to_vm}: {str(e)}") + raise + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + logger.info(f"Reverting VM to snapshot: {snapshot_name}...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name) + + # Stop the VM for disk creation + logger.info(f"Stopping VM: {vm_name}") + async_vm_stop = self.compute_client.virtual_machines.begin_deallocate(resource_group_name, vm_name) + async_vm_stop.wait(timeout=WAIT_DELAY) # Wait for the VM to stop + + try: + # Get the snapshot + snapshot = self.compute_client.snapshots.get(resource_group_name, snapshot_name) + + # Get the original disk information + original_disk_id = vm.storage_profile.os_disk.managed_disk.id + disk_name = original_disk_id.split('/')[-1] + if disk_name[-1] in ['0', '1']: + new_disk_name = disk_name[:-1] + str(int(disk_name[-1])^1) + else: + new_disk_name = disk_name + "0" + + # Delete the disk if it exists + self.compute_client.disks.begin_delete(resource_group_name, new_disk_name).wait(timeout=WAIT_DELAY) + + # Make sure the disk is deleted before proceeding to the next step + disk_deleted = False + polling_interval = 10 + attempts = 0 + while not disk_deleted and attempts < MAX_ATTEMPTS: + try: + self.compute_client.disks.get(resource_group_name, new_disk_name) + # If the above line does not raise an exception, the disk still exists + time.sleep(polling_interval) + attempts += 1 + except ResourceNotFoundError: + disk_deleted = True + + if not disk_deleted: + logger.error(f"Disk {new_disk_name} deletion timed out.") + raise + + # Create a new managed disk from the snapshot + snapshot = self.compute_client.snapshots.get(resource_group_name, snapshot_name) + disk_creation = { + 'location': snapshot.location, + 'creation_data': { + 'create_option': 'Copy', + 'source_resource_id': snapshot.id + }, + 'zones': vm.zones if vm.zones else None # Preserve the original disk's zone + } + async_disk_creation = self.compute_client.disks.begin_create_or_update(resource_group_name, new_disk_name, disk_creation) + restored_disk = async_disk_creation.result() # Wait for the disk creation to complete + + vm.storage_profile.os_disk = { + 'create_option': vm.storage_profile.os_disk.create_option, + 'managed_disk': { + 'id': restored_disk.id + } + } + + async_vm_creation = self.compute_client.virtual_machines.begin_create_or_update(resource_group_name, vm_name, vm) + async_vm_creation.wait(timeout=WAIT_DELAY) + + # Delete the original disk + self.compute_client.disks.begin_delete(resource_group_name, disk_name).wait() + + logger.info(f"Successfully reverted to snapshot {snapshot_name}.") + except Exception as e: + logger.error(f"Failed to revert the Azure VM {path_to_vm} to snapshot {snapshot_name}: {str(e)}") + raise + + def stop_emulator(self, path_to_vm, region=None): + logger.info(f"Stopping Azure VM {path_to_vm}...") + resource_group_name, vm_name = path_to_vm.split('/') + + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') + power_state = vm.instance_view.statuses[-1].code + if power_state == "PowerState/deallocated": + print("VM is already stopped.") + return + + try: + for _ in range(MAX_ATTEMPTS): + async_vm_deallocate = self.compute_client.virtual_machines.begin_deallocate(resource_group_name, vm_name) + logger.info(f"Stopping VM {path_to_vm}...") + # Wait for the instance to start + async_vm_deallocate.wait(timeout=WAIT_DELAY) + vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') + power_state = vm.instance_view.statuses[-1].code + if power_state == "PowerState/deallocated": + logger.info(f"VM {path_to_vm} is already stopped.") + break + except Exception as e: + logger.error(f"Failed to stop the Azure VM {path_to_vm}: {str(e)}") + raise diff --git a/desktop_env/providers/virtualbox/INSTALL_VITUALBOX.md b/desktop_env/providers/virtualbox/INSTALL_VITUALBOX.md new file mode 100644 index 0000000..e1d4ae4 --- /dev/null +++ b/desktop_env/providers/virtualbox/INSTALL_VITUALBOX.md @@ -0,0 +1,11 @@ +## 💾 Installation of VirtualBox + + +1. Download the VirtualBox from the [official website](https://www.virtualbox.org/wiki/Downloads). Unfortunately, for Apple chips (M1 chips, M2 chips, etc.), VirtualBox is not supported. You can only use VMware Fusion instead. +2. Install VirtualBox. Just follow the instructions provided by the installer. +For Windows, you also need to append the installation path to the environment variable `PATH` for enabling the `VBoxManage` command. The default installation path is `C:\Program Files\Oracle\VirtualBox`. +3. Verify the successful installation by running the following: + ```bash + VBoxManage --version + ``` + If the installation along with the environment variable set is successful, you will see the version of VirtualBox installed on your system. diff --git a/desktop_env/providers/virtualbox/manager.py b/desktop_env/providers/virtualbox/manager.py index e69de29..4f9e569 100644 --- a/desktop_env/providers/virtualbox/manager.py +++ b/desktop_env/providers/virtualbox/manager.py @@ -0,0 +1,398 @@ +import logging +import os +import platform +import shutil +import subprocess +import threading +import time +import zipfile + +import psutil +import requests +from filelock import FileLock +from tqdm import tqdm + +from desktop_env.providers.base import VMManager + +logger = logging.getLogger("desktopenv.providers.virtualbox.VirtualBoxVMManager") +logger.setLevel(logging.INFO) + +MAX_RETRY_TIMES = 10 +RETRY_INTERVAL = 5 +UBUNTU_ARM_URL = "NOT_AVAILABLE" +UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_x86_virtualbox/resolve/main/Ubuntu.zip" +DOWNLOADED_FILE_NAME = "Ubuntu.zip" +REGISTRY_PATH = '.virtualbox_vms' + +LOCK_FILE_NAME = '.virtualbox_lck' +VMS_DIR = "./virtualbox_vm_data" +update_lock = threading.Lock() + +if platform.system() == 'Windows': + vboxmanage_path = r"C:\Program Files\Oracle\VirtualBox" + os.environ["PATH"] += os.pathsep + vboxmanage_path + + +def generate_new_vm_name(vms_dir): + registry_idx = 0 + while True: + attempted_new_name = f"Ubuntu{registry_idx}" + if os.path.exists( + os.path.join(vms_dir, attempted_new_name, attempted_new_name, attempted_new_name + ".vbox")): + registry_idx += 1 + else: + return attempted_new_name + + +def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu", bridged_adapter_name=None): + os.makedirs(vms_dir, exist_ok=True) + + def __download_and_unzip_vm(): + # Determine the platform and CPU architecture to decide the correct VM image to download + if platform.system() == 'Darwin': # macOS + url = UBUNTU_ARM_URL + raise Exception("MacOS host is not currently supported for VirtualBox.") + elif platform.machine().lower() in ['amd64', 'x86_64']: + url = UBUNTU_X86_URL + else: + raise Exception("Unsupported platform or architecture.") + + # Download the virtual machine image + logger.info("Downloading the virtual machine image...") + downloaded_size = 0 + + while True: + downloaded_file_path = os.path.join(vms_dir, downloaded_file_name) + headers = {} + if os.path.exists(downloaded_file_path): + downloaded_size = os.path.getsize(downloaded_file_path) + headers["Range"] = f"bytes={downloaded_size}-" + + with requests.get(url, headers=headers, stream=True) as response: + if response.status_code == 416: + # This means the range was not satisfiable, possibly the file was fully downloaded + logger.info("Fully downloaded or the file size changed.") + break + + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + + with open(downloaded_file_path, "ab") as file, tqdm( + desc="Progress", + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + initial=downloaded_size, + ascii=True + ) as progress_bar: + try: + for data in response.iter_content(chunk_size=1024): + size = file.write(data) + progress_bar.update(size) + except (requests.exceptions.RequestException, IOError) as e: + logger.error(f"Download error: {e}") + time.sleep(RETRY_INTERVAL) + logger.error("Retrying...") + else: + logger.info("Download succeeds.") + break # Download completed successfully + + # Unzip the downloaded file + logger.info("Unzipping the downloaded file...☕️") + with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref: + zip_ref.extractall(vms_dir) + logger.info("Files have been successfully extracted to the directory: " + vms_dir) + + def import_vm(vms_dir, target_vm_name, max_retries=1): + """Import the .ovf file into VirtualBox.""" + logger.info(f"Starting to import VM {target_vm_name}...") + command = ( + f"VBoxManage import {os.path.abspath(os.path.join(vms_dir, original_vm_name, original_vm_name + '.ovf'))} " + f"--vsys 0 " + f"--vmname {target_vm_name} " + f"--settingsfile {os.path.abspath(os.path.join(vms_dir, target_vm_name, target_vm_name + '.vbox'))} " + f"--basefolder {vms_dir} " + f"--unit 14 " + f"--disk {os.path.abspath(os.path.join(vms_dir, target_vm_name, target_vm_name + '_disk1.vmdk'))}") + + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + errors='ignore') + if result.returncode == 0: + logger.info("Successfully imported VM.") + return True + else: + if not result.stderr or "Error" in result.stderr: + logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + logger.error(f"Attempt {attempt + 1} failed: {result.stderr}") + + if attempt == max_retries - 1: + logger.error("Maximum retry attempts reached, failed to import the virtual machine.") + return False + + def configure_vm_network(vm_name, interface_name=None): + # Config of bridged network + command = f'VBoxManage modifyvm "{vm_name}" --nic1 bridged' + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + errors='ignore') + if not interface_name: + output = subprocess.check_output(f"VBoxManage list bridgedifs", shell=True, stderr=subprocess.STDOUT) + output = output.decode() + output = output.splitlines() + result = [] + for line in output: + entries = line.split() + if entries and entries[0] == "Name:": + name = ' '.join(entries[1:]) + if entries and entries[0] == "IPAddress:": + ip = entries[1] + result.append((name, ip)) + logger.info("Found the following network adapters, default to the first. If you want to change it, please set the argument -r to the name of the adapter.") + for i, (name, ip) in enumerate(result): + logger.info(f"{i+1}: {name} ({ip})") + interface_id = 1 + interface_name = result[interface_id-1][0] + command = f'vboxmanage modifyvm "{vm_name}" --bridgeadapter1 "{interface_name}"' + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + errors='ignore') + if result.returncode == 0: + logger.info(f"Changed to bridge adapter {interface_name}.") + return True + else: + logger.error(f"Failed to change to bridge adapter {interface_name}: {result.stderr}") + return False + + # # Config of NAT network + # command = f"VBoxManage natnetwork add --netname natnet --network {nat_network} --dhcp on" + # result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + # errors='ignore') + # if result.returncode == 0: + # logger.info(f"Created NAT network {nat_network}.") + # else: + # logger.error(f"Failed to create NAT network {nat_network}") + # return False + # command = f"VBoxManage modifyvm {vm_name} --nic1 natnetwork" + # result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + # errors='ignore') + # command = f"VBoxManage modifyvm {vm_name} --natnet1 natnet" + # result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + # errors='ignore') + # if result.returncode == 0: + # logger.info("Switched VM to the NAT network.") + # else: + # logger.error("Failed to switch VM to the NAT network") + # return False + # logger.info("Start to configure port forwarding...") + # command = f"VBoxManage modifyvm {vm_name} --natpf1 'server,tcp,,5000,,5000'" + # result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8", + # errors='ignore') + # if result.returncode == 0: + # logger.info("Successfully created port forwarding rule.") + # return True + # logger.error("Failed to create port forwarding rule.") + # return False + + + vm_path = os.path.join(vms_dir, vm_name, vm_name + ".vbox") + + # Execute the function to download and unzip the VM, and update the vm metadata + if not os.path.exists(vm_path): + __download_and_unzip_vm() + import_vm(vms_dir, vm_name) + if not configure_vm_network(vm_name, bridged_adapter_name): + raise Exception("Failed to configure VM network!") + else: + logger.info(f"Virtual machine exists: {vm_path}") + + # Start the virtual machine + def start_vm(vm_name, max_retries=20): + command = f'VBoxManage startvm "{vm_name}" --type headless' + + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + logger.info("Virtual machine started.") + return True + else: + if not result.stderr or "Error" in result.stderr: + logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + logger.error(f"Attempt {attempt + 1} failed: {result.stderr}") + + if attempt == max_retries - 1: + logger.error("Maximum retry attempts reached, failed to start the virtual machine.") + return False + + if not start_vm(vm_name): + raise ValueError("Error encountered during installation, please rerun the code for retrying.") + + def get_vm_ip(vm_name, max_retries=20): + command = f'VBoxManage guestproperty get "{vm_name}" /VirtualBox/GuestInfo/Net/0/V4/IP' + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + return result.stdout.strip().split()[1] + else: + logger.error(f"Get VM IP failed: {result.stderr}") + return None + + # Function used to check whether the virtual machine is ready + def download_screenshot(): + ip = get_vm_ip(vm_name) + url = f"http://{ip}:5000/screenshot" + try: + # max trey times 1, max timeout 1 + response = requests.get(url, timeout=(10, 10)) + if response.status_code == 200: + return True + except Exception as e: + logger.error(f"Error: {e}") + logger.error(f"Type: {type(e).__name__}") + logger.error(f"Error detail: {str(e)}") + return False + + # Try downloading the screenshot until successful + while not download_screenshot(): + logger.info("Check whether the virtual machine is ready...") + time.sleep(RETRY_INTERVAL) + + logger.info("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...") + + def create_vm_snapshot(vm_name, max_retries=20): + logger.info("Saving VirtualBox VM state...") + command = f'VBoxManage snapshot "{vm_name}" take init_state' + + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8") + if result.returncode == 0: + logger.info("Snapshot created.") + return True + else: + if "Error" in result.stderr: + logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + logger.error(f"Attempt {attempt + 1} failed: {result.stderr}") + + if attempt == max_retries - 1: + logger.error("Maximum retry attempts reached, failed to create snapshot.") + return False + + # Create a snapshot of the virtual machine + if create_vm_snapshot(vm_name, max_retries=MAX_RETRY_TIMES): + return vm_path + else: + raise ValueError("Error encountered during installation, please rerun the code for retrying.") + + +class VirtualBoxVMManager(VMManager): + def __init__(self, registry_path=REGISTRY_PATH): + self.registry_path = registry_path + self.lock = FileLock(LOCK_FILE_NAME, timeout=10) + self.initialize_registry() + + def initialize_registry(self): + with self.lock: # Locking during initialization + if not os.path.exists(self.registry_path): + with open(self.registry_path, 'w') as file: + file.write('') + + def add_vm(self, vm_path, region=None): + assert region in [None, 'local'], "For VirtualBox provider, the region should be neither None or 'local'." + with self.lock: + with open(self.registry_path, 'r') as file: + lines = file.readlines() + new_lines = lines + [f'{vm_path}|free\n'] + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def occupy_vm(self, vm_path, pid, region=None): + assert region in [None, 'local'], "For VirtualBox provider, the region should be neither None or 'local'." + with self.lock: + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + registered_vm_path, _ = line.strip().split('|') + if registered_vm_path == vm_path: + new_lines.append(f'{registered_vm_path}|{pid}\n') + else: + new_lines.append(line) + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def check_and_clean(self, vms_dir): + with self.lock: # Lock when cleaning up the registry and vms_dir + # Check and clean on the running vms, detect the released ones and mark then as 'free' + active_pids = {p.pid for p in psutil.process_iter()} + new_lines = [] + vm_paths = [] + + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if not os.path.exists(vm_path): + logger.info(f"VM {vm_path} not found, releasing it.") + new_lines.append(f'{vm_path}|free\n') + continue + + vm_paths.append(vm_path) + if pid_str == "free": + new_lines.append(line) + continue + + if int(pid_str) in active_pids: + new_lines.append(line) + else: + new_lines.append(f'{vm_path}|free\n') + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + # Check and clean on the files inside vms_dir, delete the unregistered ones + os.makedirs(vms_dir, exist_ok=True) + vm_names = os.listdir(vms_dir) + for vm_name in vm_names: + # skip the downloaded .zip file + if vm_name == DOWNLOADED_FILE_NAME: + continue + # Skip the .DS_Store file on macOS + if vm_name == ".DS_Store": + continue + + flag = True + for vm_path in vm_paths: + if vm_name + ".vbox" in vm_path: + flag = False + if flag: + shutil.rmtree(os.path.join(vms_dir, vm_name)) + + def list_free_vms(self): + with self.lock: # Lock when reading the registry + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path, pid_str = line.strip().split('|') + if pid_str == "free": + free_vms.append((vm_path, pid_str)) + return free_vms + + def get_vm_path(self, region=None): + self.check_and_clean(vms_dir=VMS_DIR) + free_vms_paths = self.list_free_vms() + if len(free_vms_paths) == 0: + # No free virtual machine available, generate a new one + logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") + new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR) + new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR, + downloaded_file_name=DOWNLOADED_FILE_NAME, + bridged_adapter_name=region) + self.add_vm(new_vm_path) + self.occupy_vm(new_vm_path, os.getpid()) + return new_vm_path + else: + # Choose the first free virtual machine + chosen_vm_path = free_vms_paths[0][0] + self.occupy_vm(chosen_vm_path, os.getpid()) + return chosen_vm_path diff --git a/desktop_env/providers/virtualbox/provider.py b/desktop_env/providers/virtualbox/provider.py index e69de29..71e086e 100644 --- a/desktop_env/providers/virtualbox/provider.py +++ b/desktop_env/providers/virtualbox/provider.py @@ -0,0 +1,121 @@ +import logging +import platform +import subprocess +import time +import os +from desktop_env.providers.base import Provider +import xml.etree.ElementTree as ET + +logger = logging.getLogger("desktopenv.providers.virtualbox.VirtualBoxProvider") +logger.setLevel(logging.INFO) + +WAIT_TIME = 3 + +# Note: Windows will not add command VBoxManage to PATH by default. Please add the folder where VBoxManage executable is in (Default should be "C:\Program Files\Oracle\VirtualBox" for Windows) to PATH. + +class VirtualBoxProvider(Provider): + @staticmethod + def _execute_command(command: list): + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True, + encoding="utf-8") + if result.returncode != 0: + raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m") + return result.stdout.strip() + + @staticmethod + def _get_vm_uuid(path_to_vm: str): + try: + output = subprocess.check_output(f"VBoxManage list vms", shell=True, stderr=subprocess.STDOUT) + output = output.decode() + output = output.splitlines() + if path_to_vm.endswith('.vbox'): + # Load and parse the XML content from the file + tree = ET.parse(path_to_vm) + root = tree.getroot() + + # Find the element and retrieve its 'uuid' attribute + machine_element = root.find('.//{http://www.virtualbox.org/}Machine') + if machine_element is not None: + uuid = machine_element.get('uuid')[1:-1] + return uuid + else: + logger.error(f"UUID not found in file {path_to_vm}") + raise + elif any(line.split()[1] == "{" + path_to_vm + "}" for line in output): + logger.info(f"Got valid UUID {path_to_vm}.") + return path_to_vm + else: + for line in output: + if line.split()[0] == '"' + path_to_vm + '"': + uuid = line.split()[1][1:-1] + return uuid + logger.error(f"The path you provided does not match any of the \".vbox\" file, name, or UUID of VM.") + raise + except subprocess.CalledProcessError as e: + logger.error(f"Error executing command: {e.output.decode().strip()}") + + + def start_emulator(self, path_to_vm: str, headless: bool): + print("Starting VirtualBox VM...") + logger.info("Starting VirtualBox VM...") + + while True: + try: + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + output = subprocess.check_output(f"VBoxManage list runningvms", shell=True, stderr=subprocess.STDOUT) + output = output.decode() + output = output.splitlines() + + if any(line.split()[1] == "{" + uuid + "}" for line in output): + logger.info("VM is running.") + break + else: + logger.info("Starting VM...") + VirtualBoxProvider._execute_command(["VBoxManage", "startvm", uuid]) if not headless else \ + VirtualBoxProvider._execute_command( + ["VBoxManage", "startvm", uuid, "--type", "headless"]) + time.sleep(WAIT_TIME) + + except subprocess.CalledProcessError as e: + logger.error(f"Error executing command: {e.output.decode().strip()}") + + def get_ip_address(self, path_to_vm: str) -> str: + logger.info("Getting VirtualBox VM IP address...") + while True: + try: + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + output = VirtualBoxProvider._execute_command( + ["VBoxManage", "guestproperty", "get", uuid, "/VirtualBox/GuestInfo/Net/0/V4/IP"] + ) + result = output.split()[1] + if result != "value": + logger.info(f"VirtualBox VM IP address: {result}") + return result + else: + logger.error("VM IP address not found. Have you installed the guest additions?") + raise + except Exception as e: + logger.error(e) + time.sleep(WAIT_TIME) + logger.info("Retrying to get VirtualBox VM IP address...") + + def save_state(self, path_to_vm: str, snapshot_name: str): + logger.info("Saving VirtualBox VM state...") + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + VirtualBoxProvider._execute_command(["VBoxManage", "snapshot", uuid, "take", snapshot_name]) + time.sleep(WAIT_TIME) # Wait for the VM to save + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + logger.info(f"Reverting VirtualBox VM to snapshot: {snapshot_name}...") + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + VirtualBoxProvider._execute_command(["VBoxManage", "controlvm", uuid, "savestate"]) + time.sleep(WAIT_TIME) # Wait for the VM to stop + VirtualBoxProvider._execute_command(["VBoxManage", "snapshot", uuid, "restore", snapshot_name]) + time.sleep(WAIT_TIME) # Wait for the VM to revert + return path_to_vm + + def stop_emulator(self, path_to_vm: str): + logger.info("Stopping VirtualBox VM...") + uuid = VirtualBoxProvider._get_vm_uuid(path_to_vm) + VirtualBoxProvider._execute_command(["VBoxManage", "controlvm", uuid, "savestate"]) + time.sleep(WAIT_TIME) # Wait for the VM to stop diff --git a/desktop_env/providers/vmware/manager.py b/desktop_env/providers/vmware/manager.py index 03b5c4a..17b76fd 100644 --- a/desktop_env/providers/vmware/manager.py +++ b/desktop_env/providers/vmware/manager.py @@ -23,13 +23,18 @@ logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager") logger.setLevel(logging.INFO) MAX_RETRY_TIMES = 10 +RETRY_INTERVAL = 5 UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_arm/resolve/main/Ubuntu.zip" UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_x86/resolve/main/Ubuntu.zip" DOWNLOADED_FILE_NAME = "Ubuntu.zip" REGISTRY_PATH = '.vmware_vms' +LOCK_FILE_NAME = '.vmware_lck' VMS_DIR = "./vmware_vm_data" update_lock = threading.Lock() +if platform.system() == 'Windows': + vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation" + os.environ["PATH"] += os.pathsep + vboxmanage_path def generate_new_vm_name(vms_dir): registry_idx = 0 @@ -129,7 +134,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu with requests.get(url, headers=headers, stream=True) as response: if response.status_code == 416: # This means the range was not satisfiable, possibly the file was fully downloaded - logger.info("Fully downloaded or the file sized changed.") + logger.info("Fully downloaded or the file size changed.") break response.raise_for_status() @@ -150,7 +155,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu progress_bar.update(size) except (requests.exceptions.RequestException, IOError) as e: logger.error(f"Download error: {e}") - sleep(1) # Wait for 1 second before retrying + sleep(RETRY_INTERVAL) logger.error("Retrying...") else: logger.info("Download succeeds.") @@ -233,7 +238,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu logger.error(f"Error: {e}") logger.error(f"Type: {type(e).__name__}") logger.error(f"Error detail: {str(e)}") - sleep(2) + sleep(RETRY_INTERVAL) return False # Try downloading the screenshot until successful @@ -269,7 +274,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu class VMwareVMManager(VMManager): def __init__(self, registry_path=REGISTRY_PATH): self.registry_path = registry_path - self.lock = FileLock(".vmware_lck", timeout=10) + self.lock = FileLock(LOCK_FILE_NAME, timeout=10) self.initialize_registry() def initialize_registry(self): diff --git a/main.py b/main.py index 6dbbdee..bc87d54 100644 --- a/main.py +++ b/main.py @@ -47,8 +47,10 @@ def human_agent(): Runs the Gym environment with human input. """ parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx", help="Path to the virtual machine .vmx file.") + parser.add_argument('-p', '--path', type=str, default="", help="Path to the virtual machine.") parser.add_argument('-e', '--example', type=str, help="Path to the example json file.") + parser.add_argument('-s', '--snapshot', type=str, help="Name of the snapshot to load.") + parser.add_argument('-r', '--region', type=str, help="(For VirtualBox) Name of the bridged adapter. (For AWS) Name of the region.") args = parser.parse_args(sys.argv[1:]) example_path = args.example if args.example is not None and os.path.exists(args.example) else \ @@ -56,10 +58,12 @@ def human_agent(): with open(example_path, "r", encoding="utf-8") as f: example = json.load(f) - assert os.path.exists(args.path), "The specified path to the .vmx file does not exist." + # assert os.path.exists(args.path), "The specified path to the .vmx file does not exist." env = DesktopEnv( path_to_vm=args.path, - action_space="computer_13" + action_space="computer_13", + snapshot_name=args.snapshot, + region=args.region ) # reset the environment to certain snapshot observation = env.reset(task_config=example) diff --git a/requirements.txt b/requirements.txt index 2cc96ab..439f85b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,3 +53,7 @@ wrapt_timeout_decorator gdown tiktoken groq +boto3 +azure-identity +azure-mgmt-compute +azure-mgmt-network diff --git a/setup.py b/setup.py index c4a5c06..1968141 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,11 @@ setup( "borb", "pypdf2", "pdfplumber", - "wrapt_timeout_decorator" + "wrapt_timeout_decorator", + "boto3", + "azure-identity", + "azure-mgmt-compute", + "azure-mgmt-network", ], cmdclass={ 'install': InstallPlaywrightCommand, # Use the custom install command