Merge remote-tracking branch 'upstream/feat/aws-provider-support'

This commit is contained in:
yuanmengqi
2025-06-05 16:56:28 +00:00
5 changed files with 835 additions and 5 deletions

View File

@@ -54,13 +54,17 @@ class DesktopEnv(gym.Env):
"""
# Initialize VM manager and vitualization provider
self.region = region
self.provider_name = provider_name
# Default TODO:
self.server_port = 5000
self.chromium_port = 9222
self.vnc_port = 8006
self.vlc_port = 8080
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
# Initialize with default (no proxy) provider
self.current_use_proxy = False
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False)
self.os_type = os_type
@@ -149,6 +153,32 @@ class DesktopEnv(gym.Env):
self._step_no = 0
self.action_history.clear()
# Check and handle proxy requirement changes BEFORE starting emulator
if task_config is not None:
task_use_proxy = task_config.get("proxy", False)
if task_use_proxy != self.current_use_proxy:
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
# Close current provider if it exists
if hasattr(self, 'provider') and self.provider:
try:
self.provider.stop_emulator(self.path_to_vm)
except Exception as e:
logger.warning(f"Failed to stop current provider: {e}")
# Create new provider with appropriate proxy setting
self.current_use_proxy = task_use_proxy
self.manager, self.provider = create_vm_manager_and_provider(
self.provider_name,
self.region,
use_proxy=task_use_proxy
)
if task_use_proxy:
logger.info("Using proxy-enabled AWS provider.")
else:
logger.info("Using regular AWS provider.")
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
self._revert_to_snapshot()
logger.info("Starting emulator...")
@@ -184,12 +214,17 @@ class DesktopEnv(gym.Env):
return self.controller.get_vm_screen_size()
def _set_task_info(self, task_config: Dict[str, Any]):
"""Set task info (proxy logic is handled in reset method)"""
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"] if "config" in task_config else []
self._set_evaluator_info(task_config)
def _set_evaluator_info(self, task_config: Dict[str, Any]):
"""Set evaluator information from task config"""
# evaluator dict
# func -> metric function string, or list of metric function strings
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"

View File

@@ -1,9 +1,14 @@
from desktop_env.providers.base import VMManager, Provider
def create_vm_manager_and_provider(provider_name: str, region: str):
def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: bool = False):
"""
Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name.
Args:
provider_name (str): The name of the provider (e.g., "aws", "vmware", etc.)
region (str): The region for the provider
use_proxy (bool): Whether to use proxy-enabled providers (currently only supported for AWS)
"""
provider_name = provider_name.lower().strip()
if provider_name == "vmware":
@@ -15,9 +20,16 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
return VirtualBoxVMManager(), VirtualBoxProvider(region)
elif provider_name in ["aws", "amazon web services"]:
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
return AWSVMManager(), AWSProvider(region)
if use_proxy:
# Use proxy-enabled AWS provider
from desktop_env.providers.aws.manager_with_proxy import AWSVMManagerWithProxy
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
return AWSVMManagerWithProxy(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
else:
# Use regular AWS provider
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
return AWSVMManager(), AWSProvider(region)
elif provider_name == "azure":
from desktop_env.providers.azure.manager import AzureVMManager
from desktop_env.providers.azure.provider import AzureProvider

View File

@@ -0,0 +1,329 @@
import os
from filelock import FileLock
import boto3
import psutil
import logging
from desktop_env.providers.base import VMManager
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy")
logger.setLevel(logging.INFO)
REGISTRY_PATH = '.aws_vms_proxy'
DEFAULT_REGION = "us-east-1"
IMAGE_ID_MAP = {
"us-east-1": "ami-05e7d7bd279ea4f14",
"ap-east-1": "ami-0c092a5b8be4116f5"
}
INSTANCE_TYPE = "t3.medium"
NETWORK_INTERFACE_MAP = {
"us-east-1": [
{
"SubnetId": "subnet-037edfff66c2eb894",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-0342574803206ee9c"
]
}
],
"ap-east-1": [
{
"SubnetId": "subnet-011060501be0b589c",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-090470e64df78f6eb"
]
}
]
}
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
"""分配带有代理配置的VM"""
from .provider_with_proxy import AWSProviderWithProxy
# 初始化代理池(如果还没有初始化)
if proxy_config_file:
init_proxy_pool(proxy_config_file)
# 获取当前代理
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
if current_proxy:
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
# 创建provider实例
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
# 创建新实例
instance_id = provider.create_instance_with_proxy(
image_id=IMAGE_ID_MAP[region],
instance_type=INSTANCE_TYPE,
security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"],
subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"]
)
return instance_id
class AWSVMManagerWithProxy(VMManager):
def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None):
self.registry_path = registry_path
self.lock = FileLock(".aws_proxy_lck", timeout=60)
self.proxy_config_file = proxy_config_file
self.initialize_registry()
# 初始化代理池
if proxy_config_file:
init_proxy_pool(proxy_config_file)
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
def initialize_registry(self):
with self.lock:
if not os.path.exists(self.registry_path):
with open(self.registry_path, 'w') as file:
file.write('')
def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True):
if lock_needed:
with self.lock:
self._add_vm(vm_path, region, proxy_info)
else:
self._add_vm(vm_path, region, proxy_info)
def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None):
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# 格式: vm_path@region|status|proxy_host:proxy_port
vm_path_at_vm_region = f"{vm_path}@{region}"
proxy_str = ""
if proxy_info:
proxy_str = f"{proxy_info['host']}:{proxy_info['port']}"
new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n'
new_lines = lines + [new_line]
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._delete_vm(vm_path, region)
else:
self._delete_vm(vm_path, region)
def _delete_vm(self, vm_path, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
if vm_path_at_vm_region == f"{vm_path}@{region}":
continue
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._occupy_vm(vm_path, pid, region)
else:
self._occupy_vm(vm_path, pid, region)
def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
registered_vm_path = parts[0]
if registered_vm_path == f"{vm_path}@{region}":
proxy_str = parts[2] if len(parts) > 2 else ""
new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n')
else:
new_lines.append(line)
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def check_and_clean(self, lock_needed=True):
if lock_needed:
with self.lock:
self._check_and_clean()
else:
self._check_and_clean()
def _check_and_clean(self):
# Get active PIDs
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_path_at_vm_regions = {}
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# Collect all VM paths and their regions
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if vm_region not in vm_path_at_vm_regions:
vm_path_at_vm_regions[vm_region] = []
vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str))
# Process each region
for region, vm_info_list in vm_path_at_vm_regions.items():
ec2_client = boto3.client('ec2', region_name=region)
instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list]
try:
response = ec2_client.describe_instances(InstanceIds=instance_ids)
reservations = response.get('Reservations', [])
terminated_ids = set()
stopped_ids = set()
active_ids = set()
for reservation in reservations:
for instance in reservation.get('Instances', []):
instance_id = instance.get('InstanceId')
instance_state = instance['State']['Name']
if instance_state in ['terminated', 'shutting-down']:
terminated_ids.add(instance_id)
elif instance_state == 'stopped':
stopped_ids.add(instance_id)
else:
active_ids.add(instance_id)
for vm_path_at_vm_region, status, proxy_str in vm_info_list:
vm_path = vm_path_at_vm_region.split('@')[0]
if vm_path in terminated_ids:
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
continue
elif vm_path in stopped_ids:
logger.info(f"VM {vm_path} stopped, mark it as free")
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
continue
if status == "free":
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
elif status.isdigit() and int(status) in active_pids:
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
else:
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
except Exception as e:
logger.error(f"Error checking instances in region {region}: {e}")
continue
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
return self._list_free_vms(region)
else:
return self._list_free_vms(region)
def _list_free_vms(self, region=DEFAULT_REGION):
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if status == "free" and vm_region == region:
free_vms.append((vm_path, status, proxy_str))
return free_vms
def get_vm_path(self, region=DEFAULT_REGION):
with self.lock:
if not AWSVMManagerWithProxy.checked_and_cleaned:
AWSVMManagerWithProxy.checked_and_cleaned = True
self._check_and_clean()
allocation_needed = False
with self.lock:
free_vms_paths = self._list_free_vms(region)
if len(free_vms_paths) == 0:
allocation_needed = True
else:
chosen_vm_path, _, proxy_str = free_vms_paths[0]
self._occupy_vm(chosen_vm_path, os.getpid(), region)
logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}")
return chosen_vm_path
if allocation_needed:
logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕")
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
# 获取当前使用的代理信息
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
proxy_info = None
if current_proxy:
proxy_info = {
'host': current_proxy.host,
'port': current_proxy.port
}
with self.lock:
self._add_vm(new_vm_path, region, proxy_info)
self._occupy_vm(new_vm_path, os.getpid(), region)
return new_vm_path
def get_proxy_stats(self):
"""获取代理池统计信息"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.get_stats()
def test_all_proxies(self):
"""测试所有代理"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.test_all_proxies()
def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION):
"""为特定VM强制轮换代理"""
logger.info(f"Force rotating proxy for VM {vm_path}")
# 这里需要重新创建实例来应用新的代理配置
# 在实际应用中,可能需要保存当前状态并恢复
proxy_pool = get_global_proxy_pool()
new_proxy = proxy_pool.get_next_proxy()
if new_proxy:
logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}")
return True
else:
logger.warning(f"No available proxy for VM {vm_path}")
return False

View File

@@ -0,0 +1,261 @@
import boto3
from botocore.exceptions import ClientError
import base64
import logging
import json
from typing import Optional
from desktop_env.providers.base import Provider
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool, ProxyInfo
logger = logging.getLogger("desktopenv.providers.aws.AWSProviderWithProxy")
logger.setLevel(logging.INFO)
WAIT_DELAY = 15
MAX_ATTEMPTS = 10
class AWSProviderWithProxy(Provider):
def __init__(self, region: str = None, proxy_config_file: str = None):
super().__init__(region)
self.current_proxy: Optional[ProxyInfo] = None
# 初始化代理池
if proxy_config_file:
init_proxy_pool(proxy_config_file)
logger.info(f"Initialized proxy pool from {proxy_config_file}")
# 获取下一个可用代理
self._rotate_proxy()
def _rotate_proxy(self):
"""轮换到下一个可用代理"""
proxy_pool = get_global_proxy_pool()
self.current_proxy = proxy_pool.get_next_proxy()
if self.current_proxy:
logger.info(f"Switched to proxy: {self.current_proxy.host}:{self.current_proxy.port}")
else:
logger.warning("No proxy available, using direct connection")
def _generate_proxy_user_data(self) -> str:
"""生成包含代理配置的user data脚本"""
if not self.current_proxy:
return ""
proxy_url = self._format_proxy_url(self.current_proxy)
user_data_script = f"""#!/bin/bash
# 配置系统代理
echo 'export http_proxy={proxy_url}' >> /etc/environment
echo 'export https_proxy={proxy_url}' >> /etc/environment
echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment
echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment
# 配置apt代理
cat > /etc/apt/apt.conf.d/95proxy << EOF
Acquire::http::Proxy "{proxy_url}";
Acquire::https::Proxy "{proxy_url}";
EOF
# 配置chrome/chromium代理
mkdir -p /etc/opt/chrome/policies/managed
cat > /etc/opt/chrome/policies/managed/proxy.json << EOF
{{
"ProxyMode": "fixed_servers",
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
}}
EOF
# 配置firefox代理
mkdir -p /etc/firefox/policies
cat > /etc/firefox/policies/policies.json << EOF
{{
"policies": {{
"Proxy": {{
"Mode": "manual",
"HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"UseHTTPProxyForAllProtocols": true
}}
}}
}}
EOF
# 重新加载环境变量
source /etc/environment
# 记录代理配置日志
echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log
"""
return base64.b64encode(user_data_script.encode()).decode()
def _format_proxy_url(self, proxy: ProxyInfo) -> str:
"""格式化代理URL"""
if proxy.username and proxy.password:
return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}"
else:
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
def start_emulator(self, path_to_vm: str, headless: bool):
logger.info("Starting AWS VM with proxy configuration...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
# 如果实例已经存在,直接启动
ec2_client.start_instances(InstanceIds=[path_to_vm])
logger.info(f"Instance {path_to_vm} is starting...")
# Wait for the instance to be in the 'running' state
waiter = ec2_client.get_waiter('instance_running')
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
logger.info(f"Instance {path_to_vm} is now running.")
except ClientError as e:
logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}")
raise
def create_instance_with_proxy(self, image_id: str, instance_type: str,
security_groups: list, subnet_id: str) -> str:
"""创建带有代理配置的新实例"""
ec2_client = boto3.client('ec2', region_name=self.region)
user_data = self._generate_proxy_user_data()
run_instances_params = {
"MaxCount": 1,
"MinCount": 1,
"ImageId": image_id,
"InstanceType": instance_type,
"EbsOptimized": True,
"NetworkInterfaces": [
{
"SubnetId": subnet_id,
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": security_groups
}
]
}
if user_data:
run_instances_params["UserData"] = user_data
try:
response = ec2_client.run_instances(**run_instances_params)
instance_id = response['Instances'][0]['InstanceId']
logger.info(f"Created new instance {instance_id} with proxy configuration")
# 等待实例运行
logger.info(f"Waiting for instance {instance_id} to be running...")
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
logger.info(f"Instance {instance_id} is ready.")
return instance_id
except ClientError as e:
logger.error(f"Failed to create instance with proxy: {str(e)}")
# 如果当前代理失败,尝试轮换代理
if self.current_proxy:
proxy_pool = get_global_proxy_pool()
proxy_pool.mark_proxy_failed(self.current_proxy)
self._rotate_proxy()
raise
def get_ip_address(self, path_to_vm: str) -> str:
logger.info("Getting AWS VM IP address...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
response = ec2_client.describe_instances(InstanceIds=[path_to_vm])
for reservation in response['Reservations']:
for instance in reservation['Instances']:
private_ip_address = instance.get('PrivateIpAddress', '')
return private_ip_address
return ''
except ClientError as e:
logger.error(f"Failed to retrieve private IP address for the instance {path_to_vm}: {str(e)}")
raise
def save_state(self, path_to_vm: str, snapshot_name: str):
logger.info("Saving AWS VM state...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
image_response = ec2_client.create_image(InstanceId=path_to_vm, Name=snapshot_name)
image_id = image_response['ImageId']
logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.")
return image_id
except ClientError as e:
logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}")
raise
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
# 获取原实例详情
instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm])
instance = instance_details['Reservations'][0]['Instances'][0]
security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']]
subnet_id = instance['SubnetId']
instance_type = instance['InstanceType']
# 终止旧实例
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
logger.info(f"Old instance {path_to_vm} has been terminated.")
# 轮换到新的代理
self._rotate_proxy()
# 创建新实例
new_instance_id = self.create_instance_with_proxy(
snapshot_name, instance_type, security_groups, subnet_id
)
return new_instance_id
except ClientError as e:
logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}")
raise
def stop_emulator(self, path_to_vm, region=None):
logger.info(f"Stopping AWS VM {path_to_vm}...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
ec2_client.stop_instances(InstanceIds=[path_to_vm])
waiter = ec2_client.get_waiter('instance_stopped')
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
logger.info(f"Instance {path_to_vm} has been stopped.")
except ClientError as e:
logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}")
raise
def get_current_proxy_info(self) -> Optional[dict]:
"""获取当前代理信息"""
if self.current_proxy:
return {
'host': self.current_proxy.host,
'port': self.current_proxy.port,
'protocol': self.current_proxy.protocol,
'failed_count': self.current_proxy.failed_count
}
return None
def force_rotate_proxy(self):
"""强制轮换代理"""
logger.info("Force rotating proxy...")
if self.current_proxy:
proxy_pool = get_global_proxy_pool()
proxy_pool.mark_proxy_failed(self.current_proxy)
self._rotate_proxy()
def get_proxy_stats(self) -> dict:
"""获取代理池统计信息"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.get_stats()

