AWS Enhancement (#48)

* Fix the path registration after reset

* image id

* Handle lock stuff

* Debug

* Debug

* Update

* Update

* Update

* Update

* Update

* Reorder the reset steps

* Reorder the reset steps

* Reorder the reset steps

* Finish and clean

---------

Co-authored-by: XinyuanWangCS <xywang626@gmail.com>
This commit is contained in:
Tianbao Xie
2024-06-20 19:03:02 +08:00
committed by GitHub
parent 536c92b0ce
commit b4901cdad0
4 changed files with 194 additions and 99 deletions

View File

@@ -52,6 +52,7 @@ class DesktopEnv(gym.Env):
require_terminal (bool): whether to require terminal output
"""
# Initialize VM manager and vitualization provider
self.region = region
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
# Initialize environment variables
@@ -95,7 +96,13 @@ class DesktopEnv(gym.Env):
def _revert_to_snapshot(self):
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
# due to the fact it could be changed when implemented by cloud services
self.path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
if path_to_vm and not path_to_vm == self.path_to_vm:
# path_to_vm has to be a new path
self.manager.delete_vm(self.path_to_vm, self.region)
self.manager.add_vm(path_to_vm, self.region)
self.manager.occupy_vm(path_to_vm, os.getpid(), self.region)
self.path_to_vm = path_to_vm
def _save_state(self, snapshot_name=None):
# Save the current virtual machine state to a certain snapshot name

View File

@@ -2,8 +2,8 @@ import os
from filelock import FileLock
import boto3
import psutil
import logging
from multiprocessing import Manager
from desktop_env.providers.base import VMManager
@@ -16,11 +16,12 @@ DEFAULT_REGION = "us-east-1"
# todo: Add doc for the configuration of image, security group and network interface
# todo: public the AMI images
IMAGE_ID_MAP = {
"us-east-1": "ami-0b0531325a0d5d488",
"ap-east-1": "ami-0b92a0bf157fecaa9"
"us-east-1": "ami-019f92c05df45031b",
"ap-east-1": "ami-07b4956131da1b282"
}
INSTANCE_TYPE = "t3.large"
INSTANCE_TYPE = "t3.medium"
NETWORK_INTERFACE_MAP = {
"us-east-1": [
{
@@ -33,14 +34,14 @@ NETWORK_INTERFACE_MAP = {
}
],
"ap-east-1": [
{
"SubnetId": "subnet-011060501be0b589c",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-090470e64df78f6eb"
]
}
{
"SubnetId": "subnet-011060501be0b589c",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-090470e64df78f6eb"
]
}
]
}
@@ -60,17 +61,18 @@ def _allocate_vm(region=DEFAULT_REGION):
instance_id = response['Instances'][0]['InstanceId']
logger.info(f"Waiting for instance {instance_id} to be running...")
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
logger.info(f"Waiting for instance {instance_id} status checks to pass...")
ec2_client.get_waiter('instance_status_ok').wait(InstanceIds=[instance_id])
logger.info(f"Instance {instance_id} is ready.")
return instance_id
class AWSVMManager(VMManager):
manager = Manager()
check_and_clean_event = manager.Event()
def __init__(self, registry_path=REGISTRY_PATH):
self.registry_path = registry_path
self.lock = FileLock(".aws_lck", timeout=10)
self.lock = FileLock(".aws_lck", timeout=60)
self.initialize_registry()
def initialize_registry(self):
@@ -79,101 +81,183 @@ class AWSVMManager(VMManager):
with open(self.registry_path, 'w') as file:
file.write('')
def add_vm(self, vm_path, region=DEFAULT_REGION):
with self.lock:
with open(self.registry_path, 'r') as file:
lines = file.readlines()
vm_path_at_vm_region = "{}@{}".format(vm_path, region)
new_lines = lines + [f'{vm_path_at_vm_region}|free\n']
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._add_vm(vm_path, region)
else:
self._add_vm(vm_path, region)
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
with self.lock:
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
registered_vm_path, _ = line.strip().split('|')
if registered_vm_path == "{}@{}".format(vm_path, region):
new_lines.append(f'{registered_vm_path}|{pid}\n')
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def _add_vm(self, vm_path, region=DEFAULT_REGION):
with open(self.registry_path, 'r') as file:
lines = file.readlines()
vm_path_at_vm_region = "{}@{}".format(vm_path, region)
new_lines = lines + [f'{vm_path_at_vm_region}|free\n']
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def check_and_clean(self):
with self.lock: # Lock when cleaning up the registry and vms_dir
# Check and clean on the running vms, detect the released ones and mark then as 'free'
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_path_at_vm_regions = []
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._delete_vm(vm_path, region)
else:
self._delete_vm(vm_path, region)
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path_at_vm_region, pid_str = line.strip().split('|')
vm_path, vm_region = vm_path_at_vm_region.split("@")
ec2_client = boto3.client('ec2', region_name=vm_region)
def _delete_vm(self, vm_path, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path_at_vm_region, pid_str = line.strip().split('|')
if vm_path_at_vm_region == "{}@{}".format(vm_path, region):
continue
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
try:
response = ec2_client.describe_instances(InstanceIds=[vm_path])
if not response['Reservations'] or response['Reservations'][0]['Instances'][0]['State'][
'Name'] in ['terminated', 'shutting-down']:
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
continue
elif response['Reservations'][0]['Instances'][0]['State'][
'Name'] == "Stopped":
logger.info(f"VM {vm_path} stopped, mark it as free")
new_lines.append(f'{vm_path}@{vm_region}|free\n')
continue
except ec2_client.exceptions.ClientError as e:
if 'InvalidInstanceID.NotFound' in str(e):
logger.info(f"VM {vm_path} not found, releasing it.")
continue
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._occupy_vm(vm_path, pid, region)
else:
self._occupy_vm(vm_path, pid, region)
vm_path_at_vm_regions.append(vm_path_at_vm_region)
if pid_str == "free":
new_lines.append(line)
def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
registered_vm_path, _ = line.strip().split('|')
if registered_vm_path == "{}@{}".format(vm_path, region):
new_lines.append(f'{registered_vm_path}|{pid}\n')
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def check_and_clean(self, lock_needed=True):
if lock_needed:
with self.lock:
self._check_and_clean()
else:
self._check_and_clean()
def _check_and_clean(self):
# Get active PIDs
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_path_at_vm_regions = {}
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# Collect all VM paths and their regions
for line in lines:
vm_path_at_vm_region, pid_str = line.strip().split('|')
vm_path, vm_region = vm_path_at_vm_region.split("@")
if vm_region not in vm_path_at_vm_regions:
vm_path_at_vm_regions[vm_region] = []
vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, pid_str))
# Process each region
for region, vm_info_list in vm_path_at_vm_regions.items():
ec2_client = boto3.client('ec2', region_name=region)
instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list]
# Batch describe instances
try:
response = ec2_client.describe_instances(InstanceIds=instance_ids)
reservations = response.get('Reservations', [])
terminated_ids = set()
stopped_ids = set()
active_ids = set()
# Collect states of all instances
for reservation in reservations:
for instance in reservation.get('Instances', []):
instance_id = instance.get('InstanceId')
instance_state = instance['State']['Name']
if instance_state in ['terminated', 'shutting-down']:
terminated_ids.add(instance_id)
elif instance_state == 'stopped':
stopped_ids.add(instance_id)
else:
active_ids.add(instance_id)
# Write results back to file
for vm_path_at_vm_region, pid_str in vm_info_list:
vm_path = vm_path_at_vm_region.split('@')[0]
if vm_path in terminated_ids:
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
continue
elif vm_path in stopped_ids:
logger.info(f"VM {vm_path} stopped, mark it as free")
new_lines.append(f'{vm_path}@{region}|free\n')
continue
if int(pid_str) in active_pids:
new_lines.append(line)
if pid_str == "free":
new_lines.append(f'{vm_path}@{region}|{pid_str}\n')
elif int(pid_str) in active_pids:
new_lines.append(f'{vm_path}@{region}|{pid_str}\n')
else:
new_lines.append(f'{vm_path_at_vm_region}|free\n')
new_lines.append(f'{vm_path}@{region}|free\n')
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
except ec2_client.exceptions.ClientError as e:
if 'InvalidInstanceID.NotFound' in str(e):
logger.info(f"VM not found, releasing instances in region {region}.")
continue
# We won't check and clean on the files on aws and delete the unregistered ones
# Since this can lead to unexpected delete on other server
# PLease do monitor the instances to avoid additional cost
# Writing updated lines back to the registry file
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def list_free_vms(self, region=DEFAULT_REGION):
with self.lock: # Lock when reading the registry
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path_at_vm_region, pid_str = line.strip().split('|')
vm_path, vm_region = vm_path_at_vm_region.split("@")
if pid_str == "free" and vm_region == region:
free_vms.append((vm_path, pid_str))
# We won't check and clean on the files on aws and delete the unregistered ones
# Since this can lead to unexpected delete on other server
# PLease do monitor the instances to avoid additional cost
return free_vms
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
return self._list_free_vms(region)
else:
return self._list_free_vms(region)
def _list_free_vms(self, region=DEFAULT_REGION):
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path_at_vm_region, pid_str = line.strip().split('|')
vm_path, vm_region = vm_path_at_vm_region.split("@")
if pid_str == "free" and vm_region == region:
free_vms.append((vm_path, pid_str))
return free_vms
def get_vm_path(self, region=DEFAULT_REGION):
self.check_and_clean()
free_vms_paths = self.list_free_vms(region)
with self.lock:
if not AWSVMManager.check_and_clean_event.is_set():
AWSVMManager.check_and_clean_event.set()
self._check_and_clean()
with self.lock:
free_vms_paths = self._list_free_vms(region)
if len(free_vms_paths) == 0:
# No free virtual machine available, generate a new one
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
new_vm_path = _allocate_vm(region)
self.add_vm(new_vm_path, region)
self.occupy_vm(new_vm_path, os.getpid(), region)
with self.lock:
self._add_vm(new_vm_path, region)
self._occupy_vm(new_vm_path, os.getpid(), region)
return new_vm_path
else:
# Choose the first free virtual machine
chosen_vm_path = free_vms_paths[0][0]
self.occupy_vm(chosen_vm_path, os.getpid(), region)
with self.lock:
self._occupy_vm(chosen_vm_path, os.getpid(), region)
return chosen_vm_path

View File

@@ -3,7 +3,6 @@ from botocore.exceptions import ClientError
import logging
from .manager import INSTANCE_TYPE
from desktop_env.providers.base import Provider
logger = logging.getLogger("desktopenv.providers.aws.AWSProvider")
@@ -73,7 +72,11 @@ class AWSProvider(Provider):
subnet_id = instance['SubnetId']
instance_type = instance['InstanceType']
# Step 2: Launch a new instance from the snapshot
# Step 2: Terminate the old instance
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
logger.info(f"Old instance {path_to_vm} has been terminated.")
# Step 3: Launch a new instance from the snapshot
logger.info(f"Launching a new instance from snapshot {snapshot_name}...")
run_instances_params = {
@@ -97,14 +100,8 @@ class AWSProvider(Provider):
logger.info(f"New instance {new_instance_id} launched from snapshot {snapshot_name}.")
logger.info(f"Waiting for instance {new_instance_id} to be running...")
ec2_client.get_waiter('instance_running').wait(InstanceIds=[new_instance_id])
logger.info(f"Waiting for instance {new_instance_id} status checks to pass...")
ec2_client.get_waiter('instance_status_ok').wait(InstanceIds=[new_instance_id])
logger.info(f"Instance {new_instance_id} is ready.")
# Step 3: Terminate the old instance
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
logger.info(f"Old instance {path_to_vm} has been terminated.")
return new_instance_id
except ClientError as e:

View File

@@ -60,6 +60,13 @@ class VMManager(ABC):
"""
pass
@abstractmethod
def delete_vm(self, vm_path, **kwargs):
"""
Delete the registration of VM by path.
"""
pass
@abstractmethod
def occupy_vm(self, vm_path, pid, **kwargs):
"""