Merge branch 'feat/aws-provider-support' of https://github.com/xlang-ai/OSWorld into feat/aws-provider-support
This commit is contained in:
@@ -54,13 +54,17 @@ class DesktopEnv(gym.Env):
|
|||||||
"""
|
"""
|
||||||
# Initialize VM manager and vitualization provider
|
# Initialize VM manager and vitualization provider
|
||||||
self.region = region
|
self.region = region
|
||||||
|
self.provider_name = provider_name
|
||||||
|
|
||||||
# Default TODO:
|
# Default TODO:
|
||||||
self.server_port = 5000
|
self.server_port = 5000
|
||||||
self.chromium_port = 9222
|
self.chromium_port = 9222
|
||||||
self.vnc_port = 8006
|
self.vnc_port = 8006
|
||||||
self.vlc_port = 8080
|
self.vlc_port = 8080
|
||||||
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
|
|
||||||
|
# Initialize with default (no proxy) provider
|
||||||
|
self.current_use_proxy = False
|
||||||
|
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False)
|
||||||
|
|
||||||
self.os_type = os_type
|
self.os_type = os_type
|
||||||
|
|
||||||
@@ -149,6 +153,32 @@ class DesktopEnv(gym.Env):
|
|||||||
self._step_no = 0
|
self._step_no = 0
|
||||||
self.action_history.clear()
|
self.action_history.clear()
|
||||||
|
|
||||||
|
# Check and handle proxy requirement changes BEFORE starting emulator
|
||||||
|
if task_config is not None:
|
||||||
|
task_use_proxy = task_config.get("proxy", False)
|
||||||
|
if task_use_proxy != self.current_use_proxy:
|
||||||
|
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
|
||||||
|
|
||||||
|
# Close current provider if it exists
|
||||||
|
if hasattr(self, 'provider') and self.provider:
|
||||||
|
try:
|
||||||
|
self.provider.stop_emulator(self.path_to_vm)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to stop current provider: {e}")
|
||||||
|
|
||||||
|
# Create new provider with appropriate proxy setting
|
||||||
|
self.current_use_proxy = task_use_proxy
|
||||||
|
self.manager, self.provider = create_vm_manager_and_provider(
|
||||||
|
self.provider_name,
|
||||||
|
self.region,
|
||||||
|
use_proxy=task_use_proxy
|
||||||
|
)
|
||||||
|
|
||||||
|
if task_use_proxy:
|
||||||
|
logger.info("Using proxy-enabled AWS provider.")
|
||||||
|
else:
|
||||||
|
logger.info("Using regular AWS provider.")
|
||||||
|
|
||||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||||
self._revert_to_snapshot()
|
self._revert_to_snapshot()
|
||||||
logger.info("Starting emulator...")
|
logger.info("Starting emulator...")
|
||||||
@@ -184,12 +214,17 @@ class DesktopEnv(gym.Env):
|
|||||||
return self.controller.get_vm_screen_size()
|
return self.controller.get_vm_screen_size()
|
||||||
|
|
||||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||||
|
"""Set task info (proxy logic is handled in reset method)"""
|
||||||
self.task_id: str = task_config["id"]
|
self.task_id: str = task_config["id"]
|
||||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||||
os.makedirs(self.cache_dir, exist_ok=True)
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
self.instruction = task_config["instruction"]
|
self.instruction = task_config["instruction"]
|
||||||
self.config = task_config["config"] if "config" in task_config else []
|
self.config = task_config["config"] if "config" in task_config else []
|
||||||
|
|
||||||
|
self._set_evaluator_info(task_config)
|
||||||
|
|
||||||
|
def _set_evaluator_info(self, task_config: Dict[str, Any]):
|
||||||
|
"""Set evaluator information from task config"""
|
||||||
# evaluator dict
|
# evaluator dict
|
||||||
# func -> metric function string, or list of metric function strings
|
# func -> metric function string, or list of metric function strings
|
||||||
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
|
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
from desktop_env.providers.base import VMManager, Provider
|
from desktop_env.providers.base import VMManager, Provider
|
||||||
|
|
||||||
|
|
||||||
def create_vm_manager_and_provider(provider_name: str, region: str):
|
def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: bool = False):
|
||||||
"""
|
"""
|
||||||
Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name.
|
Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name (str): The name of the provider (e.g., "aws", "vmware", etc.)
|
||||||
|
region (str): The region for the provider
|
||||||
|
use_proxy (bool): Whether to use proxy-enabled providers (currently only supported for AWS)
|
||||||
"""
|
"""
|
||||||
provider_name = provider_name.lower().strip()
|
provider_name = provider_name.lower().strip()
|
||||||
if provider_name == "vmware":
|
if provider_name == "vmware":
|
||||||
@@ -16,8 +21,14 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
|
|||||||
return VirtualBoxVMManager(), VirtualBoxProvider(region)
|
return VirtualBoxVMManager(), VirtualBoxProvider(region)
|
||||||
elif provider_name in ["aws", "amazon web services"]:
|
elif provider_name in ["aws", "amazon web services"]:
|
||||||
from desktop_env.providers.aws.manager import AWSVMManager
|
from desktop_env.providers.aws.manager import AWSVMManager
|
||||||
from desktop_env.providers.aws.provider import AWSProvider
|
if use_proxy:
|
||||||
return AWSVMManager(), AWSProvider(region)
|
# Use proxy-enabled AWS provider
|
||||||
|
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
|
||||||
|
return AWSVMManager(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
|
||||||
|
else:
|
||||||
|
# Use regular AWS provider
|
||||||
|
from desktop_env.providers.aws.provider import AWSProvider
|
||||||
|
return AWSVMManager(), AWSProvider(region)
|
||||||
elif provider_name == "azure":
|
elif provider_name == "azure":
|
||||||
from desktop_env.providers.azure.manager import AzureVMManager
|
from desktop_env.providers.azure.manager import AzureVMManager
|
||||||
from desktop_env.providers.azure.provider import AzureProvider
|
from desktop_env.providers.azure.provider import AzureProvider
|
||||||
|
|||||||
@@ -18,11 +18,16 @@ if not os.getenv('AWS_SUBNET_ID') or not os.getenv('AWS_SECURITY_GROUP_ID'):
|
|||||||
|
|
||||||
from desktop_env.providers.base import VMManager
|
from desktop_env.providers.base import VMManager
|
||||||
|
|
||||||
|
# Import proxy-related modules only when needed
|
||||||
|
try:
|
||||||
|
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
|
||||||
|
PROXY_SUPPORT_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PROXY_SUPPORT_AVAILABLE = False
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
|
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
REGISTRY_PATH = '.aws_vms'
|
|
||||||
|
|
||||||
DEFAULT_REGION = "us-east-1"
|
DEFAULT_REGION = "us-east-1"
|
||||||
# todo: Add doc for the configuration of image, security group and network interface
|
# todo: Add doc for the configuration of image, security group and network interface
|
||||||
# todo: public the AMI images
|
# todo: public the AMI images
|
||||||
@@ -118,18 +123,56 @@ def _allocate_vm(region=DEFAULT_REGION):
|
|||||||
return instance_id
|
return instance_id
|
||||||
|
|
||||||
|
|
||||||
|
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
|
||||||
|
"""Allocate a VM with proxy configuration"""
|
||||||
|
if not PROXY_SUPPORT_AVAILABLE:
|
||||||
|
logger.warning("Proxy support not available, falling back to regular VM allocation")
|
||||||
|
return _allocate_vm(region)
|
||||||
|
|
||||||
|
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
|
||||||
|
|
||||||
|
# Initialize proxy pool if needed
|
||||||
|
if proxy_config_file:
|
||||||
|
init_proxy_pool(proxy_config_file)
|
||||||
|
|
||||||
|
# Get current proxy
|
||||||
|
proxy_pool = get_global_proxy_pool()
|
||||||
|
current_proxy = proxy_pool.get_next_proxy()
|
||||||
|
|
||||||
|
if current_proxy:
|
||||||
|
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
|
||||||
|
|
||||||
|
# Create provider instance
|
||||||
|
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
|
||||||
|
|
||||||
|
# Create new instance
|
||||||
|
instance_id = provider.create_instance_with_proxy(
|
||||||
|
image_id=IMAGE_ID_MAP[region],
|
||||||
|
instance_type=INSTANCE_TYPE,
|
||||||
|
security_groups=[os.getenv('AWS_SECURITY_GROUP_ID')],
|
||||||
|
subnet_id=os.getenv('AWS_SUBNET_ID')
|
||||||
|
)
|
||||||
|
|
||||||
|
return instance_id
|
||||||
|
|
||||||
|
|
||||||
class AWSVMManager(VMManager):
|
class AWSVMManager(VMManager):
|
||||||
"""
|
"""
|
||||||
AWS VM Manager for managing virtual machines on AWS.
|
AWS VM Manager for managing virtual machines on AWS.
|
||||||
|
|
||||||
AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs.
|
AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs.
|
||||||
This class remains the interface of VMManager for compatibility with other components.
|
This class supports both regular VM allocation and proxy-enabled VM allocation.
|
||||||
"""
|
"""
|
||||||
def __init__(self, registry_path=REGISTRY_PATH):
|
def __init__(self, proxy_config_file=None, **kwargs):
|
||||||
self.registry_path = registry_path
|
self.proxy_config_file = proxy_config_file
|
||||||
# self.lock = FileLock(".aws_lck", timeout=60)
|
# self.lock = FileLock(".aws_lck", timeout=60)
|
||||||
self.initialize_registry()
|
self.initialize_registry()
|
||||||
|
|
||||||
|
# Initialize proxy pool if proxy configuration is provided
|
||||||
|
if proxy_config_file and PROXY_SUPPORT_AVAILABLE:
|
||||||
|
init_proxy_pool(proxy_config_file)
|
||||||
|
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
|
||||||
|
|
||||||
def initialize_registry(self, **kwargs):
|
def initialize_registry(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -164,6 +207,10 @@ class AWSVMManager(VMManager):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def get_vm_path(self, region=DEFAULT_REGION, **kwargs):
|
def get_vm_path(self, region=DEFAULT_REGION, **kwargs):
|
||||||
logger.info("Allocating a new VM in region: {}".format(region))
|
if self.proxy_config_file:
|
||||||
new_vm_path = _allocate_vm(region)
|
logger.info("Allocating a new VM with proxy configuration in region: {}".format(region))
|
||||||
|
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
|
||||||
|
else:
|
||||||
|
logger.info("Allocating a new VM in region: {}".format(region))
|
||||||
|
new_vm_path = _allocate_vm(region)
|
||||||
return new_vm_path
|
return new_vm_path
|
||||||
|
|||||||
275
desktop_env/providers/aws/provider_with_proxy.py
Normal file
275
desktop_env/providers/aws/provider_with_proxy.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from desktop_env.providers.base import Provider
|
||||||
|
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool, ProxyInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.providers.aws.AWSProviderWithProxy")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
WAIT_DELAY = 15
|
||||||
|
MAX_ATTEMPTS = 10
|
||||||
|
|
||||||
|
|
||||||
|
class AWSProviderWithProxy(Provider):
|
||||||
|
|
||||||
|
def __init__(self, region: str = None, proxy_config_file: str = None):
|
||||||
|
super().__init__(region)
|
||||||
|
self.current_proxy: Optional[ProxyInfo] = None
|
||||||
|
|
||||||
|
# 初始化代理池
|
||||||
|
if proxy_config_file:
|
||||||
|
init_proxy_pool(proxy_config_file)
|
||||||
|
logger.info(f"Initialized proxy pool from {proxy_config_file}")
|
||||||
|
|
||||||
|
# 获取下一个可用代理
|
||||||
|
self._rotate_proxy()
|
||||||
|
|
||||||
|
def _rotate_proxy(self):
|
||||||
|
"""轮换到下一个可用代理"""
|
||||||
|
proxy_pool = get_global_proxy_pool()
|
||||||
|
self.current_proxy = proxy_pool.get_next_proxy()
|
||||||
|
|
||||||
|
if self.current_proxy:
|
||||||
|
logger.info(f"Switched to proxy: {self.current_proxy.host}:{self.current_proxy.port}")
|
||||||
|
else:
|
||||||
|
logger.warning("No proxy available, using direct connection")
|
||||||
|
|
||||||
|
def _generate_proxy_user_data(self) -> str:
|
||||||
|
"""生成包含代理配置的user data脚本"""
|
||||||
|
if not self.current_proxy:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
proxy_url = self._format_proxy_url(self.current_proxy)
|
||||||
|
|
||||||
|
user_data_script = f"""#!/bin/bash
|
||||||
|
# Configure system proxy
|
||||||
|
echo 'export http_proxy={proxy_url}' >> /etc/environment
|
||||||
|
echo 'export https_proxy={proxy_url}' >> /etc/environment
|
||||||
|
echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment
|
||||||
|
echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment
|
||||||
|
|
||||||
|
# Configure apt proxy
|
||||||
|
cat > /etc/apt/apt.conf.d/95proxy << EOF
|
||||||
|
Acquire::http::Proxy "{proxy_url}";
|
||||||
|
Acquire::https::Proxy "{proxy_url}";
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Configure chrome/chromium proxy
|
||||||
|
mkdir -p /etc/opt/chrome/policies/managed
|
||||||
|
cat > /etc/opt/chrome/policies/managed/proxy.json << EOF
|
||||||
|
{{
|
||||||
|
"ProxyMode": "fixed_servers",
|
||||||
|
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
|
||||||
|
}}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Configure chromium proxy (Ubuntu default)
|
||||||
|
mkdir -p /etc/chromium/policies/managed
|
||||||
|
cat > /etc/chromium/policies/managed/proxy.json << EOF
|
||||||
|
{{
|
||||||
|
"ProxyMode": "fixed_servers",
|
||||||
|
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
|
||||||
|
}}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Configure firefox proxy - support multiple possible paths
|
||||||
|
for firefox_dir in /etc/firefox/policies /usr/lib/firefox/distribution/policies /etc/firefox-esr/policies; do
|
||||||
|
if [ -d "$(dirname "$firefox_dir")" ]; then
|
||||||
|
mkdir -p "$firefox_dir"
|
||||||
|
cat > "$firefox_dir/policies.json" << EOF
|
||||||
|
{{
|
||||||
|
"policies": {{
|
||||||
|
"Proxy": {{
|
||||||
|
"Mode": "manual",
|
||||||
|
"HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
|
||||||
|
"HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
|
||||||
|
"UseHTTPProxyForAllProtocols": true
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
EOF
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Reload environment variables
|
||||||
|
source /etc/environment
|
||||||
|
|
||||||
|
# Log proxy configuration
|
||||||
|
echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log
|
||||||
|
"""
|
||||||
|
|
||||||
|
return base64.b64encode(user_data_script.encode()).decode()
|
||||||
|
|
||||||
|
def _format_proxy_url(self, proxy: ProxyInfo) -> str:
|
||||||
|
"""格式化代理URL"""
|
||||||
|
if proxy.username and proxy.password:
|
||||||
|
return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}"
|
||||||
|
else:
|
||||||
|
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
|
||||||
|
|
||||||
|
def start_emulator(self, path_to_vm: str, headless: bool):
|
||||||
|
logger.info("Starting AWS VM with proxy configuration...")
|
||||||
|
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 如果实例已经存在,直接启动
|
||||||
|
ec2_client.start_instances(InstanceIds=[path_to_vm])
|
||||||
|
logger.info(f"Instance {path_to_vm} is starting...")
|
||||||
|
|
||||||
|
# Wait for the instance to be in the 'running' state
|
||||||
|
waiter = ec2_client.get_waiter('instance_running')
|
||||||
|
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
|
||||||
|
logger.info(f"Instance {path_to_vm} is now running.")
|
||||||
|
|
||||||
|
except ClientError as e:
|
||||||
|
logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_instance_with_proxy(self, image_id: str, instance_type: str,
|
||||||
|
security_groups: list, subnet_id: str) -> str:
|
||||||
|
"""创建带有代理配置的新实例"""
|
||||||
|
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||||
|
|
||||||
|
user_data = self._generate_proxy_user_data()
|
||||||
|
|
||||||
|
run_instances_params = {
|
||||||
|
"MaxCount": 1,
|
||||||
|
"MinCount": 1,
|
||||||
|
"ImageId": image_id,
|
||||||
|
"InstanceType": instance_type,
|
||||||
|
"EbsOptimized": True,
|
||||||
|
"NetworkInterfaces": [
|
||||||
|
{
|
||||||
|
"SubnetId": subnet_id,
|
||||||
|
"AssociatePublicIpAddress": True,
|
||||||
|
"DeviceIndex": 0,
|
||||||
|
"Groups": security_groups
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
if user_data:
|
||||||
|
run_instances_params["UserData"] = user_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = ec2_client.run_instances(**run_instances_params)
|
||||||
|
instance_id = response['Instances'][0]['InstanceId']
|
||||||
|
|
||||||
|
logger.info(f"Created new instance {instance_id} with proxy configuration")
|
||||||
|
|
||||||
|
# 等待实例运行
|
||||||
|
logger.info(f"Waiting for instance {instance_id} to be running...")
|
||||||
|
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
|
||||||
|
logger.info(f"Instance {instance_id} is ready.")
|
||||||
|
|
||||||
|
return instance_id
|
||||||
|
|
||||||
|
except ClientError as e:
|
||||||
|
logger.error(f"Failed to create instance with proxy: {str(e)}")
|
||||||
|
# 如果当前代理失败,尝试轮换代理
|
||||||
|
if self.current_proxy:
|
||||||
|
proxy_pool = get_global_proxy_pool()
|
||||||
|
proxy_pool.mark_proxy_failed(self.current_proxy)
|
||||||
|
self._rotate_proxy()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_ip_address(self, path_to_vm: str) -> str:
|
||||||
|
logger.info("Getting AWS VM IP address...")
|
||||||
|
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = ec2_client.describe_instances(InstanceIds=[path_to_vm])
|
||||||
|
for reservation in response['Reservations']:
|
||||||
|
for instance in reservation['Instances']:
|
||||||
|
private_ip_address = instance.get('PrivateIpAddress', '')
|
||||||
|
return private_ip_address
|
||||||
|
return ''
|
||||||
|
except ClientError as e:
|
||||||
|
logger.error(f"Failed to retrieve private IP address for the instance {path_to_vm}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def save_state(self, path_to_vm: str, snapshot_name: str):
|
||||||
|
logger.info("Saving AWS VM state...")
|
||||||
|
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_response = ec2_client.create_image(InstanceId=path_to_vm, Name=snapshot_name)
|
||||||
|
image_id = image_response['ImageId']
|
||||||
|
logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.")
|
||||||
|
return image_id
|
||||||
|
except ClientError as e:
|
||||||
|
logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
|
||||||
|
logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...")
|
||||||
|
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取原实例详情
|
||||||
|
instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm])
|
||||||
|
instance = instance_details['Reservations'][0]['Instances'][0]
|
||||||
|
security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']]
|
||||||
|
subnet_id = instance['SubnetId']
|
||||||
|
instance_type = instance['InstanceType']
|
||||||
|
|
||||||
|
# 终止旧实例
|
||||||
|
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
|
||||||
|
logger.info(f"Old instance {path_to_vm} has been terminated.")
|
||||||
|
|
||||||
|
# 轮换到新的代理
|
||||||
|
self._rotate_proxy()
|
||||||
|
|
||||||
|
# 创建新实例
|
||||||
|
new_instance_id = self.create_instance_with_proxy(
|
||||||
|
snapshot_name, instance_type, security_groups, subnet_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_instance_id
|
||||||
|
|
||||||
|
except ClientError as e:
|
||||||
|
logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def stop_emulator(self, path_to_vm, region=None):
|
||||||
|
logger.info(f"Stopping AWS VM {path_to_vm}...")
|
||||||
|
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ec2_client.stop_instances(InstanceIds=[path_to_vm])
|
||||||
|
waiter = ec2_client.get_waiter('instance_stopped')
|
||||||
|
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
|
||||||
|
logger.info(f"Instance {path_to_vm} has been stopped.")
|
||||||
|
except ClientError as e:
|
||||||
|
logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_current_proxy_info(self) -> Optional[dict]:
|
||||||
|
"""获取当前代理信息"""
|
||||||
|
if self.current_proxy:
|
||||||
|
return {
|
||||||
|
'host': self.current_proxy.host,
|
||||||
|
'port': self.current_proxy.port,
|
||||||
|
'protocol': self.current_proxy.protocol,
|
||||||
|
'failed_count': self.current_proxy.failed_count
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def force_rotate_proxy(self):
|
||||||
|
"""强制轮换代理"""
|
||||||
|
logger.info("Force rotating proxy...")
|
||||||
|
if self.current_proxy:
|
||||||
|
proxy_pool = get_global_proxy_pool()
|
||||||
|
proxy_pool.mark_proxy_failed(self.current_proxy)
|
||||||
|
self._rotate_proxy()
|
||||||
|
|
||||||
|
def get_proxy_stats(self) -> dict:
|
||||||
|
"""获取代理池统计信息"""
|
||||||
|
proxy_pool = get_global_proxy_pool()
|
||||||
|
return proxy_pool.get_stats()
|
||||||
193
desktop_env/providers/aws/proxy_pool.py
Normal file
193
desktop_env/providers/aws/proxy_pool.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
import random
|
||||||
|
import requests
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from threading import Lock
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.providers.aws.ProxyPool")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProxyInfo:
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
username: Optional[str] = None
|
||||||
|
password: Optional[str] = None
|
||||||
|
protocol: str = "http" # http, https, socks5
|
||||||
|
failed_count: int = 0
|
||||||
|
last_used: float = 0
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
class ProxyPool:
|
||||||
|
def __init__(self, config_file: str = None):
|
||||||
|
self.proxies: List[ProxyInfo] = []
|
||||||
|
self.current_index = 0
|
||||||
|
self.lock = Lock()
|
||||||
|
self.max_failures = 3 # 最大失败次数
|
||||||
|
self.cooldown_time = 300 # 5分钟冷却时间
|
||||||
|
|
||||||
|
if config_file:
|
||||||
|
self.load_proxies_from_file(config_file)
|
||||||
|
|
||||||
|
def load_proxies_from_file(self, config_file: str):
|
||||||
|
"""从配置文件加载代理列表"""
|
||||||
|
try:
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
proxy_configs = json.load(f)
|
||||||
|
|
||||||
|
for config in proxy_configs:
|
||||||
|
proxy = ProxyInfo(
|
||||||
|
host=config['host'],
|
||||||
|
port=config['port'],
|
||||||
|
username=config.get('username'),
|
||||||
|
password=config.get('password'),
|
||||||
|
protocol=config.get('protocol', 'http')
|
||||||
|
)
|
||||||
|
self.proxies.append(proxy)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(self.proxies)} proxies from {config_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load proxies from {config_file}: {e}")
|
||||||
|
|
||||||
|
def add_proxy(self, host: str, port: int, username: str = None,
|
||||||
|
password: str = None, protocol: str = "http"):
|
||||||
|
"""添加代理到池中"""
|
||||||
|
proxy = ProxyInfo(host=host, port=port, username=username,
|
||||||
|
password=password, protocol=protocol)
|
||||||
|
with self.lock:
|
||||||
|
self.proxies.append(proxy)
|
||||||
|
logger.info(f"Added proxy {host}:{port}")
|
||||||
|
|
||||||
|
def get_next_proxy(self) -> Optional[ProxyInfo]:
|
||||||
|
"""获取下一个可用的代理"""
|
||||||
|
with self.lock:
|
||||||
|
if not self.proxies:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 过滤掉失败次数过多的代理
|
||||||
|
active_proxies = [p for p in self.proxies if self._is_proxy_available(p)]
|
||||||
|
|
||||||
|
if not active_proxies:
|
||||||
|
logger.warning("No active proxies available")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 轮询选择代理
|
||||||
|
proxy = active_proxies[self.current_index % len(active_proxies)]
|
||||||
|
self.current_index += 1
|
||||||
|
proxy.last_used = time.time()
|
||||||
|
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
def _is_proxy_available(self, proxy: ProxyInfo) -> bool:
|
||||||
|
"""检查代理是否可用"""
|
||||||
|
if not proxy.is_active:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if proxy.failed_count >= self.max_failures:
|
||||||
|
# 检查是否过了冷却时间
|
||||||
|
if time.time() - proxy.last_used < self.cooldown_time:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# 重置失败计数
|
||||||
|
proxy.failed_count = 0
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def mark_proxy_failed(self, proxy: ProxyInfo):
|
||||||
|
"""标记代理失败"""
|
||||||
|
with self.lock:
|
||||||
|
proxy.failed_count += 1
|
||||||
|
if proxy.failed_count >= self.max_failures:
|
||||||
|
logger.warning(f"Proxy {proxy.host}:{proxy.port} marked as failed "
|
||||||
|
f"(failures: {proxy.failed_count})")
|
||||||
|
|
||||||
|
def mark_proxy_success(self, proxy: ProxyInfo):
|
||||||
|
"""标记代理成功"""
|
||||||
|
with self.lock:
|
||||||
|
proxy.failed_count = 0
|
||||||
|
|
||||||
|
def test_proxy(self, proxy: ProxyInfo, test_url: str = "http://httpbin.org/ip",
|
||||||
|
timeout: int = 10) -> bool:
|
||||||
|
"""测试代理是否正常工作"""
|
||||||
|
try:
|
||||||
|
proxy_url = self._format_proxy_url(proxy)
|
||||||
|
proxies = {
|
||||||
|
'http': proxy_url,
|
||||||
|
'https': proxy_url
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(test_url, proxies=proxies, timeout=timeout)
|
||||||
|
if response.status_code == 200:
|
||||||
|
self.mark_proxy_success(proxy)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.mark_proxy_failed(proxy)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Proxy test failed for {proxy.host}:{proxy.port}: {e}")
|
||||||
|
self.mark_proxy_failed(proxy)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _format_proxy_url(self, proxy: ProxyInfo) -> str:
|
||||||
|
"""格式化代理URL"""
|
||||||
|
if proxy.username and proxy.password:
|
||||||
|
return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}"
|
||||||
|
else:
|
||||||
|
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
|
||||||
|
|
||||||
|
def get_proxy_dict(self, proxy: ProxyInfo) -> Dict[str, str]:
|
||||||
|
"""获取requests库使用的代理字典"""
|
||||||
|
proxy_url = self._format_proxy_url(proxy)
|
||||||
|
return {
|
||||||
|
'http': proxy_url,
|
||||||
|
'https': proxy_url
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_all_proxies(self, test_url: str = "http://httpbin.org/ip"):
|
||||||
|
"""测试所有代理"""
|
||||||
|
logger.info("Testing all proxies...")
|
||||||
|
working_count = 0
|
||||||
|
|
||||||
|
for proxy in self.proxies:
|
||||||
|
if self.test_proxy(proxy, test_url):
|
||||||
|
working_count += 1
|
||||||
|
logger.info(f"✓ Proxy {proxy.host}:{proxy.port} is working")
|
||||||
|
else:
|
||||||
|
logger.warning(f"✗ Proxy {proxy.host}:{proxy.port} failed")
|
||||||
|
|
||||||
|
logger.info(f"Proxy test completed: {working_count}/{len(self.proxies)} working")
|
||||||
|
return working_count
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict:
|
||||||
|
"""获取代理池统计信息"""
|
||||||
|
with self.lock:
|
||||||
|
total = len(self.proxies)
|
||||||
|
active = len([p for p in self.proxies if self._is_proxy_available(p)])
|
||||||
|
failed = len([p for p in self.proxies if p.failed_count >= self.max_failures])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total': total,
|
||||||
|
'active': active,
|
||||||
|
'failed': failed,
|
||||||
|
'success_rate': active / total if total > 0 else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 全局代理池实例
|
||||||
|
_proxy_pool = None
|
||||||
|
|
||||||
|
def get_global_proxy_pool() -> ProxyPool:
|
||||||
|
"""获取全局代理池实例"""
|
||||||
|
global _proxy_pool
|
||||||
|
if _proxy_pool is None:
|
||||||
|
_proxy_pool = ProxyPool()
|
||||||
|
return _proxy_pool
|
||||||
|
|
||||||
|
def init_proxy_pool(config_file: str = None):
|
||||||
|
"""初始化全局代理池"""
|
||||||
|
global _proxy_pool
|
||||||
|
_proxy_pool = ProxyPool(config_file)
|
||||||
|
return _proxy_pool
|
||||||
@@ -309,7 +309,21 @@ class OpenAICUAAgent:
|
|||||||
logger.error(f"OpenAI API error: {str(e)}")
|
logger.error(f"OpenAI API error: {str(e)}")
|
||||||
new_screenshot = self.env._get_obs()
|
new_screenshot = self.env._get_obs()
|
||||||
new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8')
|
new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8')
|
||||||
self.cua_messages[-1]["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
|
|
||||||
|
# Update the image in the last message based on its structure
|
||||||
|
last_message = self.cua_messages[-1]
|
||||||
|
if "output" in last_message:
|
||||||
|
# Computer call output message structure
|
||||||
|
last_message["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
|
||||||
|
elif "content" in last_message:
|
||||||
|
# User message structure - find and update the image content
|
||||||
|
for content_item in last_message["content"]:
|
||||||
|
if content_item.get("type") == "input_image":
|
||||||
|
content_item["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning("Unknown message structure, cannot update screenshot")
|
||||||
|
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
raise Exception("Failed to make OpenAI API call after 3 retries")
|
raise Exception("Failed to make OpenAI API call after 3 retries")
|
||||||
@@ -452,10 +466,7 @@ class OpenAICUAAgent:
|
|||||||
logger.warning("Empty text for type action")
|
logger.warning("Empty text for type action")
|
||||||
return "import pyautogui\n# Empty text, no action taken"
|
return "import pyautogui\n# Empty text, no action taken"
|
||||||
|
|
||||||
pattern = r"(?<!\\)'"
|
# Use repr() to properly escape the string content without double-escaping
|
||||||
text = re.sub(pattern, r"\\'", text)
|
|
||||||
|
|
||||||
# 使用三重引号来确保字符串内容不会破坏格式
|
|
||||||
pyautogui_code = f"""import pyautogui\npyautogui.typewrite({repr(text)})"""
|
pyautogui_code = f"""import pyautogui\npyautogui.typewrite({repr(text)})"""
|
||||||
logger.info(f"Pyautogui code: {pyautogui_code}")
|
logger.info(f"Pyautogui code: {pyautogui_code}")
|
||||||
return pyautogui_code
|
return pyautogui_code
|
||||||
|
|||||||
Reference in New Issue
Block a user