diff --git a/desktop_env/envs/__init__.py b/desktop_env/envs/__init__.py index 470c203..7f5baf6 100644 --- a/desktop_env/envs/__init__.py +++ b/desktop_env/envs/__init__.py @@ -12,7 +12,7 @@ import psutil import requests from tqdm import tqdm -__version__ = "0.1.9" +__version__ = "0.1.12" MAX_RETRY_TIMES = 10 UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_arm/resolve/main/Ubuntu.zip" @@ -283,31 +283,45 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm raise Exception("Unsupported operating system") # Start the virtual machine - subprocess.run(f'vmrun {get_vmrun_type()} start "{vm_path}" nogui', shell=True) - print("Starting virtual machine...") + def start_vm(vm_path, max_retries=20): + command = f'vmrun {get_vmrun_type()} start "{vm_path}" nogui' + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True) + if result.returncode == 0: + print("Virtual machine started.") + return True + else: + if "Error" in result.stderr: + print(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + print(f"Attempt {attempt + 1} failed: {result.stderr}") - # Get the IP address of the virtual machine - for i in range(MAX_RETRY_TIMES): - get_vm_ip = subprocess.run(f'vmrun {get_vmrun_type()} getGuestIPAddress "{vm_path}" -wait', shell=True, - capture_output=True, - text=True) - if "Error" in get_vm_ip.stdout: - print("Retry on getting IP") - continue - print("Virtual machine IP address:", get_vm_ip.stdout.strip()) - break + if attempt == max_retries - 1: + print("Maximum retry attempts reached, failed to start the virtual machine.") + return False - vm_ip = get_vm_ip.stdout.strip() + if not start_vm(vm_path): + raise ValueError("Error encountered during installation, please rerun the code for retrying.") - def is_url_accessible(url, timeout=1): - try: - response = requests.head(url, timeout=timeout) - return response.status_code == 200 - except requests.exceptions.RequestException: - return False + def get_vm_ip(vm_path, max_retries=20): + command = f'vmrun {get_vmrun_type()} getGuestIPAddress "{vm_path}" -wait' + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True) + if result.returncode == 0: + return result.stdout.strip() + else: + if "Error" in result.stderr: + print(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + print(f"Attempt {attempt + 1} failed: {result.stderr}") - url = f"http://{vm_ip}:5000/screenshot" - check_url = is_url_accessible(url) + if attempt == max_retries - 1: + print("Maximum retry attempts reached, failed to get the IP of virtual machine.") + return None + + vm_ip = get_vm_ip(vm_path) + if not vm_ip: + raise ValueError("Error encountered during installation, please rerun the code for retrying.") # Function used to check whether the virtual machine is ready def download_screenshot(ip): @@ -330,11 +344,28 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm print("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...") - # Create a snapshot of the virtual machine - subprocess.run(f'vmrun {get_vmrun_type()} snapshot "{vm_path}" "init_state"', shell=True) - print("Snapshot created.") + def create_vm_snapshot(vm_path, max_retries=20): + command = f'vmrun {get_vmrun_type()} snapshot "{vm_path}" "init_state"' + for attempt in range(max_retries): + result = subprocess.run(command, shell=True, text=True, capture_output=True) + if result.returncode == 0: + print("Snapshot created.") + return True + else: + if "Error" in result.stderr: + print(f"Attempt {attempt + 1} failed with specific error: {result.stderr}") + else: + print(f"Attempt {attempt + 1} failed: {result.stderr}") - return vm_path + if attempt == max_retries - 1: + print("Maximum retry attempts reached, failed to create snapshot.") + return False + + # Create a snapshot of the virtual machine + if create_vm_snapshot(vm_path, max_retries=MAX_RETRY_TIMES): + return vm_path + else: + raise ValueError("Error encountered during installation, please rerun the code for retrying.") def _get_vm_path(): diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 125679a..2f84efe 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -217,9 +217,6 @@ class DesktopEnv(gym.Env): logger.info("Resetting environment...") logger.info("Switching task...") - if task_config is not None: - self._set_task_info(task_config) - self.setup_controller.reset_cache_dir(self.cache_dir) logger.info("Setting counters...") self._traj_no += 1 @@ -234,11 +231,13 @@ class DesktopEnv(gym.Env): self._start_emulator() logger.info("Emulator started.") - logger.info("Setting up environment...") - self.setup_controller.setup(self.config) - - time.sleep(5) - logger.info("Environment setup complete.") + if task_config is not None: + self._set_task_info(task_config) + self.setup_controller.reset_cache_dir(self.cache_dir) + logger.info("Setting up environment...") + self.setup_controller.setup(self.config) + time.sleep(5) + logger.info("Environment setup complete.") observation = self._get_obs() return observation