AWS Enhancement (#48)
* Fix the path registration after reset * image id * Handle lock stuff * Debug * Debug * Update * Update * Update * Update * Update * Reorder the reset steps * Reorder the reset steps * Reorder the reset steps * Finish and clean --------- Co-authored-by: XinyuanWangCS <xywang626@gmail.com>
This commit is contained in:
@@ -52,6 +52,7 @@ class DesktopEnv(gym.Env):
|
|||||||
require_terminal (bool): whether to require terminal output
|
require_terminal (bool): whether to require terminal output
|
||||||
"""
|
"""
|
||||||
# Initialize VM manager and vitualization provider
|
# Initialize VM manager and vitualization provider
|
||||||
|
self.region = region
|
||||||
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
|
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
|
||||||
|
|
||||||
# Initialize environment variables
|
# Initialize environment variables
|
||||||
@@ -95,7 +96,13 @@ class DesktopEnv(gym.Env):
|
|||||||
def _revert_to_snapshot(self):
|
def _revert_to_snapshot(self):
|
||||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||||
# due to the fact it could be changed when implemented by cloud services
|
# due to the fact it could be changed when implemented by cloud services
|
||||||
self.path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
|
path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
|
||||||
|
if path_to_vm and not path_to_vm == self.path_to_vm:
|
||||||
|
# path_to_vm has to be a new path
|
||||||
|
self.manager.delete_vm(self.path_to_vm, self.region)
|
||||||
|
self.manager.add_vm(path_to_vm, self.region)
|
||||||
|
self.manager.occupy_vm(path_to_vm, os.getpid(), self.region)
|
||||||
|
self.path_to_vm = path_to_vm
|
||||||
|
|
||||||
def _save_state(self, snapshot_name=None):
|
def _save_state(self, snapshot_name=None):
|
||||||
# Save the current virtual machine state to a certain snapshot name
|
# Save the current virtual machine state to a certain snapshot name
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import os
|
|||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
import boto3
|
import boto3
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from multiprocessing import Manager
|
||||||
|
|
||||||
from desktop_env.providers.base import VMManager
|
from desktop_env.providers.base import VMManager
|
||||||
|
|
||||||
@@ -16,11 +16,12 @@ DEFAULT_REGION = "us-east-1"
|
|||||||
# todo: Add doc for the configuration of image, security group and network interface
|
# todo: Add doc for the configuration of image, security group and network interface
|
||||||
# todo: public the AMI images
|
# todo: public the AMI images
|
||||||
IMAGE_ID_MAP = {
|
IMAGE_ID_MAP = {
|
||||||
"us-east-1": "ami-0b0531325a0d5d488",
|
"us-east-1": "ami-019f92c05df45031b",
|
||||||
"ap-east-1": "ami-0b92a0bf157fecaa9"
|
"ap-east-1": "ami-07b4956131da1b282"
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANCE_TYPE = "t3.large"
|
INSTANCE_TYPE = "t3.medium"
|
||||||
|
|
||||||
NETWORK_INTERFACE_MAP = {
|
NETWORK_INTERFACE_MAP = {
|
||||||
"us-east-1": [
|
"us-east-1": [
|
||||||
{
|
{
|
||||||
@@ -33,14 +34,14 @@ NETWORK_INTERFACE_MAP = {
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"ap-east-1": [
|
"ap-east-1": [
|
||||||
{
|
{
|
||||||
"SubnetId": "subnet-011060501be0b589c",
|
"SubnetId": "subnet-011060501be0b589c",
|
||||||
"AssociatePublicIpAddress": True,
|
"AssociatePublicIpAddress": True,
|
||||||
"DeviceIndex": 0,
|
"DeviceIndex": 0,
|
||||||
"Groups": [
|
"Groups": [
|
||||||
"sg-090470e64df78f6eb"
|
"sg-090470e64df78f6eb"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,17 +61,18 @@ def _allocate_vm(region=DEFAULT_REGION):
|
|||||||
instance_id = response['Instances'][0]['InstanceId']
|
instance_id = response['Instances'][0]['InstanceId']
|
||||||
logger.info(f"Waiting for instance {instance_id} to be running...")
|
logger.info(f"Waiting for instance {instance_id} to be running...")
|
||||||
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
|
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
|
||||||
logger.info(f"Waiting for instance {instance_id} status checks to pass...")
|
|
||||||
ec2_client.get_waiter('instance_status_ok').wait(InstanceIds=[instance_id])
|
|
||||||
logger.info(f"Instance {instance_id} is ready.")
|
logger.info(f"Instance {instance_id} is ready.")
|
||||||
|
|
||||||
return instance_id
|
return instance_id
|
||||||
|
|
||||||
|
|
||||||
class AWSVMManager(VMManager):
|
class AWSVMManager(VMManager):
|
||||||
|
manager = Manager()
|
||||||
|
check_and_clean_event = manager.Event()
|
||||||
|
|
||||||
def __init__(self, registry_path=REGISTRY_PATH):
|
def __init__(self, registry_path=REGISTRY_PATH):
|
||||||
self.registry_path = registry_path
|
self.registry_path = registry_path
|
||||||
self.lock = FileLock(".aws_lck", timeout=10)
|
self.lock = FileLock(".aws_lck", timeout=60)
|
||||||
self.initialize_registry()
|
self.initialize_registry()
|
||||||
|
|
||||||
def initialize_registry(self):
|
def initialize_registry(self):
|
||||||
@@ -79,101 +81,183 @@ class AWSVMManager(VMManager):
|
|||||||
with open(self.registry_path, 'w') as file:
|
with open(self.registry_path, 'w') as file:
|
||||||
file.write('')
|
file.write('')
|
||||||
|
|
||||||
def add_vm(self, vm_path, region=DEFAULT_REGION):
|
def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
|
||||||
with self.lock:
|
if lock_needed:
|
||||||
with open(self.registry_path, 'r') as file:
|
with self.lock:
|
||||||
lines = file.readlines()
|
self._add_vm(vm_path, region)
|
||||||
vm_path_at_vm_region = "{}@{}".format(vm_path, region)
|
else:
|
||||||
new_lines = lines + [f'{vm_path_at_vm_region}|free\n']
|
self._add_vm(vm_path, region)
|
||||||
with open(self.registry_path, 'w') as file:
|
|
||||||
file.writelines(new_lines)
|
|
||||||
|
|
||||||
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
|
def _add_vm(self, vm_path, region=DEFAULT_REGION):
|
||||||
with self.lock:
|
with open(self.registry_path, 'r') as file:
|
||||||
new_lines = []
|
lines = file.readlines()
|
||||||
with open(self.registry_path, 'r') as file:
|
vm_path_at_vm_region = "{}@{}".format(vm_path, region)
|
||||||
lines = file.readlines()
|
new_lines = lines + [f'{vm_path_at_vm_region}|free\n']
|
||||||
for line in lines:
|
with open(self.registry_path, 'w') as file:
|
||||||
registered_vm_path, _ = line.strip().split('|')
|
file.writelines(new_lines)
|
||||||
if registered_vm_path == "{}@{}".format(vm_path, region):
|
|
||||||
new_lines.append(f'{registered_vm_path}|{pid}\n')
|
|
||||||
else:
|
|
||||||
new_lines.append(line)
|
|
||||||
with open(self.registry_path, 'w') as file:
|
|
||||||
file.writelines(new_lines)
|
|
||||||
|
|
||||||
def check_and_clean(self):
|
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
|
||||||
with self.lock: # Lock when cleaning up the registry and vms_dir
|
if lock_needed:
|
||||||
# Check and clean on the running vms, detect the released ones and mark then as 'free'
|
with self.lock:
|
||||||
active_pids = {p.pid for p in psutil.process_iter()}
|
self._delete_vm(vm_path, region)
|
||||||
new_lines = []
|
else:
|
||||||
vm_path_at_vm_regions = []
|
self._delete_vm(vm_path, region)
|
||||||
|
|
||||||
with open(self.registry_path, 'r') as file:
|
def _delete_vm(self, vm_path, region=DEFAULT_REGION):
|
||||||
lines = file.readlines()
|
new_lines = []
|
||||||
for line in lines:
|
with open(self.registry_path, 'r') as file:
|
||||||
vm_path_at_vm_region, pid_str = line.strip().split('|')
|
lines = file.readlines()
|
||||||
vm_path, vm_region = vm_path_at_vm_region.split("@")
|
for line in lines:
|
||||||
ec2_client = boto3.client('ec2', region_name=vm_region)
|
vm_path_at_vm_region, pid_str = line.strip().split('|')
|
||||||
|
if vm_path_at_vm_region == "{}@{}".format(vm_path, region):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
new_lines.append(line)
|
||||||
|
with open(self.registry_path, 'w') as file:
|
||||||
|
file.writelines(new_lines)
|
||||||
|
|
||||||
try:
|
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
|
||||||
response = ec2_client.describe_instances(InstanceIds=[vm_path])
|
if lock_needed:
|
||||||
if not response['Reservations'] or response['Reservations'][0]['Instances'][0]['State'][
|
with self.lock:
|
||||||
'Name'] in ['terminated', 'shutting-down']:
|
self._occupy_vm(vm_path, pid, region)
|
||||||
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
|
else:
|
||||||
continue
|
self._occupy_vm(vm_path, pid, region)
|
||||||
elif response['Reservations'][0]['Instances'][0]['State'][
|
|
||||||
'Name'] == "Stopped":
|
|
||||||
logger.info(f"VM {vm_path} stopped, mark it as free")
|
|
||||||
new_lines.append(f'{vm_path}@{vm_region}|free\n')
|
|
||||||
continue
|
|
||||||
except ec2_client.exceptions.ClientError as e:
|
|
||||||
if 'InvalidInstanceID.NotFound' in str(e):
|
|
||||||
logger.info(f"VM {vm_path} not found, releasing it.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
vm_path_at_vm_regions.append(vm_path_at_vm_region)
|
def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
|
||||||
if pid_str == "free":
|
new_lines = []
|
||||||
new_lines.append(line)
|
with open(self.registry_path, 'r') as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
for line in lines:
|
||||||
|
registered_vm_path, _ = line.strip().split('|')
|
||||||
|
if registered_vm_path == "{}@{}".format(vm_path, region):
|
||||||
|
new_lines.append(f'{registered_vm_path}|{pid}\n')
|
||||||
|
else:
|
||||||
|
new_lines.append(line)
|
||||||
|
with open(self.registry_path, 'w') as file:
|
||||||
|
file.writelines(new_lines)
|
||||||
|
|
||||||
|
def check_and_clean(self, lock_needed=True):
|
||||||
|
if lock_needed:
|
||||||
|
with self.lock:
|
||||||
|
self._check_and_clean()
|
||||||
|
else:
|
||||||
|
self._check_and_clean()
|
||||||
|
|
||||||
|
def _check_and_clean(self):
|
||||||
|
# Get active PIDs
|
||||||
|
active_pids = {p.pid for p in psutil.process_iter()}
|
||||||
|
|
||||||
|
new_lines = []
|
||||||
|
vm_path_at_vm_regions = {}
|
||||||
|
|
||||||
|
with open(self.registry_path, 'r') as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
|
||||||
|
# Collect all VM paths and their regions
|
||||||
|
for line in lines:
|
||||||
|
vm_path_at_vm_region, pid_str = line.strip().split('|')
|
||||||
|
vm_path, vm_region = vm_path_at_vm_region.split("@")
|
||||||
|
if vm_region not in vm_path_at_vm_regions:
|
||||||
|
vm_path_at_vm_regions[vm_region] = []
|
||||||
|
vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, pid_str))
|
||||||
|
|
||||||
|
# Process each region
|
||||||
|
for region, vm_info_list in vm_path_at_vm_regions.items():
|
||||||
|
ec2_client = boto3.client('ec2', region_name=region)
|
||||||
|
instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list]
|
||||||
|
|
||||||
|
# Batch describe instances
|
||||||
|
try:
|
||||||
|
response = ec2_client.describe_instances(InstanceIds=instance_ids)
|
||||||
|
reservations = response.get('Reservations', [])
|
||||||
|
|
||||||
|
terminated_ids = set()
|
||||||
|
stopped_ids = set()
|
||||||
|
active_ids = set()
|
||||||
|
|
||||||
|
# Collect states of all instances
|
||||||
|
for reservation in reservations:
|
||||||
|
for instance in reservation.get('Instances', []):
|
||||||
|
instance_id = instance.get('InstanceId')
|
||||||
|
instance_state = instance['State']['Name']
|
||||||
|
if instance_state in ['terminated', 'shutting-down']:
|
||||||
|
terminated_ids.add(instance_id)
|
||||||
|
elif instance_state == 'stopped':
|
||||||
|
stopped_ids.add(instance_id)
|
||||||
|
else:
|
||||||
|
active_ids.add(instance_id)
|
||||||
|
|
||||||
|
# Write results back to file
|
||||||
|
for vm_path_at_vm_region, pid_str in vm_info_list:
|
||||||
|
vm_path = vm_path_at_vm_region.split('@')[0]
|
||||||
|
|
||||||
|
if vm_path in terminated_ids:
|
||||||
|
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
|
||||||
|
continue
|
||||||
|
elif vm_path in stopped_ids:
|
||||||
|
logger.info(f"VM {vm_path} stopped, mark it as free")
|
||||||
|
new_lines.append(f'{vm_path}@{region}|free\n')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if int(pid_str) in active_pids:
|
if pid_str == "free":
|
||||||
new_lines.append(line)
|
new_lines.append(f'{vm_path}@{region}|{pid_str}\n')
|
||||||
|
elif int(pid_str) in active_pids:
|
||||||
|
new_lines.append(f'{vm_path}@{region}|{pid_str}\n')
|
||||||
else:
|
else:
|
||||||
new_lines.append(f'{vm_path_at_vm_region}|free\n')
|
new_lines.append(f'{vm_path}@{region}|free\n')
|
||||||
|
|
||||||
with open(self.registry_path, 'w') as file:
|
except ec2_client.exceptions.ClientError as e:
|
||||||
file.writelines(new_lines)
|
if 'InvalidInstanceID.NotFound' in str(e):
|
||||||
|
logger.info(f"VM not found, releasing instances in region {region}.")
|
||||||
|
continue
|
||||||
|
|
||||||
# We won't check and clean on the files on aws and delete the unregistered ones
|
# Writing updated lines back to the registry file
|
||||||
# Since this can lead to unexpected delete on other server
|
with open(self.registry_path, 'w') as file:
|
||||||
# PLease do monitor the instances to avoid additional cost
|
file.writelines(new_lines)
|
||||||
|
|
||||||
def list_free_vms(self, region=DEFAULT_REGION):
|
# We won't check and clean on the files on aws and delete the unregistered ones
|
||||||
with self.lock: # Lock when reading the registry
|
# Since this can lead to unexpected delete on other server
|
||||||
free_vms = []
|
# PLease do monitor the instances to avoid additional cost
|
||||||
with open(self.registry_path, 'r') as file:
|
|
||||||
lines = file.readlines()
|
|
||||||
for line in lines:
|
|
||||||
vm_path_at_vm_region, pid_str = line.strip().split('|')
|
|
||||||
vm_path, vm_region = vm_path_at_vm_region.split("@")
|
|
||||||
if pid_str == "free" and vm_region == region:
|
|
||||||
free_vms.append((vm_path, pid_str))
|
|
||||||
|
|
||||||
return free_vms
|
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
|
||||||
|
if lock_needed:
|
||||||
|
with self.lock:
|
||||||
|
return self._list_free_vms(region)
|
||||||
|
else:
|
||||||
|
return self._list_free_vms(region)
|
||||||
|
|
||||||
|
def _list_free_vms(self, region=DEFAULT_REGION):
|
||||||
|
free_vms = []
|
||||||
|
with open(self.registry_path, 'r') as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
for line in lines:
|
||||||
|
vm_path_at_vm_region, pid_str = line.strip().split('|')
|
||||||
|
vm_path, vm_region = vm_path_at_vm_region.split("@")
|
||||||
|
if pid_str == "free" and vm_region == region:
|
||||||
|
free_vms.append((vm_path, pid_str))
|
||||||
|
|
||||||
|
return free_vms
|
||||||
|
|
||||||
def get_vm_path(self, region=DEFAULT_REGION):
|
def get_vm_path(self, region=DEFAULT_REGION):
|
||||||
self.check_and_clean()
|
with self.lock:
|
||||||
free_vms_paths = self.list_free_vms(region)
|
if not AWSVMManager.check_and_clean_event.is_set():
|
||||||
|
AWSVMManager.check_and_clean_event.set()
|
||||||
|
self._check_and_clean()
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
free_vms_paths = self._list_free_vms(region)
|
||||||
|
|
||||||
if len(free_vms_paths) == 0:
|
if len(free_vms_paths) == 0:
|
||||||
# No free virtual machine available, generate a new one
|
# No free virtual machine available, generate a new one
|
||||||
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
|
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
|
||||||
new_vm_path = _allocate_vm(region)
|
new_vm_path = _allocate_vm(region)
|
||||||
self.add_vm(new_vm_path, region)
|
with self.lock:
|
||||||
self.occupy_vm(new_vm_path, os.getpid(), region)
|
self._add_vm(new_vm_path, region)
|
||||||
|
self._occupy_vm(new_vm_path, os.getpid(), region)
|
||||||
return new_vm_path
|
return new_vm_path
|
||||||
else:
|
else:
|
||||||
# Choose the first free virtual machine
|
# Choose the first free virtual machine
|
||||||
chosen_vm_path = free_vms_paths[0][0]
|
chosen_vm_path = free_vms_paths[0][0]
|
||||||
self.occupy_vm(chosen_vm_path, os.getpid(), region)
|
with self.lock:
|
||||||
|
self._occupy_vm(chosen_vm_path, os.getpid(), region)
|
||||||
return chosen_vm_path
|
return chosen_vm_path
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from botocore.exceptions import ClientError
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .manager import INSTANCE_TYPE
|
|
||||||
from desktop_env.providers.base import Provider
|
from desktop_env.providers.base import Provider
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.providers.aws.AWSProvider")
|
logger = logging.getLogger("desktopenv.providers.aws.AWSProvider")
|
||||||
@@ -73,7 +72,11 @@ class AWSProvider(Provider):
|
|||||||
subnet_id = instance['SubnetId']
|
subnet_id = instance['SubnetId']
|
||||||
instance_type = instance['InstanceType']
|
instance_type = instance['InstanceType']
|
||||||
|
|
||||||
# Step 2: Launch a new instance from the snapshot
|
# Step 2: Terminate the old instance
|
||||||
|
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
|
||||||
|
logger.info(f"Old instance {path_to_vm} has been terminated.")
|
||||||
|
|
||||||
|
# Step 3: Launch a new instance from the snapshot
|
||||||
logger.info(f"Launching a new instance from snapshot {snapshot_name}...")
|
logger.info(f"Launching a new instance from snapshot {snapshot_name}...")
|
||||||
|
|
||||||
run_instances_params = {
|
run_instances_params = {
|
||||||
@@ -97,14 +100,8 @@ class AWSProvider(Provider):
|
|||||||
logger.info(f"New instance {new_instance_id} launched from snapshot {snapshot_name}.")
|
logger.info(f"New instance {new_instance_id} launched from snapshot {snapshot_name}.")
|
||||||
logger.info(f"Waiting for instance {new_instance_id} to be running...")
|
logger.info(f"Waiting for instance {new_instance_id} to be running...")
|
||||||
ec2_client.get_waiter('instance_running').wait(InstanceIds=[new_instance_id])
|
ec2_client.get_waiter('instance_running').wait(InstanceIds=[new_instance_id])
|
||||||
logger.info(f"Waiting for instance {new_instance_id} status checks to pass...")
|
|
||||||
ec2_client.get_waiter('instance_status_ok').wait(InstanceIds=[new_instance_id])
|
|
||||||
logger.info(f"Instance {new_instance_id} is ready.")
|
logger.info(f"Instance {new_instance_id} is ready.")
|
||||||
|
|
||||||
# Step 3: Terminate the old instance
|
|
||||||
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
|
|
||||||
logger.info(f"Old instance {path_to_vm} has been terminated.")
|
|
||||||
|
|
||||||
return new_instance_id
|
return new_instance_id
|
||||||
|
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
|
|||||||
@@ -60,6 +60,13 @@ class VMManager(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_vm(self, vm_path, **kwargs):
|
||||||
|
"""
|
||||||
|
Delete the registration of VM by path.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def occupy_vm(self, vm_path, pid, **kwargs):
|
def occupy_vm(self, vm_path, pid, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user