Refactoring VMware Integration and Implementing AWS Support (#44)

* Initailize aws support

* Add README for the VM server

* Refactor OSWorld for supporting more cloud services.

* Initialize vmware and aws implementation v1, waiting for verification

* Initlize files for azure, gcp and virtualbox support

* Debug on the VMware provider

* Fix on aws interface mapping

* Fix instance type

* Refactor

* Clean

* hk region; debug

* Fix lock

* Remove print

* Remove key_name requirements when allocating aws vm

* Clean README

---------

Co-authored-by: XinyuanWangCS <xywang626@gmail.com>
This commit is contained in:
Tianbao Xie
2024-06-15 20:52:29 +08:00
committed by GitHub
parent c121869219
commit fffa8f8da6
31 changed files with 847 additions and 302 deletions

View File

View File

@@ -0,0 +1,18 @@
from desktop_env.providers.base import VMManager, Provider
from desktop_env.providers.vmware.manager import VMwareVMManager
from desktop_env.providers.vmware.provider import VMwareProvider
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
def create_vm_manager_and_provider(provider_name: str, region: str):
"""
Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name.
"""
provider_name = provider_name.lower().strip()
if provider_name == "vmware":
return VMwareVMManager(), VMwareProvider(region)
elif provider_name in ["aws", "amazon web services"]:
return AWSVMManager(), AWSProvider(region)
else:
raise NotImplementedError(f"{provider_name} not implemented!")

View File

@@ -0,0 +1,57 @@
# README for AWS VM Management
Welcome to the AWS VM Management documentation. Before you proceed with using the code to manage AWS services, please ensure the following variables are set correctly according to your AWS environment.
## Configuration Variables
You need to assign values to several variables crucial for the operation of these scripts on AWS:
- **`REGISTRY_PATH`**: Sets the file path for VM registration logging.
- Example: `'.aws_vms'`
- **`DEFAULT_REGION`**: Default AWS region where your instances will be launched.
- Example: `"us-east-1"`
- **`IMAGE_ID_MAP`**: Dictionary mapping regions to specific AMI IDs that should be used for instance creation.
- Example:
```python
IMAGE_ID_MAP = {
"us-east-1": "ami-09bab251951b4272c",
# Add other regions and corresponding AMIs
}
```
- **`INSTANCE_TYPE`**: Specifies the type of EC2 instance to be launched.
- Example: `"t3.medium"`
- **`KEY_NAME`**: Specifies the name of the key pair to be used for the instances.
- Example: `"osworld_key"`
- **`NETWORK_INTERFACES`**: Configuration settings for network interfaces, which include subnet IDs, security group IDs, and public IP addressing.
- Example:
```python
NETWORK_INTERFACES = {
"us-east-1": [
{
"SubnetId": "subnet-037edfff66c2eb894",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": ["sg-0342574803206ee9c"]
}
],
# Add configurations for other regions
}
```
### AWS CLI Configuration
Before using these scripts, you must configure your AWS CLI with your credentials. This can be done via the following commands:
```bash
aws configure
```
This command will prompt you for:
- AWS Access Key ID
- AWS Secret Access Key
- Default region name (Optional, you can press enter)
Enter your credentials as required. This setup will allow you to interact with AWS services using the credentials provided.
### Disclaimer
Use the provided scripts and configurations at your own risk. Ensure that you understand the AWS pricing model and potential costs associated with deploying instances, as using these scripts might result in charges on your AWS account.
> **Note:** Ensure all AMI images used in `IMAGE_ID_MAP` are accessible and permissioned correctly for your AWS account, and that they are available in the specified region.

View File

View File

@@ -0,0 +1,179 @@
import os
from filelock import FileLock
import boto3
import psutil
import logging
from desktop_env.providers.base import VMManager
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
logger.setLevel(logging.INFO)
REGISTRY_PATH = '.aws_vms'
DEFAULT_REGION = "us-east-1"
# todo: Add doc for the configuration of image, security group and network interface
# todo: public the AMI images
IMAGE_ID_MAP = {
"us-east-1": "ami-0b0531325a0d5d488",
"ap-east-1": "ami-0b92a0bf157fecaa9"
}
INSTANCE_TYPE = "t3.large"
NETWORK_INTERFACE_MAP = {
"us-east-1": [
{
"SubnetId": "subnet-037edfff66c2eb894",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-0342574803206ee9c"
]
}
],
"ap-east-1": [
{
"SubnetId": "subnet-011060501be0b589c",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-090470e64df78f6eb"
]
}
]
}
def _allocate_vm(region=DEFAULT_REGION):
run_instances_params = {
"MaxCount": 1,
"MinCount": 1,
"ImageId": IMAGE_ID_MAP[region],
"InstanceType": INSTANCE_TYPE,
"EbsOptimized": True,
"NetworkInterfaces": NETWORK_INTERFACE_MAP[region]
}
ec2_client = boto3.client('ec2', region_name=region)
response = ec2_client.run_instances(**run_instances_params)
instance_id = response['Instances'][0]['InstanceId']
logger.info(f"Waiting for instance {instance_id} to be running...")
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
logger.info(f"Waiting for instance {instance_id} status checks to pass...")
ec2_client.get_waiter('instance_status_ok').wait(InstanceIds=[instance_id])
logger.info(f"Instance {instance_id} is ready.")
return instance_id
class AWSVMManager(VMManager):
def __init__(self, registry_path=REGISTRY_PATH):
self.registry_path = registry_path
self.lock = FileLock(".aws_lck", timeout=10)
self.initialize_registry()
def initialize_registry(self):
with self.lock: # Locking during initialization
if not os.path.exists(self.registry_path):
with open(self.registry_path, 'w') as file:
file.write('')
def add_vm(self, vm_path, region=DEFAULT_REGION):
with self.lock:
with open(self.registry_path, 'r') as file:
lines = file.readlines()
vm_path_at_vm_region = "{}@{}".format(vm_path, region)
new_lines = lines + [f'{vm_path_at_vm_region}|free\n']
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
with self.lock:
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
registered_vm_path, _ = line.strip().split('|')
if registered_vm_path == "{}@{}".format(vm_path, region):
new_lines.append(f'{registered_vm_path}|{pid}\n')
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def check_and_clean(self):
with self.lock: # Lock when cleaning up the registry and vms_dir
# Check and clean on the running vms, detect the released ones and mark then as 'free'
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_path_at_vm_regions = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path_at_vm_region, pid_str = line.strip().split('|')
vm_path, vm_region = vm_path_at_vm_region.split("@")
ec2_client = boto3.client('ec2', region_name=vm_region)
try:
response = ec2_client.describe_instances(InstanceIds=[vm_path])
if not response['Reservations'] or response['Reservations'][0]['Instances'][0]['State'][
'Name'] in ['terminated', 'shutting-down']:
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
continue
elif response['Reservations'][0]['Instances'][0]['State'][
'Name'] == "Stopped":
logger.info(f"VM {vm_path} stopped, mark it as free")
new_lines.append(f'{vm_path}@{vm_region}|free\n')
continue
except ec2_client.exceptions.ClientError as e:
if 'InvalidInstanceID.NotFound' in str(e):
logger.info(f"VM {vm_path} not found, releasing it.")
continue
vm_path_at_vm_regions.append(vm_path_at_vm_region)
if pid_str == "free":
new_lines.append(line)
continue
if int(pid_str) in active_pids:
new_lines.append(line)
else:
new_lines.append(f'{vm_path_at_vm_region}|free\n')
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
# We won't check and clean on the files on aws and delete the unregistered ones
# Since this can lead to unexpected delete on other server
# PLease do monitor the instances to avoid additional cost
def list_free_vms(self, region=DEFAULT_REGION):
with self.lock: # Lock when reading the registry
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path_at_vm_region, pid_str = line.strip().split('|')
vm_path, vm_region = vm_path_at_vm_region.split("@")
if pid_str == "free" and vm_region == region:
free_vms.append((vm_path, pid_str))
return free_vms
def get_vm_path(self, region=DEFAULT_REGION):
self.check_and_clean()
free_vms_paths = self.list_free_vms(region)
if len(free_vms_paths) == 0:
# No free virtual machine available, generate a new one
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
new_vm_path = _allocate_vm(region)
self.add_vm(new_vm_path, region)
self.occupy_vm(new_vm_path, os.getpid(), region)
return new_vm_path
else:
# Choose the first free virtual machine
chosen_vm_path = free_vms_paths[0][0]
self.occupy_vm(chosen_vm_path, os.getpid(), region)
return chosen_vm_path

