import boto3 from botocore.exceptions import ClientError import logging from desktop_env.providers.base import Provider from datetime import datetime import time logger = logging.getLogger("desktopenv.providers.aws.AWSProvider") logger.setLevel(logging.INFO) WAIT_DELAY = 15 MAX_ATTEMPTS = 10 class AWSProvider(Provider): def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs): logger.info("Starting AWS VM...") ec2_client = boto3.client('ec2', region_name=self.region) try: # Check the current state of the instance response = ec2_client.describe_instances(InstanceIds=[path_to_vm]) state = response['Reservations'][0]['Instances'][0]['State']['Name'] logger.info(f"Instance {path_to_vm} current state: {state}") if state == 'running': # If the instance is already running, skip starting it logger.info(f"Instance {path_to_vm} is already running. Skipping start.") return if state == 'stopped': # Start the instance if it's currently stopped ec2_client.start_instances(InstanceIds=[path_to_vm]) logger.info(f"Instance {path_to_vm} is starting...") # Wait until the instance reaches 'running' state waiter = ec2_client.get_waiter('instance_running') waiter.wait( InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS} ) logger.info(f"Instance {path_to_vm} is now running.") else: # For all other states (terminated, pending, etc.), log a warning logger.warning(f"Instance {path_to_vm} is in state '{state}' and cannot be started.") except ClientError as e: logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}") raise def get_ip_address(self, path_to_vm: str) -> str: logger.info("Getting AWS VM IP address...") ec2_client = boto3.client('ec2', region_name=self.region) try: response = ec2_client.describe_instances(InstanceIds=[path_to_vm]) for reservation in response['Reservations']: for instance in reservation['Instances']: private_ip_address = instance.get('PrivateIpAddress', '') return private_ip_address return '' # Return an empty string if no IP address is found except ClientError as e: logger.error(f"Failed to retrieve private IP address for the instance {path_to_vm}: {str(e)}") raise def save_state(self, path_to_vm: str, snapshot_name: str): logger.info("Saving AWS VM state...") ec2_client = boto3.client('ec2', region_name=self.region) try: image_response = ec2_client.create_image(InstanceId=path_to_vm, ImageId=snapshot_name) image_id = image_response['ImageId'] logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.") return image_id except ClientError as e: logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}") raise def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...") ec2_client = boto3.client('ec2', region_name=self.region) try: # Step 1: Retrieve the original instance details instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm]) instance = instance_details['Reservations'][0]['Instances'][0] security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']] subnet_id = instance['SubnetId'] instance_type = instance['InstanceType'] instance_snapshot = instance_details['Reservations'][0]['Instances'][0]['ImageId'] # TODO:Step 2: Terminate the old instance ec2_client.terminate_instances(InstanceIds=[path_to_vm]) logger.info(f"Old instance {path_to_vm} has been terminated.") # Step 3: Launch a new instance from the snapshot logger.info(f"Launching a new instance from snapshot {instance_snapshot}...") # run_instances_params = { # "MaxCount": 1, # "MinCount": 1, # "ImageId": instance_snapshot, # "InstanceType": instance_type, # "EbsOptimized": True, # "NetworkInterfaces": [ # { # "SubnetId": subnet_id, # "AssociatePublicIpAddress": True, # "DeviceIndex": 0, # "Groups": security_groups # } # ], # "BlockDeviceMappings":[ # { # "Ebs": { # "VolumeSize": 30, # "VolumeType": "gp3" # }, # }, # ], # } new_instance = ec2_client.run_instances( MaxCount = 1, MinCount = 1, ImageId = instance_snapshot, InstanceType = instance_type, EbsOptimized = True, NetworkInterfaces = [ { "SubnetId": subnet_id, "AssociatePublicIpAddress": True, "DeviceIndex": 0, "Groups": security_groups } ] ) new_instance_id = new_instance['Instances'][0]['InstanceId'] logger.info(f"New instance {new_instance_id} launched from snapshot {snapshot_name}.") logger.info(f"Waiting for instance {new_instance_id} to be running...") ec2_client.get_waiter('instance_running').wait(InstanceIds=[new_instance_id]) # wait 60 seconds for the instance to be ready time.sleep(60) logger.info(f"Instance {new_instance_id} is ready.") return new_instance_id except ClientError as e: logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}") raise # # Step 1: Retrieve the original instance details # instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm]) # instance = instance_details['Reservations'][0]['Instances'][0] # security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']] # #subnet_id = instance['SubnetId'] # #TODO:instance_type = instance['InstanceType'] # instance_type = 't3.large' # instance_snapshot = instance_details['Reservations'][0]['Instances'][0]['ImageId'] # # TODO:Step 2: Terminate the old instance # if not path_to_vm == 'i-00017dfb534d22011': # ec2_client.terminate_instances(InstanceIds=[path_to_vm]) # logger.info(f"Old instance {path_to_vm} has been terminated.") # # Step 3: Launch a new instance from the snapshot # logger.info(f"Launching a new instance from snapshot {instance_snapshot}...") # timestamp_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # instance_name = "/dev/sda1" # new_instance = ec2_client.run_instances( # BlockDeviceMappings = [ # { # "Ebs": { # "VolumeSize": 30, # "VolumeType": "gp3" # }, # 'DeviceName':instance_name, # }, # ], # MaxCount = 1, # MinCount = 1, # ImageId = instance_snapshot, # InstanceType = instance_type, # EbsOptimized = True, # NetworkInterfaces = [ # { # "AssociatePublicIpAddress": True, # "DeviceIndex": 0, # "Groups": security_groups # } # ] # ) # '''NetworkInterfaces = [ # { # "SubnetId": subnet_id, # "AssociatePublicIpAddress": True, # "DeviceIndex": 0, # "Groups": security_groups # } # ]''' # new_instance_id = new_instance['Instances'][0]['InstanceId'] # logger.info(f"New instance {new_instance_id} launched from snapshot {snapshot_name}.") # logger.info(f"Waiting for instance {new_instance_id} to be running...") # ec2_client.get_waiter('instance_running').wait(InstanceIds=[new_instance_id]) # logger.info(f"Instance {new_instance_id} is ready.") # # # Step 4: set inbound rules # # # TODO: get host sg automatically # # host = ec2_client.describe_instances(InstanceIds=['i-027eab0d007b62793']) # # host_sg_id = host['Reservations'][0]['Instances'][0]['SecurityGroups'][0]['GroupId'] # # vm_sg_id = new_instance['Instances'][0]['SecurityGroups'][0]['GroupId'] # # # add inbound rules to the host security group # # try: # # host.authorize_security_group_ingress( # # GroupId= host_sg_id, # # IpPermissions=[ # # { # # "IpProtocol": "tcp", # # "FromPort": 5000, # # "ToPort": 5000, # # "UserIdGroupPairs": [ # # { # # "GroupId": vm_sg_id # # } # # ] # # } # # ] # # ) # # print(f"Port 5000 opened on {host_sg_id} for {vm_sg_id}") # # except ClientError as e: # # if "InvalidPermission.Duplicate" in str(e): # # print(f"Rule already exists on {host_sg_id}") # # else: # # print(f"Error updating {host_sg_id}: {e}") # # # add inbound rules to the new instance security group # # try: # # new_instance.authorize_security_group_ingress( # # GroupId= new_instance_id, # # IpPermissions=[ # # { # # "IpProtocol": "tcp", # # "FromPort": 6000, # # "ToPort": 6000, # # "UserIdGroupPairs": [ # # { # # "GroupId": host_sg_id # # } # # ] # # } # # ] # # ) # # print(f"Port 6000 opened on {new_instance_id} for {host_sg_id}") # # except ClientError as e: # # if "InvalidPermission.Duplicate" in str(e): # # print(f"Rule already exists on {new_instance_id}") # # else: # # print(f"Error updating {new_instance_id}: {e}") # return new_instance_id def stop_emulator(self, path_to_vm, region=None): logger.info(f"Stopping AWS VM {path_to_vm}...") ec2_client = boto3.client('ec2', region_name=self.region) try: ec2_client.stop_instances(InstanceIds=[path_to_vm]) waiter = ec2_client.get_waiter('instance_stopped') waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS}) logger.info(f"Instance {path_to_vm} has been stopped.") except ClientError as e: logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}") raise