diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 108d24f..31bc558 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -52,6 +52,7 @@ class DesktopEnv(gym.Env): require_terminal (bool): whether to require terminal output """ # Initialize VM manager and vitualization provider + self.region = region self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) # Initialize environment variables @@ -95,7 +96,13 @@ class DesktopEnv(gym.Env): def _revert_to_snapshot(self): # Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm # due to the fact it could be changed when implemented by cloud services - self.path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name) + path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name) + if path_to_vm and not path_to_vm == self.path_to_vm: + # path_to_vm has to be a new path + self.manager.delete_vm(self.path_to_vm, self.region) + self.manager.add_vm(path_to_vm, self.region) + self.manager.occupy_vm(path_to_vm, os.getpid(), self.region) + self.path_to_vm = path_to_vm def _save_state(self, snapshot_name=None): # Save the current virtual machine state to a certain snapshot name diff --git a/desktop_env/providers/aws/manager.py b/desktop_env/providers/aws/manager.py index fa427b8..041536d 100644 --- a/desktop_env/providers/aws/manager.py +++ b/desktop_env/providers/aws/manager.py @@ -2,8 +2,8 @@ import os from filelock import FileLock import boto3 import psutil - import logging +from multiprocessing import Manager from desktop_env.providers.base import VMManager @@ -16,11 +16,12 @@ DEFAULT_REGION = "us-east-1" # todo: Add doc for the configuration of image, security group and network interface # todo: public the AMI images IMAGE_ID_MAP = { - "us-east-1": "ami-0b0531325a0d5d488", - "ap-east-1": "ami-0b92a0bf157fecaa9" + "us-east-1": "ami-019f92c05df45031b", + "ap-east-1": "ami-07b4956131da1b282" } -INSTANCE_TYPE = "t3.large" +INSTANCE_TYPE = "t3.medium" + NETWORK_INTERFACE_MAP = { "us-east-1": [ { @@ -33,14 +34,14 @@ NETWORK_INTERFACE_MAP = { } ], "ap-east-1": [ - { - "SubnetId": "subnet-011060501be0b589c", - "AssociatePublicIpAddress": True, - "DeviceIndex": 0, - "Groups": [ - "sg-090470e64df78f6eb" - ] - } + { + "SubnetId": "subnet-011060501be0b589c", + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": [ + "sg-090470e64df78f6eb" + ] + } ] } @@ -60,17 +61,18 @@ def _allocate_vm(region=DEFAULT_REGION): instance_id = response['Instances'][0]['InstanceId'] logger.info(f"Waiting for instance {instance_id} to be running...") ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id]) - logger.info(f"Waiting for instance {instance_id} status checks to pass...") - ec2_client.get_waiter('instance_status_ok').wait(InstanceIds=[instance_id]) logger.info(f"Instance {instance_id} is ready.") return instance_id class AWSVMManager(VMManager): + manager = Manager() + check_and_clean_event = manager.Event() + def __init__(self, registry_path=REGISTRY_PATH): self.registry_path = registry_path - self.lock = FileLock(".aws_lck", timeout=10) + self.lock = FileLock(".aws_lck", timeout=60) self.initialize_registry() def initialize_registry(self): @@ -79,101 +81,183 @@ class AWSVMManager(VMManager): with open(self.registry_path, 'w') as file: file.write('') - def add_vm(self, vm_path, region=DEFAULT_REGION): - with self.lock: - with open(self.registry_path, 'r') as file: - lines = file.readlines() - vm_path_at_vm_region = "{}@{}".format(vm_path, region) - new_lines = lines + [f'{vm_path_at_vm_region}|free\n'] - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) + def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + self._add_vm(vm_path, region) + else: + self._add_vm(vm_path, region) - def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION): - with self.lock: - new_lines = [] - with open(self.registry_path, 'r') as file: - lines = file.readlines() - for line in lines: - registered_vm_path, _ = line.strip().split('|') - if registered_vm_path == "{}@{}".format(vm_path, region): - new_lines.append(f'{registered_vm_path}|{pid}\n') - else: - new_lines.append(line) - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) + def _add_vm(self, vm_path, region=DEFAULT_REGION): + with open(self.registry_path, 'r') as file: + lines = file.readlines() + vm_path_at_vm_region = "{}@{}".format(vm_path, region) + new_lines = lines + [f'{vm_path_at_vm_region}|free\n'] + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) - def check_and_clean(self): - with self.lock: # Lock when cleaning up the registry and vms_dir - # Check and clean on the running vms, detect the released ones and mark then as 'free' - active_pids = {p.pid for p in psutil.process_iter()} - new_lines = [] - vm_path_at_vm_regions = [] + def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + self._delete_vm(vm_path, region) + else: + self._delete_vm(vm_path, region) - with open(self.registry_path, 'r') as file: - lines = file.readlines() - for line in lines: - vm_path_at_vm_region, pid_str = line.strip().split('|') - vm_path, vm_region = vm_path_at_vm_region.split("@") - ec2_client = boto3.client('ec2', region_name=vm_region) + def _delete_vm(self, vm_path, region=DEFAULT_REGION): + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path_at_vm_region, pid_str = line.strip().split('|') + if vm_path_at_vm_region == "{}@{}".format(vm_path, region): + continue + else: + new_lines.append(line) + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) - try: - response = ec2_client.describe_instances(InstanceIds=[vm_path]) - if not response['Reservations'] or response['Reservations'][0]['Instances'][0]['State'][ - 'Name'] in ['terminated', 'shutting-down']: - logger.info(f"VM {vm_path} not found or terminated, releasing it.") - continue - elif response['Reservations'][0]['Instances'][0]['State'][ - 'Name'] == "Stopped": - logger.info(f"VM {vm_path} stopped, mark it as free") - new_lines.append(f'{vm_path}@{vm_region}|free\n') - continue - except ec2_client.exceptions.ClientError as e: - if 'InvalidInstanceID.NotFound' in str(e): - logger.info(f"VM {vm_path} not found, releasing it.") - continue + def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + self._occupy_vm(vm_path, pid, region) + else: + self._occupy_vm(vm_path, pid, region) - vm_path_at_vm_regions.append(vm_path_at_vm_region) - if pid_str == "free": - new_lines.append(line) + def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION): + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + registered_vm_path, _ = line.strip().split('|') + if registered_vm_path == "{}@{}".format(vm_path, region): + new_lines.append(f'{registered_vm_path}|{pid}\n') + else: + new_lines.append(line) + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def check_and_clean(self, lock_needed=True): + if lock_needed: + with self.lock: + self._check_and_clean() + else: + self._check_and_clean() + + def _check_and_clean(self): + # Get active PIDs + active_pids = {p.pid for p in psutil.process_iter()} + + new_lines = [] + vm_path_at_vm_regions = {} + + with open(self.registry_path, 'r') as file: + lines = file.readlines() + + # Collect all VM paths and their regions + for line in lines: + vm_path_at_vm_region, pid_str = line.strip().split('|') + vm_path, vm_region = vm_path_at_vm_region.split("@") + if vm_region not in vm_path_at_vm_regions: + vm_path_at_vm_regions[vm_region] = [] + vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, pid_str)) + + # Process each region + for region, vm_info_list in vm_path_at_vm_regions.items(): + ec2_client = boto3.client('ec2', region_name=region) + instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list] + + # Batch describe instances + try: + response = ec2_client.describe_instances(InstanceIds=instance_ids) + reservations = response.get('Reservations', []) + + terminated_ids = set() + stopped_ids = set() + active_ids = set() + + # Collect states of all instances + for reservation in reservations: + for instance in reservation.get('Instances', []): + instance_id = instance.get('InstanceId') + instance_state = instance['State']['Name'] + if instance_state in ['terminated', 'shutting-down']: + terminated_ids.add(instance_id) + elif instance_state == 'stopped': + stopped_ids.add(instance_id) + else: + active_ids.add(instance_id) + + # Write results back to file + for vm_path_at_vm_region, pid_str in vm_info_list: + vm_path = vm_path_at_vm_region.split('@')[0] + + if vm_path in terminated_ids: + logger.info(f"VM {vm_path} not found or terminated, releasing it.") + continue + elif vm_path in stopped_ids: + logger.info(f"VM {vm_path} stopped, mark it as free") + new_lines.append(f'{vm_path}@{region}|free\n') continue - if int(pid_str) in active_pids: - new_lines.append(line) + if pid_str == "free": + new_lines.append(f'{vm_path}@{region}|{pid_str}\n') + elif int(pid_str) in active_pids: + new_lines.append(f'{vm_path}@{region}|{pid_str}\n') else: - new_lines.append(f'{vm_path_at_vm_region}|free\n') + new_lines.append(f'{vm_path}@{region}|free\n') - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) + except ec2_client.exceptions.ClientError as e: + if 'InvalidInstanceID.NotFound' in str(e): + logger.info(f"VM not found, releasing instances in region {region}.") + continue - # We won't check and clean on the files on aws and delete the unregistered ones - # Since this can lead to unexpected delete on other server - # PLease do monitor the instances to avoid additional cost + # Writing updated lines back to the registry file + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) - def list_free_vms(self, region=DEFAULT_REGION): - with self.lock: # Lock when reading the registry - free_vms = [] - with open(self.registry_path, 'r') as file: - lines = file.readlines() - for line in lines: - vm_path_at_vm_region, pid_str = line.strip().split('|') - vm_path, vm_region = vm_path_at_vm_region.split("@") - if pid_str == "free" and vm_region == region: - free_vms.append((vm_path, pid_str)) + # We won't check and clean on the files on aws and delete the unregistered ones + # Since this can lead to unexpected delete on other server + # PLease do monitor the instances to avoid additional cost - return free_vms + def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + return self._list_free_vms(region) + else: + return self._list_free_vms(region) + + def _list_free_vms(self, region=DEFAULT_REGION): + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + vm_path_at_vm_region, pid_str = line.strip().split('|') + vm_path, vm_region = vm_path_at_vm_region.split("@") + if pid_str == "free" and vm_region == region: + free_vms.append((vm_path, pid_str)) + + return free_vms def get_vm_path(self, region=DEFAULT_REGION): - self.check_and_clean() - free_vms_paths = self.list_free_vms(region) + with self.lock: + if not AWSVMManager.check_and_clean_event.is_set(): + AWSVMManager.check_and_clean_event.set() + self._check_and_clean() + + with self.lock: + free_vms_paths = self._list_free_vms(region) + if len(free_vms_paths) == 0: # No free virtual machine available, generate a new one logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") new_vm_path = _allocate_vm(region) - self.add_vm(new_vm_path, region) - self.occupy_vm(new_vm_path, os.getpid(), region) + with self.lock: + self._add_vm(new_vm_path, region) + self._occupy_vm(new_vm_path, os.getpid(), region) return new_vm_path else: # Choose the first free virtual machine chosen_vm_path = free_vms_paths[0][0] - self.occupy_vm(chosen_vm_path, os.getpid(), region) + with self.lock: + self._occupy_vm(chosen_vm_path, os.getpid(), region) return chosen_vm_path diff --git a/desktop_env/providers/aws/provider.py b/desktop_env/providers/aws/provider.py index 2f0c5cb..c15d15b 100644 --- a/desktop_env/providers/aws/provider.py +++ b/desktop_env/providers/aws/provider.py @@ -3,7 +3,6 @@ from botocore.exceptions import ClientError import logging -from .manager import INSTANCE_TYPE from desktop_env.providers.base import Provider logger = logging.getLogger("desktopenv.providers.aws.AWSProvider") @@ -73,7 +72,11 @@ class AWSProvider(Provider): subnet_id = instance['SubnetId'] instance_type = instance['InstanceType'] - # Step 2: Launch a new instance from the snapshot + # Step 2: Terminate the old instance + ec2_client.terminate_instances(InstanceIds=[path_to_vm]) + logger.info(f"Old instance {path_to_vm} has been terminated.") + + # Step 3: Launch a new instance from the snapshot logger.info(f"Launching a new instance from snapshot {snapshot_name}...") run_instances_params = { @@ -97,14 +100,8 @@ class AWSProvider(Provider): logger.info(f"New instance {new_instance_id} launched from snapshot {snapshot_name}.") logger.info(f"Waiting for instance {new_instance_id} to be running...") ec2_client.get_waiter('instance_running').wait(InstanceIds=[new_instance_id]) - logger.info(f"Waiting for instance {new_instance_id} status checks to pass...") - ec2_client.get_waiter('instance_status_ok').wait(InstanceIds=[new_instance_id]) logger.info(f"Instance {new_instance_id} is ready.") - # Step 3: Terminate the old instance - ec2_client.terminate_instances(InstanceIds=[path_to_vm]) - logger.info(f"Old instance {path_to_vm} has been terminated.") - return new_instance_id except ClientError as e: diff --git a/desktop_env/providers/base.py b/desktop_env/providers/base.py index ac46215..5888dce 100644 --- a/desktop_env/providers/base.py +++ b/desktop_env/providers/base.py @@ -60,6 +60,13 @@ class VMManager(ABC): """ pass + @abstractmethod + def delete_vm(self, vm_path, **kwargs): + """ + Delete the registration of VM by path. + """ + pass + @abstractmethod def occupy_vm(self, vm_path, pid, **kwargs): """