View File

@@ -0,0 +1,112 @@
import boto3
from botocore.exceptions import ClientError
import logging
from .manager import INSTANCE_TYPE
from desktop_env.providers.base import Provider
logger = logging.getLogger("desktopenv.providers.aws.AWSProvider")
logger.setLevel(logging.INFO)
WAIT_DELAY = 15
MAX_ATTEMPTS = 10
class AWSProvider(Provider):
def start_emulator(self, path_to_vm: str, headless: bool):
logger.info("Starting AWS VM...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
# Start the instance
ec2_client.start_instances(InstanceIds=[path_to_vm])
logger.info(f"Instance {path_to_vm} is starting...")
# Wait for the instance to be in the 'running' state
waiter = ec2_client.get_waiter('instance_running')
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
logger.info(f"Instance {path_to_vm} is now running.")
except ClientError as e:
logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}")
raise
def get_ip_address(self, path_to_vm: str) -> str:
logger.info("Getting AWS VM IP address...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
response = ec2_client.describe_instances(InstanceIds=[path_to_vm])
for reservation in response['Reservations']:
for instance in reservation['Instances']:
private_ip_address = instance.get('PrivateIpAddress', '')
return private_ip_address
return '' # Return an empty string if no IP address is found
except ClientError as e:
logger.error(f"Failed to retrieve private IP address for the instance {path_to_vm}: {str(e)}")
raise
def save_state(self, path_to_vm: str, snapshot_name: str):
logger.info("Saving AWS VM state...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
image_response = ec2_client.create_image(InstanceId=path_to_vm, ImageId=snapshot_name)
image_id = image_response['ImageId']
logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.")
return image_id
except ClientError as e:
logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}")
raise
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
# Step 1: Retrieve the original instance details
instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm])
instance = instance_details['Reservations'][0]['Instances'][0]
security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']]
subnet_id = instance['SubnetId']
instance_type = instance['InstanceType']
iam_instance_profile = instance.get('IamInstanceProfile', {}).get('Arn', '')
# Step 2: Launch a new instance from the snapshot
logger.info(f"Launching a new instance from snapshot {snapshot_name}...")
new_instance = ec2_client.run_instances(
ImageId=snapshot_name,
InstanceType=instance_type,
SecurityGroupIds=security_groups,
SubnetId=subnet_id,
IamInstanceProfile={'Arn': iam_instance_profile} if iam_instance_profile else {},
MinCount=1,
MaxCount=1
)
new_instance_id = new_instance['Instances'][0]['InstanceId']
logger.info(f"New instance {new_instance_id} launched from snapshot {snapshot_name}.")
# Step 3: Terminate the old instance
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
logger.info(f"Old instance {path_to_vm} has been terminated.")
return new_instance_id
except ClientError as e:
logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}")
raise
def stop_emulator(self, path_to_vm, region=None):
logger.info(f"Stopping AWS VM {path_to_vm}...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
ec2_client.stop_instances(InstanceIds=[path_to_vm])
waiter = ec2_client.get_waiter('instance_stopped')
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
logger.info(f"Instance {path_to_vm} has been stopped.")
except ClientError as e:
logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}")
raise

