feat: implement proxy management for AWS VM provider and enhance task configuration handling

This commit is contained in:
Timothyxxx
2025-06-06 00:36:21 +08:00
parent fb7bafb885
commit bfd0a7ad0d
5 changed files with 835 additions and 5 deletions

View File

@@ -54,13 +54,17 @@ class DesktopEnv(gym.Env):
""" """
# Initialize VM manager and vitualization provider # Initialize VM manager and vitualization provider
self.region = region self.region = region
self.provider_name = provider_name
# Default TODO: # Default TODO:
self.server_port = 5000 self.server_port = 5000
self.chromium_port = 9222 self.chromium_port = 9222
self.vnc_port = 8006 self.vnc_port = 8006
self.vlc_port = 8080 self.vlc_port = 8080
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
# Initialize with default (no proxy) provider
self.current_use_proxy = False
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False)
self.os_type = os_type self.os_type = os_type
@@ -149,6 +153,32 @@ class DesktopEnv(gym.Env):
self._step_no = 0 self._step_no = 0
self.action_history.clear() self.action_history.clear()
# Check and handle proxy requirement changes BEFORE starting emulator
if task_config is not None:
task_use_proxy = task_config.get("proxy", False)
if task_use_proxy != self.current_use_proxy:
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
# Close current provider if it exists
if hasattr(self, 'provider') and self.provider:
try:
self.provider.stop_emulator(self.path_to_vm)
except Exception as e:
logger.warning(f"Failed to stop current provider: {e}")
# Create new provider with appropriate proxy setting
self.current_use_proxy = task_use_proxy
self.manager, self.provider = create_vm_manager_and_provider(
self.provider_name,
self.region,
use_proxy=task_use_proxy
)
if task_use_proxy:
logger.info("Using proxy-enabled AWS provider.")
else:
logger.info("Using regular AWS provider.")
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name)) logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
self._revert_to_snapshot() self._revert_to_snapshot()
logger.info("Starting emulator...") logger.info("Starting emulator...")
@@ -184,12 +214,17 @@ class DesktopEnv(gym.Env):
return self.controller.get_vm_screen_size() return self.controller.get_vm_screen_size()
def _set_task_info(self, task_config: Dict[str, Any]): def _set_task_info(self, task_config: Dict[str, Any]):
"""Set task info (proxy logic is handled in reset method)"""
self.task_id: str = task_config["id"] self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id) self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True) os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"] self.instruction = task_config["instruction"]
self.config = task_config["config"] if "config" in task_config else [] self.config = task_config["config"] if "config" in task_config else []
self._set_evaluator_info(task_config)
def _set_evaluator_info(self, task_config: Dict[str, Any]):
"""Set evaluator information from task config"""
# evaluator dict # evaluator dict
# func -> metric function string, or list of metric function strings # func -> metric function string, or list of metric function strings
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or" # conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"

View File

