feat: allow windows virtual mmachine (#73)

This commit is contained in:
MillanK
2024-09-24 00:08:13 +08:00
committed by GitHub
parent e0d0041520
commit c2f68b9085
3 changed files with 40 additions and 13 deletions

View File

@@ -36,6 +36,7 @@ class DesktopEnv(gym.Env):
headless: bool = False, headless: bool = False,
require_a11y_tree: bool = True, require_a11y_tree: bool = True,
require_terminal: bool = False, require_terminal: bool = False,
os_type: str = "Ubuntu",
): ):
""" """
Args: Args:
@@ -55,12 +56,14 @@ class DesktopEnv(gym.Env):
self.region = region self.region = region
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
self.os_type = os_type
# Initialize environment variables # Initialize environment variables
if path_to_vm: if path_to_vm:
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \ self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \
if provider_name in {"vmware", "virtualbox"} else path_to_vm if provider_name in {"vmware", "virtualbox"} else path_to_vm
else: else:
self.path_to_vm = self.manager.get_vm_path(region) self.path_to_vm = self.manager.get_vm_path(self.os_type, region)
self.snapshot_name = snapshot_name self.snapshot_name = snapshot_name
self.cache_dir_base: str = cache_dir self.cache_dir_base: str = cache_dir
@@ -86,7 +89,7 @@ class DesktopEnv(gym.Env):
def _start_emulator(self): def _start_emulator(self):
# Power on the virtual machine # Power on the virtual machine
self.provider.start_emulator(self.path_to_vm, self.headless) self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
# Get the ip from the virtual machine, and setup the controller # Get the ip from the virtual machine, and setup the controller
self.vm_ip = self.provider.get_ip_address(self.path_to_vm) self.vm_ip = self.provider.get_ip_address(self.path_to_vm)

View File

