Merge branch 'feat/aws-provider-support' of https://github.com/xlang-ai/OSWorld into feat/aws-provider-support

This commit is contained in:
adlsdztony
2025-06-06 02:56:26 +00:00
6 changed files with 588 additions and 16 deletions

View File

@@ -54,13 +54,17 @@ class DesktopEnv(gym.Env):
"""
# Initialize VM manager and vitualization provider
self.region = region
self.provider_name = provider_name
# Default TODO:
self.server_port = 5000
self.chromium_port = 9222
self.vnc_port = 8006
self.vlc_port = 8080
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
# Initialize with default (no proxy) provider
self.current_use_proxy = False
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False)
self.os_type = os_type
@@ -149,6 +153,32 @@ class DesktopEnv(gym.Env):
self._step_no = 0
self.action_history.clear()
# Check and handle proxy requirement changes BEFORE starting emulator
if task_config is not None:
task_use_proxy = task_config.get("proxy", False)
if task_use_proxy != self.current_use_proxy:
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
# Close current provider if it exists
if hasattr(self, 'provider') and self.provider:
try:
self.provider.stop_emulator(self.path_to_vm)
except Exception as e:
logger.warning(f"Failed to stop current provider: {e}")
# Create new provider with appropriate proxy setting
self.current_use_proxy = task_use_proxy
self.manager, self.provider = create_vm_manager_and_provider(
self.provider_name,
self.region,
use_proxy=task_use_proxy
)
if task_use_proxy:
logger.info("Using proxy-enabled AWS provider.")
else:
logger.info("Using regular AWS provider.")
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
self._revert_to_snapshot()
logger.info("Starting emulator...")
@@ -184,12 +214,17 @@ class DesktopEnv(gym.Env):
return self.controller.get_vm_screen_size()
def _set_task_info(self, task_config: Dict[str, Any]):
"""Set task info (proxy logic is handled in reset method)"""
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"] if "config" in task_config else []
self._set_evaluator_info(task_config)
def _set_evaluator_info(self, task_config: Dict[str, Any]):
"""Set evaluator information from task config"""
# evaluator dict
# func -> metric function string, or list of metric function strings
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"

View File

@@ -1,9 +1,14 @@
from desktop_env.providers.base import VMManager, Provider
def create_vm_manager_and_provider(provider_name: str, region: str):
def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: bool = False):
"""
Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name.
Args:
provider_name (str): The name of the provider (e.g., "aws", "vmware", etc.)
region (str): The region for the provider
use_proxy (bool): Whether to use proxy-enabled providers (currently only supported for AWS)
"""
provider_name = provider_name.lower().strip()
if provider_name == "vmware":
@@ -16,8 +21,14 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
return VirtualBoxVMManager(), VirtualBoxProvider(region)
elif provider_name in ["aws", "amazon web services"]:
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
return AWSVMManager(), AWSProvider(region)
if use_proxy:
# Use proxy-enabled AWS provider
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
return AWSVMManager(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
else:
# Use regular AWS provider
from desktop_env.providers.aws.provider import AWSProvider
return AWSVMManager(), AWSProvider(region)
elif provider_name == "azure":
from desktop_env.providers.azure.manager import AzureVMManager
from desktop_env.providers.azure.provider import AzureProvider

View File

@@ -18,11 +18,16 @@ if not os.getenv('AWS_SUBNET_ID') or not os.getenv('AWS_SECURITY_GROUP_ID'):
from desktop_env.providers.base import VMManager
# Import proxy-related modules only when needed
try:
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
PROXY_SUPPORT_AVAILABLE = True
except ImportError:
PROXY_SUPPORT_AVAILABLE = False
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
logger.setLevel(logging.INFO)
REGISTRY_PATH = '.aws_vms'
DEFAULT_REGION = "us-east-1"
# todo: Add doc for the configuration of image, security group and network interface
# todo: public the AMI images
@@ -118,17 +123,55 @@ def _allocate_vm(region=DEFAULT_REGION):
return instance_id
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
"""Allocate a VM with proxy configuration"""
if not PROXY_SUPPORT_AVAILABLE:
logger.warning("Proxy support not available, falling back to regular VM allocation")
return _allocate_vm(region)
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
# Initialize proxy pool if needed
if proxy_config_file:
init_proxy_pool(proxy_config_file)
# Get current proxy
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
if current_proxy:
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
# Create provider instance
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
# Create new instance
instance_id = provider.create_instance_with_proxy(
image_id=IMAGE_ID_MAP[region],
instance_type=INSTANCE_TYPE,
security_groups=[os.getenv('AWS_SECURITY_GROUP_ID')],
subnet_id=os.getenv('AWS_SUBNET_ID')
)
return instance_id
class AWSVMManager(VMManager):
"""
AWS VM Manager for managing virtual machines on AWS.
AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs.
This class remains the interface of VMManager for compatibility with other components.
This class supports both regular VM allocation and proxy-enabled VM allocation.
"""
def __init__(self, registry_path=REGISTRY_PATH):
self.registry_path = registry_path
def __init__(self, proxy_config_file=None, **kwargs):
self.proxy_config_file = proxy_config_file
# self.lock = FileLock(".aws_lck", timeout=60)
self.initialize_registry()
# Initialize proxy pool if proxy configuration is provided
if proxy_config_file and PROXY_SUPPORT_AVAILABLE:
init_proxy_pool(proxy_config_file)
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
def initialize_registry(self, **kwargs):
pass
@@ -164,6 +207,10 @@ class AWSVMManager(VMManager):
pass
def get_vm_path(self, region=DEFAULT_REGION, **kwargs):
logger.info("Allocating a new VM in region: {}".format(region))
new_vm_path = _allocate_vm(region)
if self.proxy_config_file:
logger.info("Allocating a new VM with proxy configuration in region: {}".format(region))
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
else:
logger.info("Allocating a new VM in region: {}".format(region))
new_vm_path = _allocate_vm(region)
return new_vm_path

View File

@@ -0,0 +1,275 @@
import boto3
from botocore.exceptions import ClientError
import base64
import logging
import json
from typing import Optional
from desktop_env.providers.base import Provider
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool, ProxyInfo
logger = logging.getLogger("desktopenv.providers.aws.AWSProviderWithProxy")
logger.setLevel(logging.INFO)
WAIT_DELAY = 15
MAX_ATTEMPTS = 10
class AWSProviderWithProxy(Provider):
def __init__(self, region: str = None, proxy_config_file: str = None):
super().__init__(region)
self.current_proxy: Optional[ProxyInfo] = None
# 初始化代理池
if proxy_config_file:
init_proxy_pool(proxy_config_file)
logger.info(f"Initialized proxy pool from {proxy_config_file}")
# 获取下一个可用代理
self._rotate_proxy()
def _rotate_proxy(self):
"""轮换到下一个可用代理"""
proxy_pool = get_global_proxy_pool()
self.current_proxy = proxy_pool.get_next_proxy()
if self.current_proxy:
logger.info(f"Switched to proxy: {self.current_proxy.host}:{self.current_proxy.port}")
else:
logger.warning("No proxy available, using direct connection")
def _generate_proxy_user_data(self) -> str:
"""生成包含代理配置的user data脚本"""
if not self.current_proxy:
return ""
proxy_url = self._format_proxy_url(self.current_proxy)
user_data_script = f"""#!/bin/bash
# Configure system proxy
echo 'export http_proxy={proxy_url}' >> /etc/environment
echo 'export https_proxy={proxy_url}' >> /etc/environment
echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment
echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment
# Configure apt proxy
cat > /etc/apt/apt.conf.d/95proxy << EOF
Acquire::http::Proxy "{proxy_url}";
Acquire::https::Proxy "{proxy_url}";
EOF
# Configure chrome/chromium proxy
mkdir -p /etc/opt/chrome/policies/managed
cat > /etc/opt/chrome/policies/managed/proxy.json << EOF
{{
"ProxyMode": "fixed_servers",
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
}}
EOF
# Configure chromium proxy (Ubuntu default)
mkdir -p /etc/chromium/policies/managed
cat > /etc/chromium/policies/managed/proxy.json << EOF
{{
"ProxyMode": "fixed_servers",
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
}}
EOF
# Configure firefox proxy - support multiple possible paths
for firefox_dir in /etc/firefox/policies /usr/lib/firefox/distribution/policies /etc/firefox-esr/policies; do
if [ -d "$(dirname "$firefox_dir")" ]; then
mkdir -p "$firefox_dir"
cat > "$firefox_dir/policies.json" << EOF
{{
"policies": {{
"Proxy": {{
"Mode": "manual",
"HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"UseHTTPProxyForAllProtocols": true
}}
}}
}}
EOF
break
fi
done
# Reload environment variables
source /etc/environment
# Log proxy configuration
echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log
"""
return base64.b64encode(user_data_script.encode()).decode()
def _format_proxy_url(self, proxy: ProxyInfo) -> str:
"""格式化代理URL"""
if proxy.username and proxy.password:
return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}"
else:
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
def start_emulator(self, path_to_vm: str, headless: bool):
logger.info("Starting AWS VM with proxy configuration...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
# 如果实例已经存在,直接启动
ec2_client.start_instances(InstanceIds=[path_to_vm])
logger.info(f"Instance {path_to_vm} is starting...")
# Wait for the instance to be in the 'running' state
waiter = ec2_client.get_waiter('instance_running')
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
logger.info(f"Instance {path_to_vm} is now running.")
except ClientError as e:
logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}")
raise
def create_instance_with_proxy(self, image_id: str, instance_type: str,
security_groups: list, subnet_id: str) -> str:
"""创建带有代理配置的新实例"""
ec2_client = boto3.client('ec2', region_name=self.region)
user_data = self._generate_proxy_user_data()
run_instances_params = {
"MaxCount": 1,
"MinCount": 1,
"ImageId": image_id,
"InstanceType": instance_type,
"EbsOptimized": True,
"NetworkInterfaces": [
{
"SubnetId": subnet_id,
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": security_groups
}
]
}
if user_data:
run_instances_params["UserData"] = user_data
try:
response = ec2_client.run_instances(**run_instances_params)
instance_id = response['Instances'][0]['InstanceId']
logger.info(f"Created new instance {instance_id} with proxy configuration")
# 等待实例运行
logger.info(f"Waiting for instance {instance_id} to be running...")
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
logger.info(f"Instance {instance_id} is ready.")
return instance_id
except ClientError as e:
logger.error(f"Failed to create instance with proxy: {str(e)}")
# 如果当前代理失败,尝试轮换代理
if self.current_proxy:
proxy_pool = get_global_proxy_pool()
proxy_pool.mark_proxy_failed(self.current_proxy)
self._rotate_proxy()
raise
def get_ip_address(self, path_to_vm: str) -> str:
logger.info("Getting AWS VM IP address...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
response = ec2_client.describe_instances(InstanceIds=[path_to_vm])
for reservation in response['Reservations']:
for instance in reservation['Instances']:
private_ip_address = instance.get('PrivateIpAddress', '')
return private_ip_address
return ''
except ClientError as e:
logger.error(f"Failed to retrieve private IP address for the instance {path_to_vm}: {str(e)}")
raise
def save_state(self, path_to_vm: str, snapshot_name: str):
logger.info("Saving AWS VM state...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
image_response = ec2_client.create_image(InstanceId=path_to_vm, Name=snapshot_name)
image_id = image_response['ImageId']
logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.")
return image_id
except ClientError as e:
logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}")
raise
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
# 获取原实例详情
instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm])
instance = instance_details['Reservations'][0]['Instances'][0]
security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']]
subnet_id = instance['SubnetId']
instance_type = instance['InstanceType']
# 终止旧实例
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
logger.info(f"Old instance {path_to_vm} has been terminated.")
# 轮换到新的代理
self._rotate_proxy()
# 创建新实例
new_instance_id = self.create_instance_with_proxy(
snapshot_name, instance_type, security_groups, subnet_id
)
return new_instance_id
except ClientError as e:
logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}")
raise
def stop_emulator(self, path_to_vm, region=None):
logger.info(f"Stopping AWS VM {path_to_vm}...")
ec2_client = boto3.client('ec2', region_name=self.region)
try:
ec2_client.stop_instances(InstanceIds=[path_to_vm])
waiter = ec2_client.get_waiter('instance_stopped')
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
logger.info(f"Instance {path_to_vm} has been stopped.")
except ClientError as e:
logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}")
raise
def get_current_proxy_info(self) -> Optional[dict]:
"""获取当前代理信息"""
if self.current_proxy:
return {
'host': self.current_proxy.host,
'port': self.current_proxy.port,
'protocol': self.current_proxy.protocol,
'failed_count': self.current_proxy.failed_count
}
return None
def force_rotate_proxy(self):
"""强制轮换代理"""
logger.info("Force rotating proxy...")
if self.current_proxy:
proxy_pool = get_global_proxy_pool()
proxy_pool.mark_proxy_failed(self.current_proxy)
self._rotate_proxy()
def get_proxy_stats(self) -> dict:
"""获取代理池统计信息"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.get_stats()

