import os from filelock import FileLock import boto3 import psutil import logging from desktop_env.providers.base import VMManager logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager") logger.setLevel(logging.INFO) REGISTRY_PATH = '.aws_vms' DEFAULT_REGION = "us-east-1" # todo: Add doc for the configuration of image, security group and network interface # todo: public the AMI images IMAGE_ID_MAP = { "us-east-1": "ami-0b0531325a0d5d488", "ap-east-1": "ami-0b92a0bf157fecaa9" } INSTANCE_TYPE = "t3.large" NETWORK_INTERFACE_MAP = { "us-east-1": [ { "SubnetId": "subnet-037edfff66c2eb894", "AssociatePublicIpAddress": True, "DeviceIndex": 0, "Groups": [ "sg-0342574803206ee9c" ] } ], "ap-east-1": [ { "SubnetId": "subnet-011060501be0b589c", "AssociatePublicIpAddress": True, "DeviceIndex": 0, "Groups": [ "sg-090470e64df78f6eb" ] } ] } def _allocate_vm(region=DEFAULT_REGION): run_instances_params = { "MaxCount": 1, "MinCount": 1, "ImageId": IMAGE_ID_MAP[region], "InstanceType": INSTANCE_TYPE, "EbsOptimized": True, "NetworkInterfaces": NETWORK_INTERFACE_MAP[region] } ec2_client = boto3.client('ec2', region_name=region) response = ec2_client.run_instances(**run_instances_params) instance_id = response['Instances'][0]['InstanceId'] logger.info(f"Waiting for instance {instance_id} to be running...") ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id]) logger.info(f"Waiting for instance {instance_id} status checks to pass...") ec2_client.get_waiter('instance_status_ok').wait(InstanceIds=[instance_id]) logger.info(f"Instance {instance_id} is ready.") return instance_id class AWSVMManager(VMManager): def __init__(self, registry_path=REGISTRY_PATH): self.registry_path = registry_path self.lock = FileLock(".aws_lck", timeout=10) self.initialize_registry() def initialize_registry(self): with self.lock: # Locking during initialization if not os.path.exists(self.registry_path): with open(self.registry_path, 'w') as file: file.write('') def add_vm(self, vm_path, region=DEFAULT_REGION): with self.lock: with open(self.registry_path, 'r') as file: lines = file.readlines() vm_path_at_vm_region = "{}@{}".format(vm_path, region) new_lines = lines + [f'{vm_path_at_vm_region}|free\n'] with open(self.registry_path, 'w') as file: file.writelines(new_lines) def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION): with self.lock: new_lines = [] with open(self.registry_path, 'r') as file: lines = file.readlines() for line in lines: registered_vm_path, _ = line.strip().split('|') if registered_vm_path == "{}@{}".format(vm_path, region): new_lines.append(f'{registered_vm_path}|{pid}\n') else: new_lines.append(line) with open(self.registry_path, 'w') as file: file.writelines(new_lines) def check_and_clean(self): with self.lock: # Lock when cleaning up the registry and vms_dir # Check and clean on the running vms, detect the released ones and mark then as 'free' active_pids = {p.pid for p in psutil.process_iter()} new_lines = [] vm_path_at_vm_regions = [] with open(self.registry_path, 'r') as file: lines = file.readlines() for line in lines: vm_path_at_vm_region, pid_str = line.strip().split('|') vm_path, vm_region = vm_path_at_vm_region.split("@") ec2_client = boto3.client('ec2', region_name=vm_region) try: response = ec2_client.describe_instances(InstanceIds=[vm_path]) if not response['Reservations'] or response['Reservations'][0]['Instances'][0]['State'][ 'Name'] in ['terminated', 'shutting-down']: logger.info(f"VM {vm_path} not found or terminated, releasing it.") continue elif response['Reservations'][0]['Instances'][0]['State'][ 'Name'] == "Stopped": logger.info(f"VM {vm_path} stopped, mark it as free") new_lines.append(f'{vm_path}@{vm_region}|free\n') continue except ec2_client.exceptions.ClientError as e: if 'InvalidInstanceID.NotFound' in str(e): logger.info(f"VM {vm_path} not found, releasing it.") continue vm_path_at_vm_regions.append(vm_path_at_vm_region) if pid_str == "free": new_lines.append(line) continue if int(pid_str) in active_pids: new_lines.append(line) else: new_lines.append(f'{vm_path_at_vm_region}|free\n') with open(self.registry_path, 'w') as file: file.writelines(new_lines) # We won't check and clean on the files on aws and delete the unregistered ones # Since this can lead to unexpected delete on other server # PLease do monitor the instances to avoid additional cost def list_free_vms(self, region=DEFAULT_REGION): with self.lock: # Lock when reading the registry free_vms = [] with open(self.registry_path, 'r') as file: lines = file.readlines() for line in lines: vm_path_at_vm_region, pid_str = line.strip().split('|') vm_path, vm_region = vm_path_at_vm_region.split("@") if pid_str == "free" and vm_region == region: free_vms.append((vm_path, pid_str)) return free_vms def get_vm_path(self, region=DEFAULT_REGION): self.check_and_clean() free_vms_paths = self.list_free_vms(region) if len(free_vms_paths) == 0: # No free virtual machine available, generate a new one logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") new_vm_path = _allocate_vm(region) self.add_vm(new_vm_path, region) self.occupy_vm(new_vm_path, os.getpid(), region) return new_vm_path else: # Choose the first free virtual machine chosen_vm_path = free_vms_paths[0][0] self.occupy_vm(chosen_vm_path, os.getpid(), region) return chosen_vm_path