View File

View File

View File

View File

@@ -0,0 +1,89 @@
from abc import ABC, abstractmethod
class Provider(ABC):
def __init__(self, region: str = None):
"""
Region of the cloud service.
"""
self.region = region
@abstractmethod
def start_emulator(self, path_to_vm: str, headless: bool):
"""
Method to start the emulator.
"""
pass
@abstractmethod
def get_ip_address(self, path_to_vm: str) -> str:
"""
Method to get the private IP address of the VM. Private IP means inside the VPC.
"""
pass
@abstractmethod
def save_state(self, path_to_vm: str, snapshot_name: str):
"""
Method to save the state of the VM.
"""
pass
@abstractmethod
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str) -> str:
"""
Method to revert the VM to a given snapshot.
"""
pass
@abstractmethod
def stop_emulator(self, path_to_vm: str):
"""
Method to stop the emulator.
"""
pass
class VMManager(ABC):
@abstractmethod
def initialize_registry(self, **kwargs):
"""
Initialize registry.
"""
pass
@abstractmethod
def add_vm(self, vm_path, **kwargs):
"""
Add the path of new VM to the registration.
"""
pass
@abstractmethod
def occupy_vm(self, vm_path, pid, **kwargs):
"""
Mark the path of VM occupied by the pid.
"""
pass
@abstractmethod
def list_free_vms(self, **kwargs):
"""
List the paths of VM that are free to use allocated.
"""
pass
@abstractmethod
def check_and_clean(self, **kwargs):
"""
Check the registration list, and remove the paths of VM that are not in use.
"""
pass
@abstractmethod
def get_vm_path(self, **kwargs):
"""
Get a virtual machine that is not occupied, generate a new one if no free VM.
"""
pass