View File

@@ -0,0 +1,193 @@
import random
import requests
import logging
import time
from typing import List, Dict, Optional
from dataclasses import dataclass
from threading import Lock
import json
logger = logging.getLogger("desktopenv.providers.aws.ProxyPool")
logger.setLevel(logging.INFO)
@dataclass
class ProxyInfo:
host: str
port: int
username: Optional[str] = None
password: Optional[str] = None
protocol: str = "http" # http, https, socks5
failed_count: int = 0
last_used: float = 0
is_active: bool = True
class ProxyPool:
def __init__(self, config_file: str = None):
self.proxies: List[ProxyInfo] = []
self.current_index = 0
self.lock = Lock()
self.max_failures = 3 # 最大失败次数
self.cooldown_time = 300 # 5分钟冷却时间
if config_file:
self.load_proxies_from_file(config_file)
def load_proxies_from_file(self, config_file: str):
"""从配置文件加载代理列表"""
try:
with open(config_file, 'r') as f:
proxy_configs = json.load(f)
for config in proxy_configs:
proxy = ProxyInfo(
host=config['host'],
port=config['port'],
username=config.get('username'),
password=config.get('password'),
protocol=config.get('protocol', 'http')
)
self.proxies.append(proxy)
logger.info(f"Loaded {len(self.proxies)} proxies from {config_file}")
except Exception as e:
logger.error(f"Failed to load proxies from {config_file}: {e}")
def add_proxy(self, host: str, port: int, username: str = None,
password: str = None, protocol: str = "http"):
"""添加代理到池中"""
proxy = ProxyInfo(host=host, port=port, username=username,
password=password, protocol=protocol)
with self.lock:
self.proxies.append(proxy)
logger.info(f"Added proxy {host}:{port}")
def get_next_proxy(self) -> Optional[ProxyInfo]:
"""获取下一个可用的代理"""
with self.lock:
if not self.proxies:
return None
# 过滤掉失败次数过多的代理
active_proxies = [p for p in self.proxies if self._is_proxy_available(p)]
if not active_proxies:
logger.warning("No active proxies available")
return None
# 轮询选择代理
proxy = active_proxies[self.current_index % len(active_proxies)]
self.current_index += 1
proxy.last_used = time.time()
return proxy
def _is_proxy_available(self, proxy: ProxyInfo) -> bool:
"""检查代理是否可用"""
if not proxy.is_active:
return False
if proxy.failed_count >= self.max_failures:
# 检查是否过了冷却时间
if time.time() - proxy.last_used < self.cooldown_time:
return False
else:
# 重置失败计数
proxy.failed_count = 0
return True
def mark_proxy_failed(self, proxy: ProxyInfo):
"""标记代理失败"""
with self.lock:
proxy.failed_count += 1
if proxy.failed_count >= self.max_failures:
logger.warning(f"Proxy {proxy.host}:{proxy.port} marked as failed "
f"(failures: {proxy.failed_count})")
def mark_proxy_success(self, proxy: ProxyInfo):
"""标记代理成功"""
with self.lock:
proxy.failed_count = 0
def test_proxy(self, proxy: ProxyInfo, test_url: str = "http://httpbin.org/ip",
timeout: int = 10) -> bool:
"""测试代理是否正常工作"""
try:
proxy_url = self._format_proxy_url(proxy)
proxies = {
'http': proxy_url,
'https': proxy_url
}
response = requests.get(test_url, proxies=proxies, timeout=timeout)
if response.status_code == 200:
self.mark_proxy_success(proxy)
return True
else:
self.mark_proxy_failed(proxy)
return False
except Exception as e:
logger.debug(f"Proxy test failed for {proxy.host}:{proxy.port}: {e}")
self.mark_proxy_failed(proxy)
return False
def _format_proxy_url(self, proxy: ProxyInfo) -> str:
"""格式化代理URL"""
if proxy.username and proxy.password:
return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}"
else:
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
def get_proxy_dict(self, proxy: ProxyInfo) -> Dict[str, str]:
"""获取requests库使用的代理字典"""
proxy_url = self._format_proxy_url(proxy)
return {
'http': proxy_url,
'https': proxy_url
}
def test_all_proxies(self, test_url: str = "http://httpbin.org/ip"):
"""测试所有代理"""
logger.info("Testing all proxies...")
working_count = 0
for proxy in self.proxies:
if self.test_proxy(proxy, test_url):
working_count += 1
logger.info(f"✓ Proxy {proxy.host}:{proxy.port} is working")
else:
logger.warning(f"✗ Proxy {proxy.host}:{proxy.port} failed")
logger.info(f"Proxy test completed: {working_count}/{len(self.proxies)} working")
return working_count
def get_stats(self) -> Dict:
"""获取代理池统计信息"""
with self.lock:
total = len(self.proxies)
active = len([p for p in self.proxies if self._is_proxy_available(p)])
failed = len([p for p in self.proxies if p.failed_count >= self.max_failures])
return {
'total': total,
'active': active,
'failed': failed,
'success_rate': active / total if total > 0 else 0
}
# 全局代理池实例
_proxy_pool = None
def get_global_proxy_pool() -> ProxyPool:
"""获取全局代理池实例"""
global _proxy_pool
if _proxy_pool is None:
_proxy_pool = ProxyPool()
return _proxy_pool
def init_proxy_pool(config_file: str = None):
"""初始化全局代理池"""
global _proxy_pool
_proxy_pool = ProxyPool(config_file)
return _proxy_pool

View File

@@ -309,7 +309,21 @@ class OpenAICUAAgent:
logger.error(f"OpenAI API error: {str(e)}")
new_screenshot = self.env._get_obs()
new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8')
self.cua_messages[-1]["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
# Update the image in the last message based on its structure
last_message = self.cua_messages[-1]
if "output" in last_message:
# Computer call output message structure
last_message["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
elif "content" in last_message:
# User message structure - find and update the image content
for content_item in last_message["content"]:
if content_item.get("type") == "input_image":
content_item["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
break
else:
logger.warning("Unknown message structure, cannot update screenshot")
retry_count += 1
time.sleep(1)
raise Exception("Failed to make OpenAI API call after 3 retries")
@@ -452,10 +466,7 @@ class OpenAICUAAgent:
logger.warning("Empty text for type action")
return "import pyautogui\n# Empty text, no action taken"
pattern = r"(?<!\\)'"
text = re.sub(pattern, r"\\'", text)
# 使用三重引号来确保字符串内容不会破坏格式
# Use repr() to properly escape the string content without double-escaping
pyautogui_code = f"""import pyautogui\npyautogui.typewrite({repr(text)})"""
logger.info(f"Pyautogui code: {pyautogui_code}")
return pyautogui_code