@@ -1,9 +1,14 @@
from desktop_env.providers.base import VMManager, Provider from desktop_env.providers.base import VMManager, Provider
def create_vm_manager_and_provider(provider_name: str, region: str): def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: bool = False):
""" """
Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name. Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name.
Args:
provider_name (str): The name of the provider (e.g., "aws", "vmware", etc.)
region (str): The region for the provider
use_proxy (bool): Whether to use proxy-enabled providers (currently only supported for AWS)
""" """
provider_name = provider_name.lower().strip() provider_name = provider_name.lower().strip()
if provider_name == "vmware": if provider_name == "vmware":
@@ -15,9 +20,16 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
return VirtualBoxVMManager(), VirtualBoxProvider(region) return VirtualBoxVMManager(), VirtualBoxProvider(region)
elif provider_name in ["aws", "amazon web services"]: elif provider_name in ["aws", "amazon web services"]:
from desktop_env.providers.aws.manager import AWSVMManager if use_proxy:
from desktop_env.providers.aws.provider import AWSProvider # Use proxy-enabled AWS provider
return AWSVMManager(), AWSProvider(region) from desktop_env.providers.aws.manager_with_proxy import AWSVMManagerWithProxy
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
return AWSVMManagerWithProxy(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
else:
# Use regular AWS provider
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
return AWSVMManager(), AWSProvider(region)
elif provider_name == "azure": elif provider_name == "azure":
from desktop_env.providers.azure.manager import AzureVMManager from desktop_env.providers.azure.manager import AzureVMManager
from desktop_env.providers.azure.provider import AzureProvider from desktop_env.providers.azure.provider import AzureProvider

View File

@@ -0,0 +1,329 @@
import os
from filelock import FileLock
import boto3
import psutil
import logging
from desktop_env.providers.base import VMManager
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy")
logger.setLevel(logging.INFO)
REGISTRY_PATH = '.aws_vms_proxy'
DEFAULT_REGION = "us-east-1"
IMAGE_ID_MAP = {
"us-east-1": "ami-05e7d7bd279ea4f14",
"ap-east-1": "ami-0c092a5b8be4116f5"
}
INSTANCE_TYPE = "t3.medium"
NETWORK_INTERFACE_MAP = {
"us-east-1": [
{
"SubnetId": "subnet-037edfff66c2eb894",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-0342574803206ee9c"
]
}
],
"ap-east-1": [
{
"SubnetId": "subnet-011060501be0b589c",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-090470e64df78f6eb"
]
}
]
}
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
"""分配带有代理配置的VM"""
from .provider_with_proxy import AWSProviderWithProxy
# 初始化代理池(如果还没有初始化)
if proxy_config_file:
init_proxy_pool(proxy_config_file)
# 获取当前代理
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
if current_proxy:
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
# 创建provider实例
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
# 创建新实例
instance_id = provider.create_instance_with_proxy(
image_id=IMAGE_ID_MAP[region],
instance_type=INSTANCE_TYPE,
security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"],
subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"]
)
return instance_id
class AWSVMManagerWithProxy(VMManager):
def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None):
self.registry_path = registry_path
self.lock = FileLock(".aws_proxy_lck", timeout=60)
self.proxy_config_file = proxy_config_file
self.initialize_registry()
# 初始化代理池
if proxy_config_file:
init_proxy_pool(proxy_config_file)
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
def initialize_registry(self):
with self.lock:
if not os.path.exists(self.registry_path):
with open(self.registry_path, 'w') as file:
file.write('')
def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True):
if lock_needed:
with self.lock:
self._add_vm(vm_path, region, proxy_info)
else:
self._add_vm(vm_path, region, proxy_info)
def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None):
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# 格式: vm_path@region|status|proxy_host:proxy_port
vm_path_at_vm_region = f"{vm_path}@{region}"
proxy_str = ""
if proxy_info:
proxy_str = f"{proxy_info['host']}:{proxy_info['port']}"
new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n'
new_lines = lines + [new_line]
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._delete_vm(vm_path, region)
else:
self._delete_vm(vm_path, region)
def _delete_vm(self, vm_path, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
if vm_path_at_vm_region == f"{vm_path}@{region}":
continue
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._occupy_vm(vm_path, pid, region)
else:
self._occupy_vm(vm_path, pid, region)
def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
registered_vm_path = parts[0]
if registered_vm_path == f"{vm_path}@{region}":
proxy_str = parts[2] if len(parts) > 2 else ""
new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n')
else:
new_lines.append(line)
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def check_and_clean(self, lock_needed=True):
if lock_needed:
with self.lock:
self._check_and_clean()
else:
self._check_and_clean()
def _check_and_clean(self):
# Get active PIDs
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_path_at_vm_regions = {}
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# Collect all VM paths and their regions
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if vm_region not in vm_path_at_vm_regions:
vm_path_at_vm_regions[vm_region] = []
vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str))
# Process each region
for region, vm_info_list in vm_path_at_vm_regions.items():
ec2_client = boto3.client('ec2', region_name=region)
instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list]
try:
response = ec2_client.describe_instances(InstanceIds=instance_ids)
reservations = response.get('Reservations', [])
terminated_ids = set()
stopped_ids = set()
active_ids = set()
for reservation in reservations:
for instance in reservation.get('Instances', []):
instance_id = instance.get('InstanceId')
instance_state = instance['State']['Name']
if instance_state in ['terminated', 'shutting-down']:
terminated_ids.add(instance_id)
elif instance_state == 'stopped':
stopped_ids.add(instance_id)
else:
active_ids.add(instance_id)
for vm_path_at_vm_region, status, proxy_str in vm_info_list:
vm_path = vm_path_at_vm_region.split('@')[0]
if vm_path in terminated_ids:
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
continue
elif vm_path in stopped_ids:
logger.info(f"VM {vm_path} stopped, mark it as free")
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
continue
if status == "free":
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
elif status.isdigit() and int(status) in active_pids:
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
else:
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
except Exception as e:
logger.error(f"Error checking instances in region {region}: {e}")
continue
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
return self._list_free_vms(region)
else:
return self._list_free_vms(region)
def _list_free_vms(self, region=DEFAULT_REGION):
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if status == "free" and vm_region == region:
free_vms.append((vm_path, status, proxy_str))
return free_vms
def get_vm_path(self, region=DEFAULT_REGION):
with self.lock:
if not AWSVMManagerWithProxy.checked_and_cleaned:
AWSVMManagerWithProxy.checked_and_cleaned = True
self._check_and_clean()
allocation_needed = False
with self.lock:
free_vms_paths = self._list_free_vms(region)
if len(free_vms_paths) == 0:
allocation_needed = True
else:
chosen_vm_path, _, proxy_str = free_vms_paths[0]
self._occupy_vm(chosen_vm_path, os.getpid(), region)
logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}")
return chosen_vm_path
if allocation_needed:
logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕")
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
# 获取当前使用的代理信息
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
proxy_info = None
if current_proxy:
proxy_info = {
'host': current_proxy.host,
'port': current_proxy.port
}
with self.lock:
self._add_vm(new_vm_path, region, proxy_info)
self._occupy_vm(new_vm_path, os.getpid(), region)
return new_vm_path
def get_proxy_stats(self):
"""获取代理池统计信息"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.get_stats()
def test_all_proxies(self):
"""测试所有代理"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.test_all_proxies()
def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION):
"""为特定VM强制轮换代理"""
logger.info(f"Force rotating proxy for VM {vm_path}")
# 这里需要重新创建实例来应用新的代理配置
# 在实际应用中,可能需要保存当前状态并恢复
proxy_pool = get_global_proxy_pool()
new_proxy = proxy_pool.get_next_proxy()
if new_proxy:
logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}")
return True
else:
logger.warning(f"No available proxy for VM {vm_path}")
return False

