diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 31bc558..e02eacf 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -36,6 +36,7 @@ class DesktopEnv(gym.Env): headless: bool = False, require_a11y_tree: bool = True, require_terminal: bool = False, + os_type: str = "Ubuntu", ): """ Args: @@ -55,12 +56,14 @@ class DesktopEnv(gym.Env): self.region = region self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) + self.os_type = os_type + # Initialize environment variables if path_to_vm: self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \ if provider_name in {"vmware", "virtualbox"} else path_to_vm else: - self.path_to_vm = self.manager.get_vm_path(region) + self.path_to_vm = self.manager.get_vm_path(self.os_type, region) self.snapshot_name = snapshot_name self.cache_dir_base: str = cache_dir @@ -86,7 +89,7 @@ class DesktopEnv(gym.Env): def _start_emulator(self): # Power on the virtual machine - self.provider.start_emulator(self.path_to_vm, self.headless) + self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type) # Get the ip from the virtual machine, and setup the controller self.vm_ip = self.provider.get_ip_address(self.path_to_vm) diff --git a/desktop_env/providers/vmware/manager.py b/desktop_env/providers/vmware/manager.py index 63377d1..7a788ba 100644 --- a/desktop_env/providers/vmware/manager.py +++ b/desktop_env/providers/vmware/manager.py @@ -26,6 +26,7 @@ MAX_RETRY_TIMES = 10 RETRY_INTERVAL = 5 UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-arm.zip" UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip" +WINDOWS_X86_URL = "https://huggingface.co/datasets/xlangai/windows_osworld/resolve/main/Windows-x86.zip" # Determine the platform and CPU architecture to decide the correct VM image to download if platform.system() == 'Darwin': # macOS @@ -35,6 +36,7 @@ if platform.system() == 'Darwin': # macOS # url = UBUNTU_X86_URL elif platform.machine().lower() in ['amd64', 'x86_64']: URL = UBUNTU_X86_URL + else: raise Exception("Unsupported platform or architecture") @@ -45,13 +47,17 @@ VMS_DIR = "./vmware_vm_data" update_lock = threading.Lock() if platform.system() == 'Windows': - vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation" + #vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation" + vboxmanage_path = r"D:\VMware Workstation" os.environ["PATH"] += os.pathsep + vboxmanage_path -def generate_new_vm_name(vms_dir): +def generate_new_vm_name(vms_dir, os_type): registry_idx = 0 + prefix = os_type while True: - attempted_new_name = f"Ubuntu{registry_idx}" + #attempted_new_name = f"Ubuntu{registry_idx}" + + attempted_new_name = f"{prefix}{registry_idx}" if os.path.exists( os.path.join(vms_dir, attempted_new_name, attempted_new_name + ".vmx")): registry_idx += 1 @@ -99,7 +105,7 @@ def _update_vm(vmx_path, target_vm_name): vmx_file_base_name = os.path.splitext(vmx_file)[0] - assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'." + # assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'." files_to_rename = ['vmx', 'nvram', 'vmsd', 'vmxf'] for ext in files_to_rename: @@ -117,7 +123,7 @@ def _update_vm(vmx_path, target_vm_name): logger.info("VM files renamed successfully.") -def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu"): +def _install_vm(vm_name, vms_dir, downloaded_file_name, os_type, original_vm_name="Ubuntu"): os.makedirs(vms_dir, exist_ok=True) def __download_and_unzip_vm(): @@ -125,6 +131,17 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu logger.info("Downloading the virtual machine image...") downloaded_size = 0 + if os_type == "Ubuntu": + if platform.system() == 'Darwin': + URL = UBUNTU_ARM_URL + elif platform.machine().lower() in ['amd64', 'x86_64']: + URL = UBUNTU_X86_URL + elif os_type == "Windows": + if platform.machine().lower() in ['amd64', 'x86_64']: + URL = WINDOWS_X86_URL + DOWNLOADED_FILE_NAME = URL.split('/')[-1] + downloaded_file_name = DOWNLOADED_FILE_NAME + while True: downloaded_file_path = os.path.join(vms_dir, downloaded_file_name) headers = {} @@ -162,7 +179,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu logger.info("Download succeeds.") break # Download completed successfully - # Unzip the downloaded file + # # # Unzip the downloaded file logger.info("Unzipping the downloaded file...☕️") with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref: zip_ref.extractall(os.path.join(vms_dir, vm_name)) @@ -403,7 +420,7 @@ class VMwareVMManager(VMManager): free_vms.append((vm_path, pid_str)) return free_vms - def get_vm_path(self, region=None): + def get_vm_path(self, os_type, region=None): with self.lock: if not VMwareVMManager.checked_and_cleaned: VMwareVMManager.checked_and_cleaned = True @@ -417,15 +434,22 @@ class VMwareVMManager(VMManager): allocation_needed = True else: # Choose the first free virtual machine - chosen_vm_path = free_vms_paths[0][0] + chosen_vm_path = free_vms_paths[0][0] self._occupy_vm(chosen_vm_path, os.getpid()) return chosen_vm_path if allocation_needed: logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") - new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR) + new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR, os_type=os_type) + + original_vm_name = None + if os_type == "Ubuntu": + original_vm_name = "Ubuntu" + elif os_type == "Windows": + original_vm_name = "Windows 10 x64" + new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR, - downloaded_file_name=DOWNLOADED_FILE_NAME) + downloaded_file_name=DOWNLOADED_FILE_NAME, original_vm_name=original_vm_name, os_type=os_type) with self.lock: self._add_vm(new_vm_path) self._occupy_vm(new_vm_path, os.getpid()) diff --git a/desktop_env/providers/vmware/provider.py b/desktop_env/providers/vmware/provider.py index 111a258..44f8fc3 100644 --- a/desktop_env/providers/vmware/provider.py +++ b/desktop_env/providers/vmware/provider.py @@ -44,7 +44,7 @@ class VMwareProvider(Provider): else: return None - def start_emulator(self, path_to_vm: str, headless: bool): + def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): print("Starting VMware VM...") logger.info("Starting VMware VM...")