@@ -26,6 +26,7 @@ MAX_RETRY_TIMES = 10
RETRY_INTERVAL = 5 RETRY_INTERVAL = 5
UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-arm.zip" UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-arm.zip"
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip" UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip"
WINDOWS_X86_URL = "https://huggingface.co/datasets/xlangai/windows_osworld/resolve/main/Windows-x86.zip"
# Determine the platform and CPU architecture to decide the correct VM image to download # Determine the platform and CPU architecture to decide the correct VM image to download
if platform.system() == 'Darwin': # macOS if platform.system() == 'Darwin': # macOS
@@ -35,6 +36,7 @@ if platform.system() == 'Darwin': # macOS
# url = UBUNTU_X86_URL # url = UBUNTU_X86_URL
elif platform.machine().lower() in ['amd64', 'x86_64']: elif platform.machine().lower() in ['amd64', 'x86_64']:
URL = UBUNTU_X86_URL URL = UBUNTU_X86_URL
else: else:
raise Exception("Unsupported platform or architecture") raise Exception("Unsupported platform or architecture")
@@ -45,13 +47,17 @@ VMS_DIR = "./vmware_vm_data"
update_lock = threading.Lock() update_lock = threading.Lock()
if platform.system() == 'Windows': if platform.system() == 'Windows':
vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation" #vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation"
vboxmanage_path = r"D:\VMware Workstation"
os.environ["PATH"] += os.pathsep + vboxmanage_path os.environ["PATH"] += os.pathsep + vboxmanage_path
def generate_new_vm_name(vms_dir): def generate_new_vm_name(vms_dir, os_type):
registry_idx = 0 registry_idx = 0
prefix = os_type
while True: while True:
attempted_new_name = f"Ubuntu{registry_idx}" #attempted_new_name = f"Ubuntu{registry_idx}"
attempted_new_name = f"{prefix}{registry_idx}"
if os.path.exists( if os.path.exists(
os.path.join(vms_dir, attempted_new_name, attempted_new_name + ".vmx")): os.path.join(vms_dir, attempted_new_name, attempted_new_name + ".vmx")):
registry_idx += 1 registry_idx += 1
@@ -99,7 +105,7 @@ def _update_vm(vmx_path, target_vm_name):
vmx_file_base_name = os.path.splitext(vmx_file)[0] vmx_file_base_name = os.path.splitext(vmx_file)[0]
assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'." # assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'."
files_to_rename = ['vmx', 'nvram', 'vmsd', 'vmxf'] files_to_rename = ['vmx', 'nvram', 'vmsd', 'vmxf']
for ext in files_to_rename: for ext in files_to_rename:
@@ -117,7 +123,7 @@ def _update_vm(vmx_path, target_vm_name):
logger.info("VM files renamed successfully.") logger.info("VM files renamed successfully.")
def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu"): def _install_vm(vm_name, vms_dir, downloaded_file_name, os_type, original_vm_name="Ubuntu"):
os.makedirs(vms_dir, exist_ok=True) os.makedirs(vms_dir, exist_ok=True)
def __download_and_unzip_vm(): def __download_and_unzip_vm():
@@ -125,6 +131,17 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
logger.info("Downloading the virtual machine image...") logger.info("Downloading the virtual machine image...")
downloaded_size = 0 downloaded_size = 0
if os_type == "Ubuntu":
if platform.system() == 'Darwin':
URL = UBUNTU_ARM_URL
elif platform.machine().lower() in ['amd64', 'x86_64']:
URL = UBUNTU_X86_URL
elif os_type == "Windows":
if platform.machine().lower() in ['amd64', 'x86_64']:
URL = WINDOWS_X86_URL
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
downloaded_file_name = DOWNLOADED_FILE_NAME
while True: while True:
downloaded_file_path = os.path.join(vms_dir, downloaded_file_name) downloaded_file_path = os.path.join(vms_dir, downloaded_file_name)
headers = {} headers = {}
@@ -162,7 +179,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
logger.info("Download succeeds.") logger.info("Download succeeds.")
break # Download completed successfully break # Download completed successfully
# Unzip the downloaded file # # # Unzip the downloaded file
logger.info("Unzipping the downloaded file...☕️") logger.info("Unzipping the downloaded file...☕️")
with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref: with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref:
zip_ref.extractall(os.path.join(vms_dir, vm_name)) zip_ref.extractall(os.path.join(vms_dir, vm_name))
@@ -403,7 +420,7 @@ class VMwareVMManager(VMManager):
free_vms.append((vm_path, pid_str)) free_vms.append((vm_path, pid_str))
return free_vms return free_vms
def get_vm_path(self, region=None): def get_vm_path(self, os_type, region=None):
with self.lock: with self.lock:
if not VMwareVMManager.checked_and_cleaned: if not VMwareVMManager.checked_and_cleaned:
VMwareVMManager.checked_and_cleaned = True VMwareVMManager.checked_and_cleaned = True
@@ -423,9 +440,16 @@ class VMwareVMManager(VMManager):
if allocation_needed: if allocation_needed:
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR) new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR, os_type=os_type)
original_vm_name = None
if os_type == "Ubuntu":
original_vm_name = "Ubuntu"
elif os_type == "Windows":
original_vm_name = "Windows 10 x64"
new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR, new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR,
downloaded_file_name=DOWNLOADED_FILE_NAME) downloaded_file_name=DOWNLOADED_FILE_NAME, original_vm_name=original_vm_name, os_type=os_type)
with self.lock: with self.lock:
self._add_vm(new_vm_path) self._add_vm(new_vm_path)
self._occupy_vm(new_vm_path, os.getpid()) self._occupy_vm(new_vm_path, os.getpid())

View File

@@ -44,7 +44,7 @@ class VMwareProvider(Provider):
else: else:
return None return None
def start_emulator(self, path_to_vm: str, headless: bool): def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
print("Starting VMware VM...") print("Starting VMware VM...")
logger.info("Starting VMware VM...") logger.info("Starting VMware VM...")