feat: allow windows virtual mmachine (#73)
This commit is contained in:
@@ -36,6 +36,7 @@ class DesktopEnv(gym.Env):
|
||||
headless: bool = False,
|
||||
require_a11y_tree: bool = True,
|
||||
require_terminal: bool = False,
|
||||
os_type: str = "Ubuntu",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -55,12 +56,14 @@ class DesktopEnv(gym.Env):
|
||||
self.region = region
|
||||
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
|
||||
|
||||
self.os_type = os_type
|
||||
|
||||
# Initialize environment variables
|
||||
if path_to_vm:
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \
|
||||
if provider_name in {"vmware", "virtualbox"} else path_to_vm
|
||||
else:
|
||||
self.path_to_vm = self.manager.get_vm_path(region)
|
||||
self.path_to_vm = self.manager.get_vm_path(self.os_type, region)
|
||||
|
||||
self.snapshot_name = snapshot_name
|
||||
self.cache_dir_base: str = cache_dir
|
||||
@@ -86,7 +89,7 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
def _start_emulator(self):
|
||||
# Power on the virtual machine
|
||||
self.provider.start_emulator(self.path_to_vm, self.headless)
|
||||
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
||||
|
||||
# Get the ip from the virtual machine, and setup the controller
|
||||
self.vm_ip = self.provider.get_ip_address(self.path_to_vm)
|
||||
|
||||
@@ -26,6 +26,7 @@ MAX_RETRY_TIMES = 10
|
||||
RETRY_INTERVAL = 5
|
||||
UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-arm.zip"
|
||||
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip"
|
||||
WINDOWS_X86_URL = "https://huggingface.co/datasets/xlangai/windows_osworld/resolve/main/Windows-x86.zip"
|
||||
|
||||
# Determine the platform and CPU architecture to decide the correct VM image to download
|
||||
if platform.system() == 'Darwin': # macOS
|
||||
@@ -35,6 +36,7 @@ if platform.system() == 'Darwin': # macOS
|
||||
# url = UBUNTU_X86_URL
|
||||
elif platform.machine().lower() in ['amd64', 'x86_64']:
|
||||
URL = UBUNTU_X86_URL
|
||||
|
||||
else:
|
||||
raise Exception("Unsupported platform or architecture")
|
||||
|
||||
@@ -45,13 +47,17 @@ VMS_DIR = "./vmware_vm_data"
|
||||
update_lock = threading.Lock()
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation"
|
||||
#vboxmanage_path = r"C:\Program Files (x86)\VMware\VMware Workstation"
|
||||
vboxmanage_path = r"D:\VMware Workstation"
|
||||
os.environ["PATH"] += os.pathsep + vboxmanage_path
|
||||
|
||||
def generate_new_vm_name(vms_dir):
|
||||
def generate_new_vm_name(vms_dir, os_type):
|
||||
registry_idx = 0
|
||||
prefix = os_type
|
||||
while True:
|
||||
attempted_new_name = f"Ubuntu{registry_idx}"
|
||||
#attempted_new_name = f"Ubuntu{registry_idx}"
|
||||
|
||||
attempted_new_name = f"{prefix}{registry_idx}"
|
||||
if os.path.exists(
|
||||
os.path.join(vms_dir, attempted_new_name, attempted_new_name + ".vmx")):
|
||||
registry_idx += 1
|
||||
@@ -99,7 +105,7 @@ def _update_vm(vmx_path, target_vm_name):
|
||||
|
||||
vmx_file_base_name = os.path.splitext(vmx_file)[0]
|
||||
|
||||
assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'."
|
||||
# assert vmx_file == "Ubuntu.vmx", "The VMX file should be named 'Ubuntu.vmx'."
|
||||
files_to_rename = ['vmx', 'nvram', 'vmsd', 'vmxf']
|
||||
|
||||
for ext in files_to_rename:
|
||||
@@ -117,7 +123,7 @@ def _update_vm(vmx_path, target_vm_name):
|
||||
logger.info("VM files renamed successfully.")
|
||||
|
||||
|
||||
def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu"):
|
||||
def _install_vm(vm_name, vms_dir, downloaded_file_name, os_type, original_vm_name="Ubuntu"):
|
||||
os.makedirs(vms_dir, exist_ok=True)
|
||||
|
||||
def __download_and_unzip_vm():
|
||||
@@ -125,6 +131,17 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
|
||||
logger.info("Downloading the virtual machine image...")
|
||||
downloaded_size = 0
|
||||
|
||||
if os_type == "Ubuntu":
|
||||
if platform.system() == 'Darwin':
|
||||
URL = UBUNTU_ARM_URL
|
||||
elif platform.machine().lower() in ['amd64', 'x86_64']:
|
||||
URL = UBUNTU_X86_URL
|
||||
elif os_type == "Windows":
|
||||
if platform.machine().lower() in ['amd64', 'x86_64']:
|
||||
URL = WINDOWS_X86_URL
|
||||
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
|
||||
downloaded_file_name = DOWNLOADED_FILE_NAME
|
||||
|
||||
while True:
|
||||
downloaded_file_path = os.path.join(vms_dir, downloaded_file_name)
|
||||
headers = {}
|
||||
@@ -162,7 +179,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
|
||||
logger.info("Download succeeds.")
|
||||
break # Download completed successfully
|
||||
|
||||
# Unzip the downloaded file
|
||||
# # # Unzip the downloaded file
|
||||
logger.info("Unzipping the downloaded file...☕️")
|
||||
with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(os.path.join(vms_dir, vm_name))
|
||||
@@ -403,7 +420,7 @@ class VMwareVMManager(VMManager):
|
||||
free_vms.append((vm_path, pid_str))
|
||||
return free_vms
|
||||
|
||||
def get_vm_path(self, region=None):
|
||||
def get_vm_path(self, os_type, region=None):
|
||||
with self.lock:
|
||||
if not VMwareVMManager.checked_and_cleaned:
|
||||
VMwareVMManager.checked_and_cleaned = True
|
||||
@@ -417,15 +434,22 @@ class VMwareVMManager(VMManager):
|
||||
allocation_needed = True
|
||||
else:
|
||||
# Choose the first free virtual machine
|
||||
chosen_vm_path = free_vms_paths[0][0]
|
||||
chosen_vm_path = free_vms_paths[0][0]
|
||||
self._occupy_vm(chosen_vm_path, os.getpid())
|
||||
return chosen_vm_path
|
||||
|
||||
if allocation_needed:
|
||||
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
|
||||
new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR)
|
||||
new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR, os_type=os_type)
|
||||
|
||||
original_vm_name = None
|
||||
if os_type == "Ubuntu":
|
||||
original_vm_name = "Ubuntu"
|
||||
elif os_type == "Windows":
|
||||
original_vm_name = "Windows 10 x64"
|
||||
|
||||
new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR,
|
||||
downloaded_file_name=DOWNLOADED_FILE_NAME)
|
||||
downloaded_file_name=DOWNLOADED_FILE_NAME, original_vm_name=original_vm_name, os_type=os_type)
|
||||
with self.lock:
|
||||
self._add_vm(new_vm_path)
|
||||
self._occupy_vm(new_vm_path, os.getpid())
|
||||
|
||||
@@ -44,7 +44,7 @@ class VMwareProvider(Provider):
|
||||
else:
|
||||
return None
|
||||
|
||||
def start_emulator(self, path_to_vm: str, headless: bool):
|
||||
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
|
||||
print("Starting VMware VM...")
|
||||
logger.info("Starting VMware VM...")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user