diff --git a/.gitignore b/.gitignore index cae53ca..3ccc515 100644 --- a/.gitignore +++ b/.gitignore @@ -190,6 +190,7 @@ test2.xlsx docker_vm_data vmware_vm_data .vmware* +.aws* # result **/result*/**/* diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 2b8da86..0c0f9aa 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -69,7 +69,7 @@ class DesktopEnv(gym.Env): self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \ if provider_name in {"vmware", "virtualbox"} else path_to_vm else: - self.path_to_vm = self.manager.get_vm_path(self.os_type, region) + self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region) self.snapshot_name = snapshot_name self.cache_dir_base: str = cache_dir diff --git a/desktop_env/providers/aws/manager.py b/desktop_env/providers/aws/manager.py index 188fb6b..78fecb5 100644 --- a/desktop_env/providers/aws/manager.py +++ b/desktop_env/providers/aws/manager.py @@ -16,7 +16,7 @@ DEFAULT_REGION = "us-east-1" # todo: public the AMI images IMAGE_ID_MAP = { "us-east-1": "ami-05e7d7bd279ea4f14", - "ap-east-1": "ami-0c092a5b8be4116f5" + "ap-east-1": "ami-0c092a5b8be4116f5", } INSTANCE_TYPE = "t3.medium" @@ -71,13 +71,13 @@ class AWSVMManager(VMManager): self.lock = FileLock(".aws_lck", timeout=60) self.initialize_registry() - def initialize_registry(self): + def initialize_registry(self, **kwargs): with self.lock: # Locking during initialization if not os.path.exists(self.registry_path): with open(self.registry_path, 'w') as file: file.write('') - def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): + def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True, **kwargs): if lock_needed: with self.lock: self._add_vm(vm_path, region) @@ -92,7 +92,7 @@ class AWSVMManager(VMManager): with open(self.registry_path, 'w') as file: file.writelines(new_lines) - def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): + def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True, **kwargs): if lock_needed: with self.lock: self._delete_vm(vm_path, region) @@ -112,7 +112,7 @@ class AWSVMManager(VMManager): with open(self.registry_path, 'w') as file: file.writelines(new_lines) - def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True): + def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True, **kwargs): if lock_needed: with self.lock: self._occupy_vm(vm_path, pid, region) @@ -132,7 +132,7 @@ class AWSVMManager(VMManager): with open(self.registry_path, 'w') as file: file.writelines(new_lines) - def check_and_clean(self, lock_needed=True): + def check_and_clean(self, lock_needed=True, **kwargs): if lock_needed: with self.lock: self._check_and_clean() @@ -215,7 +215,7 @@ class AWSVMManager(VMManager): # Since this can lead to unexpected delete on other server # PLease do monitor the instances to avoid additional cost - def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True): + def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True, **kwargs): if lock_needed: with self.lock: return self._list_free_vms(region) @@ -234,7 +234,7 @@ class AWSVMManager(VMManager): return free_vms - def get_vm_path(self, region=DEFAULT_REGION): + def get_vm_path(self, region=DEFAULT_REGION, **kwargs): with self.lock: if not AWSVMManager.checked_and_cleaned: AWSVMManager.checked_and_cleaned = True diff --git a/desktop_env/providers/aws/provider.py b/desktop_env/providers/aws/provider.py index c15d15b..d7338d6 100644 --- a/desktop_env/providers/aws/provider.py +++ b/desktop_env/providers/aws/provider.py @@ -14,7 +14,7 @@ MAX_ATTEMPTS = 10 class AWSProvider(Provider): - def start_emulator(self, path_to_vm: str, headless: bool): + def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs): logger.info("Starting AWS VM...") ec2_client = boto3.client('ec2', region_name=self.region) diff --git a/run_multienv.py b/run_multienv.py index 8a8a417..823061b 100644 --- a/run_multienv.py +++ b/run_multienv.py @@ -221,6 +221,10 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: env = DesktopEnv( path_to_vm=args.path_to_vm, action_space=agent.action_space, + + provider_name="aws", + region="us-east-1", + screen_size=(args.screen_width, args.screen_height), headless=args.headless, os_type="Ubuntu",