View File

View File

View File

View File

@@ -0,0 +1,23 @@
## 💾 Installation of VMware Workstation Pro
---
1. Download VMware Workstation Pro from the [official website](https://www.vmware.com/products/workstation-pro/workstation-pro-evaluation.html). The version we are using is 17.5.1. For systems with Apple chips, you should install [VMware Fusion](https://www.vmware.com/go/getfusion).
2. Install VMware Workstation
- **[On Linux](https://docs.vmware.com/en/VMware-Workstation-Pro/17/com.vmware.ws.using.doc/GUID-1F5B1F14-A586-4A56-83FA-2E7D8333D5CA.html):** Run the following command in your terminal, where `xxxx-xxxxxxx` represents the version number and internal version number.
```
sudo sh VMware-Workstation-xxxx-xxxxxxx.architecture.bundle --console
```
- **[On Windows](https://docs.vmware.com/en/VMware-Workstation-Pro/17/com.vmware.ws.using.doc/GUID-F5A7B3CB-9141-458B-A256-E0C3EA805AAA.html):** Ensure that you're logged in as either the Administrator user or as a user who belongs to the local Administrators group. If you're logging in to a domain, make sure your domain account has local administrator privileges. Proceed by double-clicking the `VMware-workstation-xxxx-xxxxxxx.exe` file. Be aware that you might need to reboot your host system to finalize the installation.
- **[For systems with Apple chips](https://docs.vmware.com/en/VMware-Fusion/13/com.vmware.fusion.using.doc/GUID-ACC3A019-93D3-442C-A34E-F7755DF6733B.html):** Double-click the `VMware-Fusion-xxxx-xxxxxxx.dmg` file to open it. In the Finder window that appears, double-click the 'Install Fusion' icon. When prompted, enter your administrator username and password.
> **Note:** You need to fill the activation key during the installation process when prompted.
3. Verify the successful installation by running the following:
```
vmrun -T ws list
```
If the installation along with the environment variable set is successful, you will see the message showing the current running virtual machines.

View File

View File

@@ -0,0 +1,379 @@
import os
import platform
import random
import re
import threading
from filelock import FileLock
import uuid
import zipfile
from time import sleep
import shutil
import psutil
import subprocess
import requests
from tqdm import tqdm
import logging
from desktop_env.providers.base import VMManager
logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager")
logger.setLevel(logging.INFO)
MAX_RETRY_TIMES = 10
UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_arm/resolve/main/Ubuntu.zip"
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_x86/resolve/main/Ubuntu.zip"
DOWNLOADED_FILE_NAME = "Ubuntu.zip"
REGISTRY_PATH = '.vmware_vms'
VMS_DIR = "./vmware_vm_data"
update_lock = threading.Lock()
def generate_new_vm_name(vms_dir):
registry_idx = 0
while True:
attempted_new_name = f"Ubuntu{registry_idx}"
if os.path.exists(
os.path.join(vms_dir, attempted_new_name, attempted_new_name, attempted_new_name + ".vmx")):
registry_idx += 1
else:
return attempted_new_name
def _update_vm(vmx_path, target_vm_name):
"""Update the VMX file with the new VM name and other parameters, so that the VM can be started successfully without conflict with the original VM."""
with update_lock:
dir_path, vmx_file = os.path.split(vmx_path)
def _generate_mac_address():
# VMware MAC address range starts with 00:0c:29
mac = [0x00, 0x0c, 0x29,
random.randint(0x00, 0x7f),
random.randint(0x00, 0xff),
random.randint(0x00, 0xff)]
return ':'.join(map(lambda x: "%02x" % x, mac))
# Backup the original file
with open(vmx_path, 'r') as file:
original_content = file.read()
# Generate new values
new_uuid_bios = str(uuid.uuid4())
new_uuid_location = str(uuid.uuid4())
new_mac_address = _generate_mac_address()
new_vmci_id = str(random.randint(-2147483648, 2147483647)) # Random 32-bit integer
# Update the content
updated_content = re.sub(r'displayName = ".*?"', f'displayName = "{target_vm_name}"', original_content)
updated_content = re.sub(r'uuid.bios = ".*?"', f'uuid.bios = "{new_uuid_bios}"', updated_content)
updated_content = re.sub(r'uuid.location = ".*?"', f'uuid.location = "{new_uuid_location}"', updated_content)
updated_content = re.sub(r'ethernet0.generatedAddress = ".*?"',
f'ethernet0.generatedAddress = "{new_mac_address}"',
updated_content)
updated_content = re.sub(r'vmci0.id = ".*?"', f'vmci0.id = "{new_vmci_id}"', updated_content)
# Write the updated content back to the file
with open(vmx_path, 'w') as file:
file.write(updated_content)
logger.info(".vmx file updated successfully.")
vmx_file_base_name = os.path.splitext(vmx_file)[0]
assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'."
files_to_rename = ['vmx', 'nvram', 'vmsd', 'vmxf']
for ext in files_to_rename:
original_file = os.path.join(dir_path, f"{vmx_file_base_name}.{ext}")
target_file = os.path.join(dir_path, f"{target_vm_name}.{ext}")
os.rename(original_file, target_file)
# Update the dir_path to the target vm_name, only replace the last character
# Split the path into parts up to the last folder
path_parts = dir_path.rstrip(os.sep).split(os.sep)
path_parts[-1] = target_vm_name
target_dir_path = os.sep.join(path_parts)
os.rename(dir_path, target_dir_path)
logger.info("VM files renamed successfully.")
def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu"):
os.makedirs(vms_dir, exist_ok=True)
def __download_and_unzip_vm():
# Determine the platform and CPU architecture to decide the correct VM image to download
if platform.system() == 'Darwin': # macOS
# if os.uname().machine == 'arm64': # Apple Silicon
url = UBUNTU_ARM_URL
# else:
# url = UBUNTU_X86_URL
elif platform.machine().lower() in ['amd64', 'x86_64']:
url = UBUNTU_X86_URL
else:
raise Exception("Unsupported platform or architecture")
# Download the virtual machine image
logger.info("Downloading the virtual machine image...")
downloaded_size = 0
while True:
downloaded_file_path = os.path.join(vms_dir, downloaded_file_name)
headers = {}
if os.path.exists(downloaded_file_path):
downloaded_size = os.path.getsize(downloaded_file_path)
headers["Range"] = f"bytes={downloaded_size}-"
with requests.get(url, headers=headers, stream=True) as response:
if response.status_code == 416:
# This means the range was not satisfiable, possibly the file was fully downloaded
logger.info("Fully downloaded or the file sized changed.")
break
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(downloaded_file_path, "ab") as file, tqdm(
desc="Progress",
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
initial=downloaded_size,
ascii=True
) as progress_bar:
try:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
progress_bar.update(size)
except (requests.exceptions.RequestException, IOError) as e:
logger.error(f"Download error: {e}")
sleep(1) # Wait for 1 second before retrying
logger.error("Retrying...")
else:
logger.info("Download succeeds.")
break # Download completed successfully
# Unzip the downloaded file
logger.info("Unzipping the downloaded file...☕️")
with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref:
zip_ref.extractall(os.path.join(vms_dir, vm_name))
logger.info("Files have been successfully extracted to the directory: " + str(os.path.join(vms_dir, vm_name)))
vm_path = os.path.join(vms_dir, vm_name, vm_name, vm_name + ".vmx")
# Execute the function to download and unzip the VM, and update the vm metadata
if not os.path.exists(vm_path):
__download_and_unzip_vm()
_update_vm(os.path.join(vms_dir, vm_name, original_vm_name, original_vm_name + ".vmx"), vm_name)
else:
logger.info(f"Virtual machine exists: {vm_path}")
# Determine the platform of the host machine and decide the parameter for vmrun
def get_vmrun_type():
if platform.system() == 'Windows' or platform.system() == 'Linux':
return '-T ws'
elif platform.system() == 'Darwin': # Darwin is the system name for macOS
return '-T fusion'
else:
raise Exception("Unsupported operating system")
# Start the virtual machine
def start_vm(vm_path, max_retries=20):
command = f'vmrun {get_vmrun_type()} start "{vm_path}" nogui'
for attempt in range(max_retries):
result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8")
if result.returncode == 0:
logger.info("Virtual machine started.")
return True
else:
if "Error" in result.stderr:
logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}")
else:
logger.error(f"Attempt {attempt + 1} failed: {result.stderr}")
if attempt == max_retries - 1:
logger.error("Maximum retry attempts reached, failed to start the virtual machine.")
return False
if not start_vm(vm_path):
raise ValueError("Error encountered during installation, please rerun the code for retrying.")
def get_vm_ip(vm_path, max_retries=20):
command = f'vmrun {get_vmrun_type()} getGuestIPAddress "{vm_path}" -wait'
for attempt in range(max_retries):
result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8")
if result.returncode == 0:
return result.stdout.strip()
else:
if "Error" in result.stderr:
logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}")
else:
logger.error(f"Attempt {attempt + 1} failed: {result.stderr}")
if attempt == max_retries - 1:
logger.error("Maximum retry attempts reached, failed to get the IP of virtual machine.")
return None
vm_ip = get_vm_ip(vm_path)
if not vm_ip:
raise ValueError("Error encountered during installation, please rerun the code for retrying.")
# Function used to check whether the virtual machine is ready
def download_screenshot(ip):
url = f"http://{ip}:5000/screenshot"
try:
# max trey times 1, max timeout 1
response = requests.get(url, timeout=(10, 10))
if response.status_code == 200:
return True
except Exception as e:
logger.error(f"Error: {e}")
logger.error(f"Type: {type(e).__name__}")
logger.error(f"Error detail: {str(e)}")
sleep(2)
return False
# Try downloading the screenshot until successful
while not download_screenshot(vm_ip):
logger.info("Check whether the virtual machine is ready...")
logger.info("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...")
def create_vm_snapshot(vm_path, max_retries=20):
command = f'vmrun {get_vmrun_type()} snapshot "{vm_path}" "init_state"'
for attempt in range(max_retries):
result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8")
if result.returncode == 0:
logger.info("Snapshot created.")
return True
else:
if "Error" in result.stderr:
logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}")
else:
logger.error(f"Attempt {attempt + 1} failed: {result.stderr}")
if attempt == max_retries - 1:
logger.error("Maximum retry attempts reached, failed to create snapshot.")
return False
# Create a snapshot of the virtual machine
if create_vm_snapshot(vm_path, max_retries=MAX_RETRY_TIMES):
return vm_path
else:
raise ValueError("Error encountered during installation, please rerun the code for retrying.")
class VMwareVMManager(VMManager):
def __init__(self, registry_path=REGISTRY_PATH):
self.registry_path = registry_path
self.lock = FileLock(".vmware_lck", timeout=10)
self.initialize_registry()
def initialize_registry(self):
with self.lock: # Locking during initialization
if not os.path.exists(self.registry_path):
with open(self.registry_path, 'w') as file:
file.write('')
def add_vm(self, vm_path, region=None):
assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'."
with self.lock:
with open(self.registry_path, 'r') as file:
lines = file.readlines()
new_lines = lines + [f'{vm_path}|free\n']
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def occupy_vm(self, vm_path, pid, region=None):
assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'."
with self.lock:
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
registered_vm_path, _ = line.strip().split('|')
if registered_vm_path == vm_path:
new_lines.append(f'{registered_vm_path}|{pid}\n')
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def check_and_clean(self, vms_dir):
with self.lock: # Lock when cleaning up the registry and vms_dir
# Check and clean on the running vms, detect the released ones and mark then as 'free'
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_paths = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path, pid_str = line.strip().split('|')
if not os.path.exists(vm_path):
logger.info(f"VM {vm_path} not found, releasing it.")
new_lines.append(f'{vm_path}|free\n')
continue
vm_paths.append(vm_path)
if pid_str == "free":
new_lines.append(line)
continue
if int(pid_str) in active_pids:
new_lines.append(line)
else:
new_lines.append(f'{vm_path}|free\n')
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
# Check and clean on the files inside vms_dir, delete the unregistered ones
os.makedirs(vms_dir, exist_ok=True)
vm_names = os.listdir(vms_dir)
for vm_name in vm_names:
# skip the downloaded .zip file
if vm_name == DOWNLOADED_FILE_NAME:
continue
# Skip the .DS_Store file on macOS
if vm_name == ".DS_Store":
continue
flag = True
for vm_path in vm_paths:
if vm_name + ".vmx" in vm_path:
flag = False
if flag:
shutil.rmtree(os.path.join(vms_dir, vm_name))
def list_free_vms(self):
with self.lock: # Lock when reading the registry
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path, pid_str = line.strip().split('|')
if pid_str == "free":
free_vms.append((vm_path, pid_str))
return free_vms
def get_vm_path(self, region=None):
assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'."
self.check_and_clean(vms_dir=VMS_DIR)
free_vms_paths = self.list_free_vms()
if len(free_vms_paths) == 0:
# No free virtual machine available, generate a new one
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR)
new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR,
downloaded_file_name=DOWNLOADED_FILE_NAME)
self.add_vm(new_vm_path)
self.occupy_vm(new_vm_path, os.getpid())
return new_vm_path
else:
# Choose the first free virtual machine
chosen_vm_path = free_vms_paths[0][0]
self.occupy_vm(chosen_vm_path, os.getpid())
return chosen_vm_path

