Merge branch 'feat/aws-provider-support' into main
This commit is contained in:
@@ -61,6 +61,7 @@ class SetupController:
|
|||||||
# protocol
|
# protocol
|
||||||
setup_function: str = "_{:}_setup".format(config_type)
|
setup_function: str = "_{:}_setup".format(config_type)
|
||||||
assert hasattr(self, setup_function), f'Setup controller cannot find init function {setup_function}'
|
assert hasattr(self, setup_function), f'Setup controller cannot find init function {setup_function}'
|
||||||
|
logger.info(f"call function {setup_function}")
|
||||||
getattr(self, setup_function)(**parameters)
|
getattr(self, setup_function)(**parameters)
|
||||||
|
|
||||||
logger.info("SETUP: %s(%s)", setup_function, str(parameters))
|
logger.info("SETUP: %s(%s)", setup_function, str(parameters))
|
||||||
@@ -229,6 +230,7 @@ class SetupController:
|
|||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/launch")
|
||||||
response = requests.post(self.http_server + "/setup" + "/launch", headers=headers, data=payload)
|
response = requests.post(self.http_server + "/setup" + "/launch", headers=headers, data=payload)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
logger.info("Command executed successfully: %s", response.text)
|
logger.info("Command executed successfully: %s", response.text)
|
||||||
|
|||||||
@@ -69,7 +69,8 @@ class DesktopEnv(gym.Env):
|
|||||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \
|
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \
|
||||||
if provider_name in {"vmware", "virtualbox"} else path_to_vm
|
if provider_name in {"vmware", "virtualbox"} else path_to_vm
|
||||||
else:
|
else:
|
||||||
self.path_to_vm = self.manager.get_vm_path(region) # self.os_type,
|
|
||||||
|
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region)
|
||||||
|
|
||||||
self.snapshot_name = snapshot_name
|
self.snapshot_name = snapshot_name
|
||||||
self.cache_dir_base: str = cache_dir
|
self.cache_dir_base: str = cache_dir
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ DEFAULT_REGION = "us-east-1"
|
|||||||
# todo: public the AMI images
|
# todo: public the AMI images
|
||||||
IMAGE_ID_MAP = {
|
IMAGE_ID_MAP = {
|
||||||
"us-east-1": "ami-05e7d7bd279ea4f14",
|
"us-east-1": "ami-05e7d7bd279ea4f14",
|
||||||
"ap-east-1": "ami-0c092a5b8be4116f5"
|
"ap-east-1": "ami-0c092a5b8be4116f5",
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANCE_TYPE = "t3.medium"
|
INSTANCE_TYPE = "t3.medium"
|
||||||
@@ -72,13 +72,13 @@ class AWSVMManager(VMManager):
|
|||||||
self.lock = FileLock(".aws_lck", timeout=60)
|
self.lock = FileLock(".aws_lck", timeout=60)
|
||||||
self.initialize_registry()
|
self.initialize_registry()
|
||||||
|
|
||||||
def initialize_registry(self):
|
def initialize_registry(self, **kwargs):
|
||||||
with self.lock: # Locking during initialization
|
with self.lock: # Locking during initialization
|
||||||
if not os.path.exists(self.registry_path):
|
if not os.path.exists(self.registry_path):
|
||||||
with open(self.registry_path, 'w') as file:
|
with open(self.registry_path, 'w') as file:
|
||||||
file.write('')
|
file.write('')
|
||||||
|
|
||||||
def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
|
def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True, **kwargs):
|
||||||
if lock_needed:
|
if lock_needed:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._add_vm(vm_path, region)
|
self._add_vm(vm_path, region)
|
||||||
@@ -93,7 +93,7 @@ class AWSVMManager(VMManager):
|
|||||||
with open(self.registry_path, 'w') as file:
|
with open(self.registry_path, 'w') as file:
|
||||||
file.writelines(new_lines)
|
file.writelines(new_lines)
|
||||||
|
|
||||||
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
|
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True, **kwargs):
|
||||||
if lock_needed:
|
if lock_needed:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._delete_vm(vm_path, region)
|
self._delete_vm(vm_path, region)
|
||||||
@@ -113,7 +113,7 @@ class AWSVMManager(VMManager):
|
|||||||
with open(self.registry_path, 'w') as file:
|
with open(self.registry_path, 'w') as file:
|
||||||
file.writelines(new_lines)
|
file.writelines(new_lines)
|
||||||
|
|
||||||
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
|
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True, **kwargs):
|
||||||
if lock_needed:
|
if lock_needed:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._occupy_vm(vm_path, pid, region)
|
self._occupy_vm(vm_path, pid, region)
|
||||||
@@ -133,7 +133,7 @@ class AWSVMManager(VMManager):
|
|||||||
with open(self.registry_path, 'w') as file:
|
with open(self.registry_path, 'w') as file:
|
||||||
file.writelines(new_lines)
|
file.writelines(new_lines)
|
||||||
|
|
||||||
def check_and_clean(self, lock_needed=True):
|
def check_and_clean(self, lock_needed=True, **kwargs):
|
||||||
if lock_needed:
|
if lock_needed:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._check_and_clean()
|
self._check_and_clean()
|
||||||
@@ -216,7 +216,7 @@ class AWSVMManager(VMManager):
|
|||||||
# Since this can lead to unexpected delete on other server
|
# Since this can lead to unexpected delete on other server
|
||||||
# PLease do monitor the instances to avoid additional cost
|
# PLease do monitor the instances to avoid additional cost
|
||||||
|
|
||||||
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
|
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True, **kwargs):
|
||||||
if lock_needed:
|
if lock_needed:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
return self._list_free_vms(region)
|
return self._list_free_vms(region)
|
||||||
@@ -235,7 +235,7 @@ class AWSVMManager(VMManager):
|
|||||||
|
|
||||||
return free_vms
|
return free_vms
|
||||||
|
|
||||||
def get_vm_path(self, region=DEFAULT_REGION):
|
def get_vm_path(self, region=DEFAULT_REGION, **kwargs):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if not AWSVMManager.checked_and_cleaned:
|
if not AWSVMManager.checked_and_cleaned:
|
||||||
AWSVMManager.checked_and_cleaned = True
|
AWSVMManager.checked_and_cleaned = True
|
||||||
|
|||||||
@@ -17,25 +17,8 @@ MAX_ATTEMPTS = 10
|
|||||||
|
|
||||||
class AWSProvider(Provider):
|
class AWSProvider(Provider):
|
||||||
|
|
||||||
# def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
|
|
||||||
# logger.info("Starting AWS VM...")
|
|
||||||
# ec2_client = boto3.client('ec2', region_name=self.region)
|
|
||||||
|
|
||||||
# try:
|
def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs):
|
||||||
# # Start the instance
|
|
||||||
# ec2_client.start_instances(InstanceIds=[path_to_vm])
|
|
||||||
# logger.info(f"Instance {path_to_vm} is starting...")
|
|
||||||
|
|
||||||
# # Wait for the instance to be in the 'running' state
|
|
||||||
# waiter = ec2_client.get_waiter('instance_running')
|
|
||||||
# waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
|
|
||||||
# logger.info(f"Instance {path_to_vm} is now running.")
|
|
||||||
|
|
||||||
# except ClientError as e:
|
|
||||||
# logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}")
|
|
||||||
# raise
|
|
||||||
|
|
||||||
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
|
|
||||||
logger.info("Starting AWS VM...")
|
logger.info("Starting AWS VM...")
|
||||||
ec2_client = boto3.client('ec2', region_name=self.region)
|
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||||
|
|
||||||
|
|||||||
@@ -227,6 +227,11 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
path_to_vm=args.path_to_vm,
|
path_to_vm=args.path_to_vm,
|
||||||
region=args.region,
|
region=args.region,
|
||||||
action_space=agent.action_space,
|
action_space=agent.action_space,
|
||||||
|
|
||||||
|
provider_name="aws",
|
||||||
|
region="us-east-1",
|
||||||
|
snapshot_name="ami-05e7d7bd279ea4f14",
|
||||||
|
|
||||||
screen_size=(args.screen_width, args.screen_height),
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
headless=args.headless,
|
headless=args.headless,
|
||||||
os_type="Ubuntu",
|
os_type="Ubuntu",
|
||||||
|
|||||||
Reference in New Issue
Block a user