Files
sci-gui-agent-benchmark/desktop_env/providers/aws/manager_with_proxy.py

329 lines
12 KiB
Python

import os
from filelock import FileLock
import boto3
import psutil
import logging
from desktop_env.providers.base import VMManager
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy")
logger.setLevel(logging.INFO)
REGISTRY_PATH = '.aws_vms_proxy'
DEFAULT_REGION = "us-east-1"
IMAGE_ID_MAP = {
"us-east-1": "ami-05e7d7bd279ea4f14",
"ap-east-1": "ami-0c092a5b8be4116f5"
}
INSTANCE_TYPE = "t3.medium"
NETWORK_INTERFACE_MAP = {
"us-east-1": [
{
"SubnetId": "subnet-037edfff66c2eb894",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-0342574803206ee9c"
]
}
],
"ap-east-1": [
{
"SubnetId": "subnet-011060501be0b589c",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-090470e64df78f6eb"
]
}
]
}
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
"""分配带有代理配置的VM"""
from .provider_with_proxy import AWSProviderWithProxy
# 初始化代理池(如果还没有初始化)
if proxy_config_file:
init_proxy_pool(proxy_config_file)
# 获取当前代理
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
if current_proxy:
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
# 创建provider实例
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
# 创建新实例
instance_id = provider.create_instance_with_proxy(
image_id=IMAGE_ID_MAP[region],
instance_type=INSTANCE_TYPE,
security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"],
subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"]
)
return instance_id
class AWSVMManagerWithProxy(VMManager):
def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None):
self.registry_path = registry_path
self.lock = FileLock(".aws_proxy_lck", timeout=60)
self.proxy_config_file = proxy_config_file
self.initialize_registry()
# 初始化代理池
if proxy_config_file:
init_proxy_pool(proxy_config_file)
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
def initialize_registry(self):
with self.lock:
if not os.path.exists(self.registry_path):
with open(self.registry_path, 'w') as file:
file.write('')
def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True):
if lock_needed:
with self.lock:
self._add_vm(vm_path, region, proxy_info)
else:
self._add_vm(vm_path, region, proxy_info)
def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None):
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# 格式: vm_path@region|status|proxy_host:proxy_port
vm_path_at_vm_region = f"{vm_path}@{region}"
proxy_str = ""
if proxy_info:
proxy_str = f"{proxy_info['host']}:{proxy_info['port']}"
new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n'
new_lines = lines + [new_line]
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._delete_vm(vm_path, region)
else:
self._delete_vm(vm_path, region)
def _delete_vm(self, vm_path, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
if vm_path_at_vm_region == f"{vm_path}@{region}":
continue
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._occupy_vm(vm_path, pid, region)
else:
self._occupy_vm(vm_path, pid, region)
def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
registered_vm_path = parts[0]
if registered_vm_path == f"{vm_path}@{region}":
proxy_str = parts[2] if len(parts) > 2 else ""
new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n')
else:
new_lines.append(line)
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def check_and_clean(self, lock_needed=True):
if lock_needed:
with self.lock:
self._check_and_clean()
else:
self._check_and_clean()
def _check_and_clean(self):
# Get active PIDs
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_path_at_vm_regions = {}
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# Collect all VM paths and their regions
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if vm_region not in vm_path_at_vm_regions:
vm_path_at_vm_regions[vm_region] = []
vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str))
# Process each region
for region, vm_info_list in vm_path_at_vm_regions.items():
ec2_client = boto3.client('ec2', region_name=region)
instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list]
try:
response = ec2_client.describe_instances(InstanceIds=instance_ids)
reservations = response.get('Reservations', [])
terminated_ids = set()
stopped_ids = set()
active_ids = set()
for reservation in reservations:
for instance in reservation.get('Instances', []):
instance_id = instance.get('InstanceId')
instance_state = instance['State']['Name']
if instance_state in ['terminated', 'shutting-down']:
terminated_ids.add(instance_id)
elif instance_state == 'stopped':
stopped_ids.add(instance_id)
else:
active_ids.add(instance_id)
for vm_path_at_vm_region, status, proxy_str in vm_info_list:
vm_path = vm_path_at_vm_region.split('@')[0]
if vm_path in terminated_ids:
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
continue
elif vm_path in stopped_ids:
logger.info(f"VM {vm_path} stopped, mark it as free")
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
continue
if status == "free":
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
elif status.isdigit() and int(status) in active_pids:
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
else:
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
except Exception as e:
logger.error(f"Error checking instances in region {region}: {e}")
continue
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
return self._list_free_vms(region)
else:
return self._list_free_vms(region)
def _list_free_vms(self, region=DEFAULT_REGION):
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if status == "free" and vm_region == region:
free_vms.append((vm_path, status, proxy_str))
return free_vms
def get_vm_path(self, region=DEFAULT_REGION):
with self.lock:
if not AWSVMManagerWithProxy.checked_and_cleaned:
AWSVMManagerWithProxy.checked_and_cleaned = True
self._check_and_clean()
allocation_needed = False
with self.lock:
free_vms_paths = self._list_free_vms(region)
if len(free_vms_paths) == 0:
allocation_needed = True
else:
chosen_vm_path, _, proxy_str = free_vms_paths[0]
self._occupy_vm(chosen_vm_path, os.getpid(), region)
logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}")
return chosen_vm_path
if allocation_needed:
logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕")
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
# 获取当前使用的代理信息
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
proxy_info = None
if current_proxy:
proxy_info = {
'host': current_proxy.host,
'port': current_proxy.port
}
with self.lock:
self._add_vm(new_vm_path, region, proxy_info)
self._occupy_vm(new_vm_path, os.getpid(), region)
return new_vm_path
def get_proxy_stats(self):
"""获取代理池统计信息"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.get_stats()
def test_all_proxies(self):
"""测试所有代理"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.test_all_proxies()
def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION):
"""为特定VM强制轮换代理"""
logger.info(f"Force rotating proxy for VM {vm_path}")
# 这里需要重新创建实例来应用新的代理配置
# 在实际应用中,可能需要保存当前状态并恢复
proxy_pool = get_global_proxy_pool()
new_proxy = proxy_pool.get_next_proxy()
if new_proxy:
logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}")
return True
else:
logger.warning(f"No available proxy for VM {vm_path}")
return False