View File

@@ -0,0 +1,90 @@
import logging
import platform
import subprocess
import time
import os
from desktop_env.providers.base import Provider
logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider")
logger.setLevel(logging.INFO)
WAIT_TIME = 3
def get_vmrun_type(return_list=False):
if platform.system() == 'Windows' or platform.system() == 'Linux':
if return_list:
return ['-T', 'ws']
else:
return '-T ws'
elif platform.system() == 'Darwin': # Darwin is the system name for macOS
if return_list:
return ['-T', 'fusion']
else:
return '-T fusion'
else:
raise Exception("Unsupported operating system")
class VMwareProvider(Provider):
@staticmethod
def _execute_command(command: list):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True,
encoding="utf-8")
if result.returncode != 0:
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
return result.stdout.strip()
def start_emulator(self, path_to_vm: str, headless: bool):
print("Starting VMware VM...")
logger.info("Starting VMware VM...")
while True:
try:
output = subprocess.check_output(f"vmrun {get_vmrun_type()} list", shell=True, stderr=subprocess.STDOUT)
output = output.decode()
output = output.splitlines()
normalized_path_to_vm = os.path.abspath(os.path.normpath(path_to_vm))
if any(os.path.abspath(os.path.normpath(line)) == normalized_path_to_vm for line in output):
logger.info("VM is running.")
break
else:
logger.info("Starting VM...")
VMwareProvider._execute_command(["vmrun"] + get_vmrun_type(return_list=True) + ["start", path_to_vm]) if not headless else \
VMwareProvider._execute_command(
["vmrun"] + get_vmrun_type(return_list=True) + ["start", path_to_vm, "nogui"])
time.sleep(WAIT_TIME)
except subprocess.CalledProcessError as e:
logger.error(f"Error executing command: {e.output.decode().strip()}")
def get_ip_address(self, path_to_vm: str) -> str:
logger.info("Getting VMware VM IP address...")
while True:
try:
output = VMwareProvider._execute_command(
["vmrun"] + get_vmrun_type(return_list=True) + ["getGuestIPAddress", path_to_vm, "-wait"]
)
logger.info(f"VMware VM IP address: {output}")
return output
except Exception as e:
logger.error(e)
time.sleep(WAIT_TIME)
logger.info("Retrying to get VMware VM IP address...")
def save_state(self, path_to_vm: str, snapshot_name: str):
logger.info("Saving VMware VM state...")
VMwareProvider._execute_command(["vmrun"] + get_vmrun_type(return_list=True) + ["snapshot", path_to_vm, snapshot_name])
time.sleep(WAIT_TIME) # Wait for the VM to save
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
logger.info(f"Reverting VMware VM to snapshot: {snapshot_name}...")
VMwareProvider._execute_command(["vmrun"] + get_vmrun_type(return_list=True) + ["revertToSnapshot", path_to_vm, snapshot_name])
time.sleep(WAIT_TIME) # Wait for the VM to revert
return path_to_vm
def stop_emulator(self, path_to_vm: str):
logger.info("Stopping VMware VM...")
VMwareProvider._execute_command(["vmrun"] + get_vmrun_type(return_list=True) + ["stop", path_to_vm])
time.sleep(WAIT_TIME) # Wait for the VM to stop