View File

@@ -0,0 +1,193 @@
import random
import requests
import logging
import time
from typing import List, Dict, Optional
from dataclasses import dataclass
from threading import Lock
import json
logger = logging.getLogger("desktopenv.providers.aws.ProxyPool")
logger.setLevel(logging.INFO)
@dataclass
class ProxyInfo:
host: str
port: int
username: Optional[str] = None
password: Optional[str] = None
protocol: str = "http" # http, https, socks5
failed_count: int = 0
last_used: float = 0
is_active: bool = True
class ProxyPool:
def __init__(self, config_file: str = None):
self.proxies: List[ProxyInfo] = []
self.current_index = 0
self.lock = Lock()
self.max_failures = 3 # 最大失败次数
self.cooldown_time = 300 # 5分钟冷却时间
if config_file:
self.load_proxies_from_file(config_file)
def load_proxies_from_file(self, config_file: str):
"""从配置文件加载代理列表"""
try:
with open(config_file, 'r') as f:
proxy_configs = json.load(f)
for config in proxy_configs:
proxy = ProxyInfo(
host=config['host'],
port=config['port'],
username=config.get('username'),
password=config.get('password'),
protocol=config.get('protocol', 'http')
)
self.proxies.append(proxy)
logger.info(f"Loaded {len(self.proxies)} proxies from {config_file}")
except Exception as e:
logger.error(f"Failed to load proxies from {config_file}: {e}")
def add_proxy(self, host: str, port: int, username: str = None,
password: str = None, protocol: str = "http"):
"""添加代理到池中"""
proxy = ProxyInfo(host=host, port=port, username=username,
password=password, protocol=protocol)
with self.lock:
self.proxies.append(proxy)
logger.info(f"Added proxy {host}:{port}")
def get_next_proxy(self) -> Optional[ProxyInfo]:
"""获取下一个可用的代理"""
with self.lock:
if not self.proxies:
return None
# 过滤掉失败次数过多的代理
active_proxies = [p for p in self.proxies if self._is_proxy_available(p)]
if not active_proxies:
logger.warning("No active proxies available")
return None
# 轮询选择代理
proxy = active_proxies[self.current_index % len(active_proxies)]
self.current_index += 1
proxy.last_used = time.time()
return proxy
def _is_proxy_available(self, proxy: ProxyInfo) -> bool:
"""检查代理是否可用"""
if not proxy.is_active:
return False
if proxy.failed_count >= self.max_failures:
# 检查是否过了冷却时间
if time.time() - proxy.last_used < self.cooldown_time:
return False
else:
# 重置失败计数
proxy.failed_count = 0
return True
def mark_proxy_failed(self, proxy: ProxyInfo):
"""标记代理失败"""
with self.lock:
proxy.failed_count += 1
if proxy.failed_count >= self.max_failures:
logger.warning(f"Proxy {proxy.host}:{proxy.port} marked as failed "
f"(failures: {proxy.failed_count})")
def mark_proxy_success(self, proxy: ProxyInfo):
"""标记代理成功"""
with self.lock:
proxy.failed_count = 0
def test_proxy(self, proxy: ProxyInfo, test_url: str = "http://httpbin.org/ip",
timeout: int = 10) -> bool:
"""测试代理是否正常工作"""
try:
proxy_url = self._format_proxy_url(proxy)
proxies = {
'http': proxy_url,
'https': proxy_url
}
response = requests.get(test_url, proxies=proxies, timeout=timeout)
if response.status_code == 200:
self.mark_proxy_success(proxy)
return True
else:
self.mark_proxy_failed(proxy)
return False
except Exception as e:
logger.debug(f"Proxy test failed for {proxy.host}:{proxy.port}: {e}")
self.mark_proxy_failed(proxy)
return False
def _format_proxy_url(self, proxy: ProxyInfo) -> str:
"""格式化代理URL"""
if proxy.username and proxy.password:
return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}"
else:
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
def get_proxy_dict(self, proxy: ProxyInfo) -> Dict[str, str]:
"""获取requests库使用的代理字典"""
proxy_url = self._format_proxy_url(proxy)
return {
'http': proxy_url,
'https': proxy_url
}
def test_all_proxies(self, test_url: str = "http://httpbin.org/ip"):
"""测试所有代理"""
logger.info("Testing all proxies...")
working_count = 0
for proxy in self.proxies:
if self.test_proxy(proxy, test_url):
working_count += 1
logger.info(f"✓ Proxy {proxy.host}:{proxy.port} is working")
else:
logger.warning(f"✗ Proxy {proxy.host}:{proxy.port} failed")
logger.info(f"Proxy test completed: {working_count}/{len(self.proxies)} working")
return working_count
def get_stats(self) -> Dict:
"""获取代理池统计信息"""
with self.lock:
total = len(self.proxies)
active = len([p for p in self.proxies if self._is_proxy_available(p)])
failed = len([p for p in self.proxies if p.failed_count >= self.max_failures])
return {
'total': total,
'active': active,
'failed': failed,
'success_rate': active / total if total > 0 else 0
}
# 全局代理池实例
_proxy_pool = None
def get_global_proxy_pool() -> ProxyPool:
"""获取全局代理池实例"""
global _proxy_pool
if _proxy_pool is None:
_proxy_pool = ProxyPool()
return _proxy_pool
def init_proxy_pool(config_file: str = None):
"""初始化全局代理池"""
global _proxy_pool
_proxy_pool = ProxyPool(config_file)
return _proxy_pool