feat: allow windows virtual mmachine (#73)

This commit is contained in:
MillanK
2024-09-24 00:08:13 +08:00
committed by GitHub
parent e0d0041520
commit c2f68b9085
3 changed files with 40 additions and 13 deletions

View File

@@ -36,6 +36,7 @@ class DesktopEnv(gym.Env):
headless: bool = False,
require_a11y_tree: bool = True,
require_terminal: bool = False,
os_type: str = "Ubuntu",
):
"""
Args:
@@ -55,12 +56,14 @@ class DesktopEnv(gym.Env):
self.region = region
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
self.os_type = os_type
# Initialize environment variables
if path_to_vm:
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \
if provider_name in {"vmware", "virtualbox"} else path_to_vm
else:
self.path_to_vm = self.manager.get_vm_path(region)
self.path_to_vm = self.manager.get_vm_path(self.os_type, region)
self.snapshot_name = snapshot_name
self.cache_dir_base: str = cache_dir
@@ -86,7 +89,7 @@ class DesktopEnv(gym.Env):
def _start_emulator(self):
# Power on the virtual machine
self.provider.start_emulator(self.path_to_vm, self.headless)
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
# Get the ip from the virtual machine, and setup the controller
self.vm_ip = self.provider.get_ip_address(self.path_to_vm)

View File

@@ -26,6 +26,7 @@ MAX_RETRY_TIMES = 10
RETRY_INTERVAL = 5
UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-arm.zip"
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip"
WINDOWS_X86_URL = "https://huggingface.co/datasets/xlangai/windows_osworld/resolve/main/Windows-x86.zip"
# Determine the platform and CPU architecture to decide the correct VM image to download
if platform.system() == 'Darwin': # macOS
@@ -35,6 +36,7 @@ if platform.system() == 'Darwin': # macOS
# url = UBUNTU_X86_URL
elif platform.machine().lower() in ['amd64', 'x86_64']:
URL = UBUNTU_X86_URL
else:
raise Exception("Unsupported platform or architecture")
@@ -45,13 +47,17 @@ VMS_DIR = "./vmware_vm_data"
update_lock = threading.Lock()
if platform.system() == 'Windows':
vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation"
#vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation"
vboxmanage_path = r"D:\VMware Workstation"
os.environ["PATH"] += os.pathsep + vboxmanage_path
def generate_new_vm_name(vms_dir):
def generate_new_vm_name(vms_dir, os_type):
registry_idx = 0
prefix = os_type
while True:
attempted_new_name = f"Ubuntu{registry_idx}"
#attempted_new_name = f"Ubuntu{registry_idx}"
attempted_new_name = f"{prefix}{registry_idx}"
if os.path.exists(
os.path.join(vms_dir, attempted_new_name, attempted_new_name + ".vmx")):
registry_idx += 1
@@ -99,7 +105,7 @@ def _update_vm(vmx_path, target_vm_name):
vmx_file_base_name = os.path.splitext(vmx_file)[0]
assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'."
# assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'."
files_to_rename = ['vmx', 'nvram', 'vmsd', 'vmxf']
for ext in files_to_rename:
@@ -117,7 +123,7 @@ def _update_vm(vmx_path, target_vm_name):
logger.info("VM files renamed successfully.")
def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu"):
def _install_vm(vm_name, vms_dir, downloaded_file_name, os_type, original_vm_name="Ubuntu"):
os.makedirs(vms_dir, exist_ok=True)
def __download_and_unzip_vm():
@@ -125,6 +131,17 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
logger.info("Downloading the virtual machine image...")
downloaded_size = 0
if os_type == "Ubuntu":
if platform.system() == 'Darwin':
URL = UBUNTU_ARM_URL
elif platform.machine().lower() in ['amd64', 'x86_64']:
URL = UBUNTU_X86_URL
elif os_type == "Windows":
if platform.machine().lower() in ['amd64', 'x86_64']:
URL = WINDOWS_X86_URL
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
downloaded_file_name = DOWNLOADED_FILE_NAME
while True:
downloaded_file_path = os.path.join(vms_dir, downloaded_file_name)
headers = {}
@@ -162,7 +179,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
logger.info("Download succeeds.")
break # Download completed successfully
# Unzip the downloaded file
# # # Unzip the downloaded file
logger.info("Unzipping the downloaded file...☕️")
with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref:
zip_ref.extractall(os.path.join(vms_dir, vm_name))
@@ -403,7 +420,7 @@ class VMwareVMManager(VMManager):
free_vms.append((vm_path, pid_str))
return free_vms
def get_vm_path(self, region=None):
def get_vm_path(self, os_type, region=None):
with self.lock:
if not VMwareVMManager.checked_and_cleaned:
VMwareVMManager.checked_and_cleaned = True
@@ -417,15 +434,22 @@ class VMwareVMManager(VMManager):
allocation_needed = True
else:
# Choose the first free virtual machine
chosen_vm_path = free_vms_paths[0][0]
chosen_vm_path = free_vms_paths[0][0]
self._occupy_vm(chosen_vm_path, os.getpid())
return chosen_vm_path
if allocation_needed:
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR)
new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR, os_type=os_type)
original_vm_name = None
if os_type == "Ubuntu":
original_vm_name = "Ubuntu"
elif os_type == "Windows":
original_vm_name = "Windows 10 x64"
new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR,
downloaded_file_name=DOWNLOADED_FILE_NAME)
downloaded_file_name=DOWNLOADED_FILE_NAME, original_vm_name=original_vm_name, os_type=os_type)
with self.lock:
self._add_vm(new_vm_path)
self._occupy_vm(new_vm_path, os.getpid())

View File

@@ -44,7 +44,7 @@ class VMwareProvider(Provider):
else:
return None
def start_emulator(self, path_to_vm: str, headless: bool):
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
print("Starting VMware VM...")
logger.info("Starting VMware VM...")