import os from filelock import FileLock import boto3 import psutil import logging from desktop_env.providers.base import VMManager from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy") logger.setLevel(logging.INFO) REGISTRY_PATH = '.aws_vms_proxy' DEFAULT_REGION = "us-east-1" IMAGE_ID_MAP = { "us-east-1": "ami-05e7d7bd279ea4f14", "ap-east-1": "ami-0c092a5b8be4116f5" } INSTANCE_TYPE = "t3.medium" NETWORK_INTERFACE_MAP = { "us-east-1": [ { "SubnetId": "subnet-037edfff66c2eb894", "AssociatePublicIpAddress": True, "DeviceIndex": 0, "Groups": [ "sg-0342574803206ee9c" ] } ], "ap-east-1": [ { "SubnetId": "subnet-011060501be0b589c", "AssociatePublicIpAddress": True, "DeviceIndex": 0, "Groups": [ "sg-090470e64df78f6eb" ] } ] } def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None): """分配带有代理配置的VM""" from .provider_with_proxy import AWSProviderWithProxy # 初始化代理池(如果还没有初始化) if proxy_config_file: init_proxy_pool(proxy_config_file) # 获取当前代理 proxy_pool = get_global_proxy_pool() current_proxy = proxy_pool.get_next_proxy() if current_proxy: logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}") # 创建provider实例 provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file) # 创建新实例 instance_id = provider.create_instance_with_proxy( image_id=IMAGE_ID_MAP[region], instance_type=INSTANCE_TYPE, security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"], subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"] ) return instance_id class AWSVMManagerWithProxy(VMManager): def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None): self.registry_path = registry_path self.lock = FileLock(".aws_proxy_lck", timeout=60) self.proxy_config_file = proxy_config_file self.initialize_registry() # 初始化代理池 if proxy_config_file: init_proxy_pool(proxy_config_file) logger.info(f"Proxy pool initialized with config: {proxy_config_file}") def initialize_registry(self): with self.lock: if not os.path.exists(self.registry_path): with open(self.registry_path, 'w') as file: file.write('') def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True): if lock_needed: with self.lock: self._add_vm(vm_path, region, proxy_info) else: self._add_vm(vm_path, region, proxy_info) def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None): with open(self.registry_path, 'r') as file: lines = file.readlines() # 格式: vm_path@region|status|proxy_host:proxy_port vm_path_at_vm_region = f"{vm_path}@{region}" proxy_str = "" if proxy_info: proxy_str = f"{proxy_info['host']}:{proxy_info['port']}" new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n' new_lines = lines + [new_line] with open(self.registry_path, 'w') as file: file.writelines(new_lines) def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): if lock_needed: with self.lock: self._delete_vm(vm_path, region) else: self._delete_vm(vm_path, region) def _delete_vm(self, vm_path, region=DEFAULT_REGION): new_lines = [] with open(self.registry_path, 'r') as file: lines = file.readlines() for line in lines: parts = line.strip().split('|') if len(parts) >= 2: vm_path_at_vm_region = parts[0] if vm_path_at_vm_region == f"{vm_path}@{region}": continue new_lines.append(line) with open(self.registry_path, 'w') as file: file.writelines(new_lines) def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True): if lock_needed: with self.lock: self._occupy_vm(vm_path, pid, region) else: self._occupy_vm(vm_path, pid, region) def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION): new_lines = [] with open(self.registry_path, 'r') as file: lines = file.readlines() for line in lines: parts = line.strip().split('|') if len(parts) >= 2: registered_vm_path = parts[0] if registered_vm_path == f"{vm_path}@{region}": proxy_str = parts[2] if len(parts) > 2 else "" new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n') else: new_lines.append(line) else: new_lines.append(line) with open(self.registry_path, 'w') as file: file.writelines(new_lines) def check_and_clean(self, lock_needed=True): if lock_needed: with self.lock: self._check_and_clean() else: self._check_and_clean() def _check_and_clean(self): # Get active PIDs active_pids = {p.pid for p in psutil.process_iter()} new_lines = [] vm_path_at_vm_regions = {} with open(self.registry_path, 'r') as file: lines = file.readlines() # Collect all VM paths and their regions for line in lines: parts = line.strip().split('|') if len(parts) >= 2: vm_path_at_vm_region = parts[0] status = parts[1] proxy_str = parts[2] if len(parts) > 2 else "" vm_path, vm_region = vm_path_at_vm_region.split("@") if vm_region not in vm_path_at_vm_regions: vm_path_at_vm_regions[vm_region] = [] vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str)) # Process each region for region, vm_info_list in vm_path_at_vm_regions.items(): ec2_client = boto3.client('ec2', region_name=region) instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list] try: response = ec2_client.describe_instances(InstanceIds=instance_ids) reservations = response.get('Reservations', []) terminated_ids = set() stopped_ids = set() active_ids = set() for reservation in reservations: for instance in reservation.get('Instances', []): instance_id = instance.get('InstanceId') instance_state = instance['State']['Name'] if instance_state in ['terminated', 'shutting-down']: terminated_ids.add(instance_id) elif instance_state == 'stopped': stopped_ids.add(instance_id) else: active_ids.add(instance_id) for vm_path_at_vm_region, status, proxy_str in vm_info_list: vm_path = vm_path_at_vm_region.split('@')[0] if vm_path in terminated_ids: logger.info(f"VM {vm_path} not found or terminated, releasing it.") continue elif vm_path in stopped_ids: logger.info(f"VM {vm_path} stopped, mark it as free") new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') continue if status == "free": new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') elif status.isdigit() and int(status) in active_pids: new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') else: new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') except Exception as e: logger.error(f"Error checking instances in region {region}: {e}") continue with open(self.registry_path, 'w') as file: file.writelines(new_lines) def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True): if lock_needed: with self.lock: return self._list_free_vms(region) else: return self._list_free_vms(region) def _list_free_vms(self, region=DEFAULT_REGION): free_vms = [] with open(self.registry_path, 'r') as file: lines = file.readlines() for line in lines: parts = line.strip().split('|') if len(parts) >= 2: vm_path_at_vm_region = parts[0] status = parts[1] proxy_str = parts[2] if len(parts) > 2 else "" vm_path, vm_region = vm_path_at_vm_region.split("@") if status == "free" and vm_region == region: free_vms.append((vm_path, status, proxy_str)) return free_vms def get_vm_path(self, region=DEFAULT_REGION): with self.lock: if not AWSVMManagerWithProxy.checked_and_cleaned: AWSVMManagerWithProxy.checked_and_cleaned = True self._check_and_clean() allocation_needed = False with self.lock: free_vms_paths = self._list_free_vms(region) if len(free_vms_paths) == 0: allocation_needed = True else: chosen_vm_path, _, proxy_str = free_vms_paths[0] self._occupy_vm(chosen_vm_path, os.getpid(), region) logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}") return chosen_vm_path if allocation_needed: logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕") new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file) # 获取当前使用的代理信息 proxy_pool = get_global_proxy_pool() current_proxy = proxy_pool.get_next_proxy() proxy_info = None if current_proxy: proxy_info = { 'host': current_proxy.host, 'port': current_proxy.port } with self.lock: self._add_vm(new_vm_path, region, proxy_info) self._occupy_vm(new_vm_path, os.getpid(), region) return new_vm_path def get_proxy_stats(self): """获取代理池统计信息""" proxy_pool = get_global_proxy_pool() return proxy_pool.get_stats() def test_all_proxies(self): """测试所有代理""" proxy_pool = get_global_proxy_pool() return proxy_pool.test_all_proxies() def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION): """为特定VM强制轮换代理""" logger.info(f"Force rotating proxy for VM {vm_path}") # 这里需要重新创建实例来应用新的代理配置 # 在实际应用中,可能需要保存当前状态并恢复 proxy_pool = get_global_proxy_pool() new_proxy = proxy_pool.get_next_proxy() if new_proxy: logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}") return True else: logger.warning(f"No available proxy for VM {vm_path}") return False