diff --git a/desktop_env/controllers/setup.py b/desktop_env/controllers/setup.py index ac1c372..04859a0 100644 --- a/desktop_env/controllers/setup.py +++ b/desktop_env/controllers/setup.py @@ -61,6 +61,7 @@ class SetupController: # protocol setup_function: str = "_{:}_setup".format(config_type) assert hasattr(self, setup_function), f'Setup controller cannot find init function {setup_function}' + logger.info(f"call function {setup_function}") getattr(self, setup_function)(**parameters) logger.info("SETUP: %s(%s)", setup_function, str(parameters)) @@ -229,6 +230,7 @@ class SetupController: headers = {"Content-Type": "application/json"} try: + logger.info("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/launch") response = requests.post(self.http_server + "/setup" + "/launch", headers=headers, data=payload) if response.status_code == 200: logger.info("Command executed successfully: %s", response.text) diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 3fec51c..6cb58e7 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -69,7 +69,8 @@ class DesktopEnv(gym.Env): self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \ if provider_name in {"vmware", "virtualbox"} else path_to_vm else: - self.path_to_vm = self.manager.get_vm_path(region) # self.os_type, + + self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region) self.snapshot_name = snapshot_name self.cache_dir_base: str = cache_dir diff --git a/desktop_env/providers/aws/manager.py b/desktop_env/providers/aws/manager.py index 83e7561..1e36054 100644 --- a/desktop_env/providers/aws/manager.py +++ b/desktop_env/providers/aws/manager.py @@ -16,7 +16,7 @@ DEFAULT_REGION = "us-east-1" # todo: public the AMI images IMAGE_ID_MAP = { "us-east-1": "ami-05e7d7bd279ea4f14", - "ap-east-1": "ami-0c092a5b8be4116f5" + "ap-east-1": "ami-0c092a5b8be4116f5", } INSTANCE_TYPE = "t3.medium" @@ -72,13 +72,13 @@ class AWSVMManager(VMManager): self.lock = FileLock(".aws_lck", timeout=60) self.initialize_registry() - def initialize_registry(self): + def initialize_registry(self, **kwargs): with self.lock: # Locking during initialization if not os.path.exists(self.registry_path): with open(self.registry_path, 'w') as file: file.write('') - def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): + def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True, **kwargs): if lock_needed: with self.lock: self._add_vm(vm_path, region) @@ -93,7 +93,7 @@ class AWSVMManager(VMManager): with open(self.registry_path, 'w') as file: file.writelines(new_lines) - def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): + def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True, **kwargs): if lock_needed: with self.lock: self._delete_vm(vm_path, region) @@ -113,7 +113,7 @@ class AWSVMManager(VMManager): with open(self.registry_path, 'w') as file: file.writelines(new_lines) - def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True): + def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True, **kwargs): if lock_needed: with self.lock: self._occupy_vm(vm_path, pid, region) @@ -133,7 +133,7 @@ class AWSVMManager(VMManager): with open(self.registry_path, 'w') as file: file.writelines(new_lines) - def check_and_clean(self, lock_needed=True): + def check_and_clean(self, lock_needed=True, **kwargs): if lock_needed: with self.lock: self._check_and_clean() @@ -216,7 +216,7 @@ class AWSVMManager(VMManager): # Since this can lead to unexpected delete on other server # PLease do monitor the instances to avoid additional cost - def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True): + def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True, **kwargs): if lock_needed: with self.lock: return self._list_free_vms(region) @@ -235,7 +235,7 @@ class AWSVMManager(VMManager): return free_vms - def get_vm_path(self, region=DEFAULT_REGION): + def get_vm_path(self, region=DEFAULT_REGION, **kwargs): with self.lock: if not AWSVMManager.checked_and_cleaned: AWSVMManager.checked_and_cleaned = True diff --git a/desktop_env/providers/aws/provider.py b/desktop_env/providers/aws/provider.py index 57c4d3a..b7cb675 100644 --- a/desktop_env/providers/aws/provider.py +++ b/desktop_env/providers/aws/provider.py @@ -17,25 +17,8 @@ MAX_ATTEMPTS = 10 class AWSProvider(Provider): - # def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): - # logger.info("Starting AWS VM...") - # ec2_client = boto3.client('ec2', region_name=self.region) - # try: - # # Start the instance - # ec2_client.start_instances(InstanceIds=[path_to_vm]) - # logger.info(f"Instance {path_to_vm} is starting...") - - # # Wait for the instance to be in the 'running' state - # waiter = ec2_client.get_waiter('instance_running') - # waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS}) - # logger.info(f"Instance {path_to_vm} is now running.") - - # except ClientError as e: - # logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}") - # raise - - def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): + def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs): logger.info("Starting AWS VM...") ec2_client = boto3.client('ec2', region_name=self.region) diff --git a/run_multienv.py b/run_multienv.py index f09ca66..e697b54 100644 --- a/run_multienv.py +++ b/run_multienv.py @@ -227,6 +227,11 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: path_to_vm=args.path_to_vm, region=args.region, action_space=agent.action_space, + + provider_name="aws", + region="us-east-1", + snapshot_name="ami-05e7d7bd279ea4f14", + screen_size=(args.screen_width, args.screen_height), headless=args.headless, os_type="Ubuntu",