diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 4a24a55..2a13afa 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -54,13 +54,17 @@ class DesktopEnv(gym.Env): """ # Initialize VM manager and vitualization provider self.region = region + self.provider_name = provider_name # Default TODO: self.server_port = 5000 self.chromium_port = 9222 self.vnc_port = 8006 self.vlc_port = 8080 - self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) + + # Initialize with default (no proxy) provider + self.current_use_proxy = False + self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False) self.os_type = os_type @@ -149,6 +153,32 @@ class DesktopEnv(gym.Env): self._step_no = 0 self.action_history.clear() + # Check and handle proxy requirement changes BEFORE starting emulator + if task_config is not None: + task_use_proxy = task_config.get("proxy", False) + if task_use_proxy != self.current_use_proxy: + logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}") + + # Close current provider if it exists + if hasattr(self, 'provider') and self.provider: + try: + self.provider.stop_emulator(self.path_to_vm) + except Exception as e: + logger.warning(f"Failed to stop current provider: {e}") + + # Create new provider with appropriate proxy setting + self.current_use_proxy = task_use_proxy + self.manager, self.provider = create_vm_manager_and_provider( + self.provider_name, + self.region, + use_proxy=task_use_proxy + ) + + if task_use_proxy: + logger.info("Using proxy-enabled AWS provider.") + else: + logger.info("Using regular AWS provider.") + logger.info("Reverting to snapshot to {}...".format(self.snapshot_name)) self._revert_to_snapshot() logger.info("Starting emulator...") @@ -184,12 +214,17 @@ class DesktopEnv(gym.Env): return self.controller.get_vm_screen_size() def _set_task_info(self, task_config: Dict[str, Any]): + """Set task info (proxy logic is handled in reset method)""" self.task_id: str = task_config["id"] self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id) os.makedirs(self.cache_dir, exist_ok=True) self.instruction = task_config["instruction"] self.config = task_config["config"] if "config" in task_config else [] + + self._set_evaluator_info(task_config) + def _set_evaluator_info(self, task_config: Dict[str, Any]): + """Set evaluator information from task config""" # evaluator dict # func -> metric function string, or list of metric function strings # conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or" diff --git a/desktop_env/providers/__init__.py b/desktop_env/providers/__init__.py index d1359a1..19d9b98 100644 --- a/desktop_env/providers/__init__.py +++ b/desktop_env/providers/__init__.py @@ -1,9 +1,14 @@ from desktop_env.providers.base import VMManager, Provider -def create_vm_manager_and_provider(provider_name: str, region: str): +def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: bool = False): """ Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name. + + Args: + provider_name (str): The name of the provider (e.g., "aws", "vmware", etc.) + region (str): The region for the provider + use_proxy (bool): Whether to use proxy-enabled providers (currently only supported for AWS) """ provider_name = provider_name.lower().strip() if provider_name == "vmware": @@ -15,9 +20,16 @@ def create_vm_manager_and_provider(provider_name: str, region: str): from desktop_env.providers.virtualbox.provider import VirtualBoxProvider return VirtualBoxVMManager(), VirtualBoxProvider(region) elif provider_name in ["aws", "amazon web services"]: - from desktop_env.providers.aws.manager import AWSVMManager - from desktop_env.providers.aws.provider import AWSProvider - return AWSVMManager(), AWSProvider(region) + if use_proxy: + # Use proxy-enabled AWS provider + from desktop_env.providers.aws.manager_with_proxy import AWSVMManagerWithProxy + from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy + return AWSVMManagerWithProxy(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json") + else: + # Use regular AWS provider + from desktop_env.providers.aws.manager import AWSVMManager + from desktop_env.providers.aws.provider import AWSProvider + return AWSVMManager(), AWSProvider(region) elif provider_name == "azure": from desktop_env.providers.azure.manager import AzureVMManager from desktop_env.providers.azure.provider import AzureProvider diff --git a/desktop_env/providers/aws/manager_with_proxy.py b/desktop_env/providers/aws/manager_with_proxy.py new file mode 100644 index 0000000..f4150ec --- /dev/null +++ b/desktop_env/providers/aws/manager_with_proxy.py @@ -0,0 +1,329 @@ +import os +from filelock import FileLock +import boto3 +import psutil +import logging + +from desktop_env.providers.base import VMManager +from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool + +logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy") +logger.setLevel(logging.INFO) + +REGISTRY_PATH = '.aws_vms_proxy' + +DEFAULT_REGION = "us-east-1" +IMAGE_ID_MAP = { + "us-east-1": "ami-05e7d7bd279ea4f14", + "ap-east-1": "ami-0c092a5b8be4116f5" +} + +INSTANCE_TYPE = "t3.medium" + +NETWORK_INTERFACE_MAP = { + "us-east-1": [ + { + "SubnetId": "subnet-037edfff66c2eb894", + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": [ + "sg-0342574803206ee9c" + ] + } + ], + "ap-east-1": [ + { + "SubnetId": "subnet-011060501be0b589c", + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": [ + "sg-090470e64df78f6eb" + ] + } + ] +} + + +def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None): + """分配带有代理配置的VM""" + from .provider_with_proxy import AWSProviderWithProxy + + # 初始化代理池(如果还没有初始化) + if proxy_config_file: + init_proxy_pool(proxy_config_file) + + # 获取当前代理 + proxy_pool = get_global_proxy_pool() + current_proxy = proxy_pool.get_next_proxy() + + if current_proxy: + logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}") + + # 创建provider实例 + provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file) + + # 创建新实例 + instance_id = provider.create_instance_with_proxy( + image_id=IMAGE_ID_MAP[region], + instance_type=INSTANCE_TYPE, + security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"], + subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"] + ) + + return instance_id + + +class AWSVMManagerWithProxy(VMManager): + def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None): + self.registry_path = registry_path + self.lock = FileLock(".aws_proxy_lck", timeout=60) + self.proxy_config_file = proxy_config_file + self.initialize_registry() + + # 初始化代理池 + if proxy_config_file: + init_proxy_pool(proxy_config_file) + logger.info(f"Proxy pool initialized with config: {proxy_config_file}") + + def initialize_registry(self): + with self.lock: + if not os.path.exists(self.registry_path): + with open(self.registry_path, 'w') as file: + file.write('') + + def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True): + if lock_needed: + with self.lock: + self._add_vm(vm_path, region, proxy_info) + else: + self._add_vm(vm_path, region, proxy_info) + + def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None): + with open(self.registry_path, 'r') as file: + lines = file.readlines() + + # 格式: vm_path@region|status|proxy_host:proxy_port + vm_path_at_vm_region = f"{vm_path}@{region}" + proxy_str = "" + if proxy_info: + proxy_str = f"{proxy_info['host']}:{proxy_info['port']}" + + new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n' + new_lines = lines + [new_line] + + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + self._delete_vm(vm_path, region) + else: + self._delete_vm(vm_path, region) + + def _delete_vm(self, vm_path, region=DEFAULT_REGION): + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + parts = line.strip().split('|') + if len(parts) >= 2: + vm_path_at_vm_region = parts[0] + if vm_path_at_vm_region == f"{vm_path}@{region}": + continue + new_lines.append(line) + + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + self._occupy_vm(vm_path, pid, region) + else: + self._occupy_vm(vm_path, pid, region) + + def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION): + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + parts = line.strip().split('|') + if len(parts) >= 2: + registered_vm_path = parts[0] + if registered_vm_path == f"{vm_path}@{region}": + proxy_str = parts[2] if len(parts) > 2 else "" + new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n') + else: + new_lines.append(line) + else: + new_lines.append(line) + + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def check_and_clean(self, lock_needed=True): + if lock_needed: + with self.lock: + self._check_and_clean() + else: + self._check_and_clean() + + def _check_and_clean(self): + # Get active PIDs + active_pids = {p.pid for p in psutil.process_iter()} + + new_lines = [] + vm_path_at_vm_regions = {} + + with open(self.registry_path, 'r') as file: + lines = file.readlines() + + # Collect all VM paths and their regions + for line in lines: + parts = line.strip().split('|') + if len(parts) >= 2: + vm_path_at_vm_region = parts[0] + status = parts[1] + proxy_str = parts[2] if len(parts) > 2 else "" + + vm_path, vm_region = vm_path_at_vm_region.split("@") + if vm_region not in vm_path_at_vm_regions: + vm_path_at_vm_regions[vm_region] = [] + vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str)) + + # Process each region + for region, vm_info_list in vm_path_at_vm_regions.items(): + ec2_client = boto3.client('ec2', region_name=region) + instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list] + + try: + response = ec2_client.describe_instances(InstanceIds=instance_ids) + reservations = response.get('Reservations', []) + + terminated_ids = set() + stopped_ids = set() + active_ids = set() + + for reservation in reservations: + for instance in reservation.get('Instances', []): + instance_id = instance.get('InstanceId') + instance_state = instance['State']['Name'] + if instance_state in ['terminated', 'shutting-down']: + terminated_ids.add(instance_id) + elif instance_state == 'stopped': + stopped_ids.add(instance_id) + else: + active_ids.add(instance_id) + + for vm_path_at_vm_region, status, proxy_str in vm_info_list: + vm_path = vm_path_at_vm_region.split('@')[0] + + if vm_path in terminated_ids: + logger.info(f"VM {vm_path} not found or terminated, releasing it.") + continue + elif vm_path in stopped_ids: + logger.info(f"VM {vm_path} stopped, mark it as free") + new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') + continue + + if status == "free": + new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') + elif status.isdigit() and int(status) in active_pids: + new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') + else: + new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') + + except Exception as e: + logger.error(f"Error checking instances in region {region}: {e}") + continue + + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + return self._list_free_vms(region) + else: + return self._list_free_vms(region) + + def _list_free_vms(self, region=DEFAULT_REGION): + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + parts = line.strip().split('|') + if len(parts) >= 2: + vm_path_at_vm_region = parts[0] + status = parts[1] + proxy_str = parts[2] if len(parts) > 2 else "" + + vm_path, vm_region = vm_path_at_vm_region.split("@") + if status == "free" and vm_region == region: + free_vms.append((vm_path, status, proxy_str)) + + return free_vms + + def get_vm_path(self, region=DEFAULT_REGION): + with self.lock: + if not AWSVMManagerWithProxy.checked_and_cleaned: + AWSVMManagerWithProxy.checked_and_cleaned = True + self._check_and_clean() + + allocation_needed = False + with self.lock: + free_vms_paths = self._list_free_vms(region) + + if len(free_vms_paths) == 0: + allocation_needed = True + else: + chosen_vm_path, _, proxy_str = free_vms_paths[0] + self._occupy_vm(chosen_vm_path, os.getpid(), region) + logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}") + return chosen_vm_path + + if allocation_needed: + logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕") + new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file) + + # 获取当前使用的代理信息 + proxy_pool = get_global_proxy_pool() + current_proxy = proxy_pool.get_next_proxy() + proxy_info = None + if current_proxy: + proxy_info = { + 'host': current_proxy.host, + 'port': current_proxy.port + } + + with self.lock: + self._add_vm(new_vm_path, region, proxy_info) + self._occupy_vm(new_vm_path, os.getpid(), region) + return new_vm_path + + def get_proxy_stats(self): + """获取代理池统计信息""" + proxy_pool = get_global_proxy_pool() + return proxy_pool.get_stats() + + def test_all_proxies(self): + """测试所有代理""" + proxy_pool = get_global_proxy_pool() + return proxy_pool.test_all_proxies() + + def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION): + """为特定VM强制轮换代理""" + logger.info(f"Force rotating proxy for VM {vm_path}") + + # 这里需要重新创建实例来应用新的代理配置 + # 在实际应用中,可能需要保存当前状态并恢复 + proxy_pool = get_global_proxy_pool() + new_proxy = proxy_pool.get_next_proxy() + + if new_proxy: + logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}") + return True + else: + logger.warning(f"No available proxy for VM {vm_path}") + return False \ No newline at end of file diff --git a/desktop_env/providers/aws/provider_with_proxy.py b/desktop_env/providers/aws/provider_with_proxy.py new file mode 100644 index 0000000..309e71b --- /dev/null +++ b/desktop_env/providers/aws/provider_with_proxy.py @@ -0,0 +1,261 @@ +import boto3 +from botocore.exceptions import ClientError +import base64 +import logging +import json +from typing import Optional + +from desktop_env.providers.base import Provider +from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool, ProxyInfo + +logger = logging.getLogger("desktopenv.providers.aws.AWSProviderWithProxy") +logger.setLevel(logging.INFO) + +WAIT_DELAY = 15 +MAX_ATTEMPTS = 10 + + +class AWSProviderWithProxy(Provider): + + def __init__(self, region: str = None, proxy_config_file: str = None): + super().__init__(region) + self.current_proxy: Optional[ProxyInfo] = None + + # 初始化代理池 + if proxy_config_file: + init_proxy_pool(proxy_config_file) + logger.info(f"Initialized proxy pool from {proxy_config_file}") + + # 获取下一个可用代理 + self._rotate_proxy() + + def _rotate_proxy(self): + """轮换到下一个可用代理""" + proxy_pool = get_global_proxy_pool() + self.current_proxy = proxy_pool.get_next_proxy() + + if self.current_proxy: + logger.info(f"Switched to proxy: {self.current_proxy.host}:{self.current_proxy.port}") + else: + logger.warning("No proxy available, using direct connection") + + def _generate_proxy_user_data(self) -> str: + """生成包含代理配置的user data脚本""" + if not self.current_proxy: + return "" + + proxy_url = self._format_proxy_url(self.current_proxy) + + user_data_script = f"""#!/bin/bash +# 配置系统代理 +echo 'export http_proxy={proxy_url}' >> /etc/environment +echo 'export https_proxy={proxy_url}' >> /etc/environment +echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment +echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment + +# 配置apt代理 +cat > /etc/apt/apt.conf.d/95proxy << EOF +Acquire::http::Proxy "{proxy_url}"; +Acquire::https::Proxy "{proxy_url}"; +EOF + +# 配置chrome/chromium代理 +mkdir -p /etc/opt/chrome/policies/managed +cat > /etc/opt/chrome/policies/managed/proxy.json << EOF +{{ + "ProxyMode": "fixed_servers", + "ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}" +}} +EOF + +# 配置firefox代理 +mkdir -p /etc/firefox/policies +cat > /etc/firefox/policies/policies.json << EOF +{{ + "policies": {{ + "Proxy": {{ + "Mode": "manual", + "HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}", + "HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}", + "UseHTTPProxyForAllProtocols": true + }} + }} +}} +EOF + +# 重新加载环境变量 +source /etc/environment + +# 记录代理配置日志 +echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log +""" + + return base64.b64encode(user_data_script.encode()).decode() + + def _format_proxy_url(self, proxy: ProxyInfo) -> str: + """格式化代理URL""" + if proxy.username and proxy.password: + return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}" + else: + return f"{proxy.protocol}://{proxy.host}:{proxy.port}" + + def start_emulator(self, path_to_vm: str, headless: bool): + logger.info("Starting AWS VM with proxy configuration...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + # 如果实例已经存在,直接启动 + ec2_client.start_instances(InstanceIds=[path_to_vm]) + logger.info(f"Instance {path_to_vm} is starting...") + + # Wait for the instance to be in the 'running' state + waiter = ec2_client.get_waiter('instance_running') + waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS}) + logger.info(f"Instance {path_to_vm} is now running.") + + except ClientError as e: + logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}") + raise + + def create_instance_with_proxy(self, image_id: str, instance_type: str, + security_groups: list, subnet_id: str) -> str: + """创建带有代理配置的新实例""" + ec2_client = boto3.client('ec2', region_name=self.region) + + user_data = self._generate_proxy_user_data() + + run_instances_params = { + "MaxCount": 1, + "MinCount": 1, + "ImageId": image_id, + "InstanceType": instance_type, + "EbsOptimized": True, + "NetworkInterfaces": [ + { + "SubnetId": subnet_id, + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": security_groups + } + ] + } + + if user_data: + run_instances_params["UserData"] = user_data + + try: + response = ec2_client.run_instances(**run_instances_params) + instance_id = response['Instances'][0]['InstanceId'] + + logger.info(f"Created new instance {instance_id} with proxy configuration") + + # 等待实例运行 + logger.info(f"Waiting for instance {instance_id} to be running...") + ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id]) + logger.info(f"Instance {instance_id} is ready.") + + return instance_id + + except ClientError as e: + logger.error(f"Failed to create instance with proxy: {str(e)}") + # 如果当前代理失败,尝试轮换代理 + if self.current_proxy: + proxy_pool = get_global_proxy_pool() + proxy_pool.mark_proxy_failed(self.current_proxy) + self._rotate_proxy() + raise + + def get_ip_address(self, path_to_vm: str) -> str: + logger.info("Getting AWS VM IP address...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + response = ec2_client.describe_instances(InstanceIds=[path_to_vm]) + for reservation in response['Reservations']: + for instance in reservation['Instances']: + private_ip_address = instance.get('PrivateIpAddress', '') + return private_ip_address + return '' + except ClientError as e: + logger.error(f"Failed to retrieve private IP address for the instance {path_to_vm}: {str(e)}") + raise + + def save_state(self, path_to_vm: str, snapshot_name: str): + logger.info("Saving AWS VM state...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + image_response = ec2_client.create_image(InstanceId=path_to_vm, Name=snapshot_name) + image_id = image_response['ImageId'] + logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.") + return image_id + except ClientError as e: + logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}") + raise + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + # 获取原实例详情 + instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm]) + instance = instance_details['Reservations'][0]['Instances'][0] + security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']] + subnet_id = instance['SubnetId'] + instance_type = instance['InstanceType'] + + # 终止旧实例 + ec2_client.terminate_instances(InstanceIds=[path_to_vm]) + logger.info(f"Old instance {path_to_vm} has been terminated.") + + # 轮换到新的代理 + self._rotate_proxy() + + # 创建新实例 + new_instance_id = self.create_instance_with_proxy( + snapshot_name, instance_type, security_groups, subnet_id + ) + + return new_instance_id + + except ClientError as e: + logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}") + raise + + def stop_emulator(self, path_to_vm, region=None): + logger.info(f"Stopping AWS VM {path_to_vm}...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + ec2_client.stop_instances(InstanceIds=[path_to_vm]) + waiter = ec2_client.get_waiter('instance_stopped') + waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS}) + logger.info(f"Instance {path_to_vm} has been stopped.") + except ClientError as e: + logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}") + raise + + def get_current_proxy_info(self) -> Optional[dict]: + """获取当前代理信息""" + if self.current_proxy: + return { + 'host': self.current_proxy.host, + 'port': self.current_proxy.port, + 'protocol': self.current_proxy.protocol, + 'failed_count': self.current_proxy.failed_count + } + return None + + def force_rotate_proxy(self): + """强制轮换代理""" + logger.info("Force rotating proxy...") + if self.current_proxy: + proxy_pool = get_global_proxy_pool() + proxy_pool.mark_proxy_failed(self.current_proxy) + self._rotate_proxy() + + def get_proxy_stats(self) -> dict: + """获取代理池统计信息""" + proxy_pool = get_global_proxy_pool() + return proxy_pool.get_stats() \ No newline at end of file diff --git a/desktop_env/providers/aws/proxy_pool.py b/desktop_env/providers/aws/proxy_pool.py new file mode 100644 index 0000000..812df18 --- /dev/null +++ b/desktop_env/providers/aws/proxy_pool.py @@ -0,0 +1,193 @@ +import random +import requests +import logging +import time +from typing import List, Dict, Optional +from dataclasses import dataclass +from threading import Lock +import json + +logger = logging.getLogger("desktopenv.providers.aws.ProxyPool") +logger.setLevel(logging.INFO) + +@dataclass +class ProxyInfo: + host: str + port: int + username: Optional[str] = None + password: Optional[str] = None + protocol: str = "http" # http, https, socks5 + failed_count: int = 0 + last_used: float = 0 + is_active: bool = True + +class ProxyPool: + def __init__(self, config_file: str = None): + self.proxies: List[ProxyInfo] = [] + self.current_index = 0 + self.lock = Lock() + self.max_failures = 3 # 最大失败次数 + self.cooldown_time = 300 # 5分钟冷却时间 + + if config_file: + self.load_proxies_from_file(config_file) + + def load_proxies_from_file(self, config_file: str): + """从配置文件加载代理列表""" + try: + with open(config_file, 'r') as f: + proxy_configs = json.load(f) + + for config in proxy_configs: + proxy = ProxyInfo( + host=config['host'], + port=config['port'], + username=config.get('username'), + password=config.get('password'), + protocol=config.get('protocol', 'http') + ) + self.proxies.append(proxy) + + logger.info(f"Loaded {len(self.proxies)} proxies from {config_file}") + except Exception as e: + logger.error(f"Failed to load proxies from {config_file}: {e}") + + def add_proxy(self, host: str, port: int, username: str = None, + password: str = None, protocol: str = "http"): + """添加代理到池中""" + proxy = ProxyInfo(host=host, port=port, username=username, + password=password, protocol=protocol) + with self.lock: + self.proxies.append(proxy) + logger.info(f"Added proxy {host}:{port}") + + def get_next_proxy(self) -> Optional[ProxyInfo]: + """获取下一个可用的代理""" + with self.lock: + if not self.proxies: + return None + + # 过滤掉失败次数过多的代理 + active_proxies = [p for p in self.proxies if self._is_proxy_available(p)] + + if not active_proxies: + logger.warning("No active proxies available") + return None + + # 轮询选择代理 + proxy = active_proxies[self.current_index % len(active_proxies)] + self.current_index += 1 + proxy.last_used = time.time() + + return proxy + + def _is_proxy_available(self, proxy: ProxyInfo) -> bool: + """检查代理是否可用""" + if not proxy.is_active: + return False + + if proxy.failed_count >= self.max_failures: + # 检查是否过了冷却时间 + if time.time() - proxy.last_used < self.cooldown_time: + return False + else: + # 重置失败计数 + proxy.failed_count = 0 + + return True + + def mark_proxy_failed(self, proxy: ProxyInfo): + """标记代理失败""" + with self.lock: + proxy.failed_count += 1 + if proxy.failed_count >= self.max_failures: + logger.warning(f"Proxy {proxy.host}:{proxy.port} marked as failed " + f"(failures: {proxy.failed_count})") + + def mark_proxy_success(self, proxy: ProxyInfo): + """标记代理成功""" + with self.lock: + proxy.failed_count = 0 + + def test_proxy(self, proxy: ProxyInfo, test_url: str = "http://httpbin.org/ip", + timeout: int = 10) -> bool: + """测试代理是否正常工作""" + try: + proxy_url = self._format_proxy_url(proxy) + proxies = { + 'http': proxy_url, + 'https': proxy_url + } + + response = requests.get(test_url, proxies=proxies, timeout=timeout) + if response.status_code == 200: + self.mark_proxy_success(proxy) + return True + else: + self.mark_proxy_failed(proxy) + return False + + except Exception as e: + logger.debug(f"Proxy test failed for {proxy.host}:{proxy.port}: {e}") + self.mark_proxy_failed(proxy) + return False + + def _format_proxy_url(self, proxy: ProxyInfo) -> str: + """格式化代理URL""" + if proxy.username and proxy.password: + return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}" + else: + return f"{proxy.protocol}://{proxy.host}:{proxy.port}" + + def get_proxy_dict(self, proxy: ProxyInfo) -> Dict[str, str]: + """获取requests库使用的代理字典""" + proxy_url = self._format_proxy_url(proxy) + return { + 'http': proxy_url, + 'https': proxy_url + } + + def test_all_proxies(self, test_url: str = "http://httpbin.org/ip"): + """测试所有代理""" + logger.info("Testing all proxies...") + working_count = 0 + + for proxy in self.proxies: + if self.test_proxy(proxy, test_url): + working_count += 1 + logger.info(f"✓ Proxy {proxy.host}:{proxy.port} is working") + else: + logger.warning(f"✗ Proxy {proxy.host}:{proxy.port} failed") + + logger.info(f"Proxy test completed: {working_count}/{len(self.proxies)} working") + return working_count + + def get_stats(self) -> Dict: + """获取代理池统计信息""" + with self.lock: + total = len(self.proxies) + active = len([p for p in self.proxies if self._is_proxy_available(p)]) + failed = len([p for p in self.proxies if p.failed_count >= self.max_failures]) + + return { + 'total': total, + 'active': active, + 'failed': failed, + 'success_rate': active / total if total > 0 else 0 + } + +# 全局代理池实例 +_proxy_pool = None + +def get_global_proxy_pool() -> ProxyPool: + """获取全局代理池实例""" + global _proxy_pool + if _proxy_pool is None: + _proxy_pool = ProxyPool() + return _proxy_pool + +def init_proxy_pool(config_file: str = None): + """初始化全局代理池""" + global _proxy_pool + _proxy_pool = ProxyPool(config_file) + return _proxy_pool \ No newline at end of file