View File

@@ -0,0 +1,261 @@
import boto3
from botocore.exceptions import ClientError
import base64
import logging
import json
from typing import Optional
from desktop_env.providers.base import Provider
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool, ProxyInfo
logger = logging.getLogger("desktopenv.providers.aws.AWSProviderWithProxy")
logger.setLevel(logging.INFO)
WAIT_DELAY = 15
MAX_ATTEMPTS = 10
class AWSProviderWithProxy(Provider):
def __init__(self, region: str = None, proxy_config_file: str = None):
super().__init__(region)
self.current_proxy: Optional[ProxyInfo] = None
# 初始化代理池
if proxy_config_file:
init_proxy_pool(proxy_config_file)
logger.info(f"Initialized proxy pool from {proxy_config_file}")
# 获取下一个可用代理
self._rotate_proxy()
def _rotate_proxy(self):
"""轮换到下一个可用代理"""
proxy_pool = get_global_proxy_pool()
self.current_proxy = proxy_pool.get_next_proxy()
if self.current_proxy:
logger.info(f"Switched to proxy: {self.current_proxy.host}:{self.current_proxy.port}")
else:
logger.warning("No proxy available, using direct connection")
def _generate_proxy_user_data(self) -> str:
"""生成包含代理配置的user data脚本"""
if not self.current_proxy:
return ""
proxy_url = self._format_proxy_url(self.current_proxy)
user_data_script = f"""#!/bin/bash
# 配置系统代理
echo 'export http_proxy={proxy_url}' >> /etc/environment
echo 'export https_proxy={proxy_url}' >> /etc/environment
echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment
echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment
# 配置apt代理
cat > /etc/apt/apt.conf.d/95proxy << EOF
Acquire::http::Proxy "{proxy_url}";
Acquire::https::Proxy "{proxy_url}";
EOF
# 配置chrome/chromium代理
mkdir -p /etc/opt/chrome/policies/managed
cat > /etc/opt/chrome/policies/managed/proxy.json << EOF
{{
"ProxyMode": "fixed_servers",
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
}}
EOF
# 配置firefox代理
mkdir -p /etc/firefox/policies
cat > /etc/firefox/policies/policies.json << EOF
{{
"policies": {{
"Proxy": {{
"Mode": "manual",
"HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"UseHTTPProxyForAllProtocols": true
}}
}}
}}
EOF
# 重新加载环境变量
source /etc/environment
# 记录代理配置日志
echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log
"""
return base64.b64encode(user_data_script.encode()).decode()
def _format_proxy_url(self, proxy: ProxyInfo) -> str:
"""格式化代理URL"""
if proxy.username and proxy.password:
return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}"
else:
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
def start_emulator(self, path_to_vm: str, headless: bool):
logger.info("Starting AWS VM with proxy configuration...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
# 如果实例已经存在,直接启动
ec2_client.start_instances(InstanceIds=[path_to_vm])
logger.info(f"Instance {path_to_vm} is starting...")
# Wait for the instance to be in the 'running' state
waiter = ec2_client.get_waiter('instance_running')
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
logger.info(f"Instance {path_to_vm} is now running.")
except ClientError as e:
logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}")
raise
def create_instance_with_proxy(self, image_id: str, instance_type: str,
security_groups: list, subnet_id: str) -> str:
"""创建带有代理配置的新实例"""
ec2_client = boto3.client('ec2', region_name=self.region)
user_data = self._generate_proxy_user_data()
run_instances_params = {
"MaxCount": 1,
"MinCount": 1,
"ImageId": image_id,
"InstanceType": instance_type,
"EbsOptimized": True,
"NetworkInterfaces": [
{
"SubnetId": subnet_id,
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": security_groups
}
]
}
if user_data:
run_instances_params["UserData"] = user_data
try:
response = ec2_client.run_instances(**run_instances_params)
instance_id = response['Instances'][0]['InstanceId']
logger.info(f"Created new instance {instance_id} with proxy configuration")
# 等待实例运行
logger.info(f"Waiting for instance {instance_id} to be running...")
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
logger.info(f"Instance {instance_id} is ready.")
return instance_id
except ClientError as e:
logger.error(f"Failed to create instance with proxy: {str(e)}")
# 如果当前代理失败,尝试轮换代理
if self.current_proxy:
proxy_pool = get_global_proxy_pool()
proxy_pool.mark_proxy_failed(self.current_proxy)
self._rotate_proxy()
raise
def get_ip_address(self, path_to_vm: str) -> str:
logger.info("Getting AWS VM IP address...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
response = ec2_client.describe_instances(InstanceIds=[path_to_vm])
for reservation in response['Reservations']:
for instance in reservation['Instances']:
private_ip_address = instance.get('PrivateIpAddress', '')
return private_ip_address
return ''
except ClientError as e:
logger.error(f"Failed to retrieve private IP address for the instance {path_to_vm}: {str(e)}")
raise
def save_state(self, path_to_vm: str, snapshot_name: str):
logger.info("Saving AWS VM state...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
image_response = ec2_client.create_image(InstanceId=path_to_vm, Name=snapshot_name)
image_id = image_response['ImageId']
logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.")
return image_id
except ClientError as e:
logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}")
raise
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
# 获取原实例详情
instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm])
instance = instance_details['Reservations'][0]['Instances'][0]
security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']]
subnet_id = instance['SubnetId']
instance_type = instance['InstanceType']
# 终止旧实例
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
logger.info(f"Old instance {path_to_vm} has been terminated.")
# 轮换到新的代理
self._rotate_proxy()
# 创建新实例
new_instance_id = self.create_instance_with_proxy(
snapshot_name, instance_type, security_groups, subnet_id
)
return new_instance_id
except ClientError as e:
logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}")
raise
def stop_emulator(self, path_to_vm, region=None):
logger.info(f"Stopping AWS VM {path_to_vm}...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
ec2_client.stop_instances(InstanceIds=[path_to_vm])
waiter = ec2_client.get_waiter('instance_stopped')
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
logger.info(f"Instance {path_to_vm} has been stopped.")
except ClientError as e:
logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}")
raise
def get_current_proxy_info(self) -> Optional[dict]:
"""获取当前代理信息"""
if self.current_proxy:
return {
'host': self.current_proxy.host,
'port': self.current_proxy.port,
'protocol': self.current_proxy.protocol,
'failed_count': self.current_proxy.failed_count
}
return None
def force_rotate_proxy(self):
"""强制轮换代理"""
logger.info("Force rotating proxy...")
if self.current_proxy:
proxy_pool = get_global_proxy_pool()
proxy_pool.mark_proxy_failed(self.current_proxy)
self._rotate_proxy()
def get_proxy_stats(self) -> dict:
"""获取代理池统计信息"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.get_stats()

View File

@@ -0,0 +1,193 @@
import random
import requests
import logging
import time
from typing import List, Dict, Optional
from dataclasses import dataclass
from threading import Lock
import json
logger = logging.getLogger("desktopenv.providers.aws.ProxyPool")
logger.setLevel(logging.INFO)
@dataclass
class ProxyInfo:
host: str
port: int
username: Optional[str] = None
password: Optional[str] = None
protocol: str = "http" # http, https, socks5
failed_count: int = 0
last_used: float = 0
is_active: bool = True
class ProxyPool:
def __init__(self, config_file: str = None):
self.proxies: List[ProxyInfo] = []
self.current_index = 0
self.lock = Lock()
self.max_failures = 3 # 最大失败次数
self.cooldown_time = 300 # 5分钟冷却时间
if config_file:
self.load_proxies_from_file(config_file)
def load_proxies_from_file(self, config_file: str):
"""从配置文件加载代理列表"""
try:
with open(config_file, 'r') as f:
proxy_configs = json.load(f)
for config in proxy_configs:
proxy = ProxyInfo(
host=config['host'],
port=config['port'],
username=config.get('username'),
password=config.get('password'),
protocol=config.get('protocol', 'http')
)
self.proxies.append(proxy)
logger.info(f"Loaded {len(self.proxies)} proxies from {config_file}")
except Exception as e:
logger.error(f"Failed to load proxies from {config_file}: {e}")
def add_proxy(self, host: str, port: int, username: str = None,
password: str = None, protocol: str = "http"):
"""添加代理到池中"""
proxy = ProxyInfo(host=host, port=port, username=username,
password=password, protocol=protocol)
with self.lock:
self.proxies.append(proxy)
logger.info(f"Added proxy {host}:{port}")
def get_next_proxy(self) -> Optional[ProxyInfo]:
"""获取下一个可用的代理"""
with self.lock:
if not self.proxies:
return None
# 过滤掉失败次数过多的代理
active_proxies = [p for p in self.proxies if self._is_proxy_available(p)]
if not active_proxies:
logger.warning("No active proxies available")
return None
# 轮询选择代理
proxy = active_proxies[self.current_index % len(active_proxies)]
self.current_index += 1
proxy.last_used = time.time()
return proxy
def _is_proxy_available(self, proxy: ProxyInfo) -> bool:
"""检查代理是否可用"""
if not proxy.is_active:
return False
if proxy.failed_count >= self.max_failures:
# 检查是否过了冷却时间
if time.time() - proxy.last_used < self.cooldown_time:
return False
else:
# 重置失败计数
proxy.failed_count = 0
return True
def mark_proxy_failed(self, proxy: ProxyInfo):
"""标记代理失败"""
with self.lock:
proxy.failed_count += 1
if proxy.failed_count >= self.max_failures:
logger.warning(f"Proxy {proxy.host}:{proxy.port} marked as failed "
f"(failures: {proxy.failed_count})")
def mark_proxy_success(self, proxy: ProxyInfo):
"""标记代理成功"""
with self.lock:
proxy.failed_count = 0
def test_proxy(self, proxy: ProxyInfo, test_url: str = "http://httpbin.org/ip",
timeout: int = 10) -> bool:
"""测试代理是否正常工作"""
try:
proxy_url = self._format_proxy_url(proxy)
proxies = {
'http': proxy_url,
'https': proxy_url
}
response = requests.get(test_url, proxies=proxies, timeout=timeout)
if response.status_code == 200:
self.mark_proxy_success(proxy)
return True
else:
self.mark_proxy_failed(proxy)
return False
except Exception as e:
logger.debug(f"Proxy test failed for {proxy.host}:{proxy.port}: {e}")
self.mark_proxy_failed(proxy)
return False
def _format_proxy_url(self, proxy: ProxyInfo) -> str:
"""格式化代理URL"""
if proxy.username and proxy.password:
return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}"
else:
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
def get_proxy_dict(self, proxy: ProxyInfo) -> Dict[str, str]:
"""获取requests库使用的代理字典"""
proxy_url = self._format_proxy_url(proxy)
return {
'http': proxy_url,
'https': proxy_url
}
def test_all_proxies(self, test_url: str = "http://httpbin.org/ip"):
"""测试所有代理"""
logger.info("Testing all proxies...")
working_count = 0
for proxy in self.proxies:
if self.test_proxy(proxy, test_url):
working_count += 1
logger.info(f"✓ Proxy {proxy.host}:{proxy.port} is working")
else:
logger.warning(f"✗ Proxy {proxy.host}:{proxy.port} failed")
logger.info(f"Proxy test completed: {working_count}/{len(self.proxies)} working")
return working_count
def get_stats(self) -> Dict:
"""获取代理池统计信息"""
with self.lock:
total = len(self.proxies)
active = len([p for p in self.proxies if self._is_proxy_available(p)])
failed = len([p for p in self.proxies if p.failed_count >= self.max_failures])
return {
'total': total,
'active': active,
'failed': failed,
'success_rate': active / total if total > 0 else 0
}
# 全局代理池实例
_proxy_pool = None
def get_global_proxy_pool() -> ProxyPool:
"""获取全局代理池实例"""
global _proxy_pool
if _proxy_pool is None:
_proxy_pool = ProxyPool()
return _proxy_pool
def init_proxy_pool(config_file: str = None):
"""初始化全局代理池"""
global _proxy_pool
_proxy_pool = ProxyPool(config_file)
return _proxy_pool