Merge remote-tracking branch 'upstream/feat/aws-provider-support'
This commit is contained in:
@@ -121,7 +121,7 @@ class SetupController:
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to download {url} caused by {e}. Retrying... ({max_retries - i - 1} attempts left)")
|
f"Failed to download {url} caused by {e}. Retrying... ({max_retries - i - 1} attempts left)")
|
||||||
if not downloaded:
|
if not downloaded:
|
||||||
raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}")
|
raise requests.RequestException(f"Failed to download {url}. No retries left.")
|
||||||
|
|
||||||
form = MultipartEncoder({
|
form = MultipartEncoder({
|
||||||
"file_path": path,
|
"file_path": path,
|
||||||
|
|||||||
@@ -20,14 +20,13 @@ def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: b
|
|||||||
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
|
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
|
||||||
return VirtualBoxVMManager(), VirtualBoxProvider(region)
|
return VirtualBoxVMManager(), VirtualBoxProvider(region)
|
||||||
elif provider_name in ["aws", "amazon web services"]:
|
elif provider_name in ["aws", "amazon web services"]:
|
||||||
|
from desktop_env.providers.aws.manager import AWSVMManager
|
||||||
if use_proxy:
|
if use_proxy:
|
||||||
# Use proxy-enabled AWS provider
|
# Use proxy-enabled AWS provider
|
||||||
from desktop_env.providers.aws.manager_with_proxy import AWSVMManagerWithProxy
|
|
||||||
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
|
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
|
||||||
return AWSVMManagerWithProxy(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
|
return AWSVMManager(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
|
||||||
else:
|
else:
|
||||||
# Use regular AWS provider
|
# Use regular AWS provider
|
||||||
from desktop_env.providers.aws.manager import AWSVMManager
|
|
||||||
from desktop_env.providers.aws.provider import AWSProvider
|
from desktop_env.providers.aws.provider import AWSProvider
|
||||||
return AWSVMManager(), AWSProvider(region)
|
return AWSVMManager(), AWSProvider(region)
|
||||||
elif provider_name == "azure":
|
elif provider_name == "azure":
|
||||||
|
|||||||
@@ -18,11 +18,16 @@ if not os.getenv('AWS_SUBNET_ID') or not os.getenv('AWS_SECURITY_GROUP_ID'):
|
|||||||
|
|
||||||
from desktop_env.providers.base import VMManager
|
from desktop_env.providers.base import VMManager
|
||||||
|
|
||||||
|
# Import proxy-related modules only when needed
|
||||||
|
try:
|
||||||
|
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
|
||||||
|
PROXY_SUPPORT_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PROXY_SUPPORT_AVAILABLE = False
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
|
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
REGISTRY_PATH = '.aws_vms'
|
|
||||||
|
|
||||||
DEFAULT_REGION = "us-east-1"
|
DEFAULT_REGION = "us-east-1"
|
||||||
# todo: Add doc for the configuration of image, security group and network interface
|
# todo: Add doc for the configuration of image, security group and network interface
|
||||||
# todo: public the AMI images
|
# todo: public the AMI images
|
||||||
@@ -118,17 +123,55 @@ def _allocate_vm(region=DEFAULT_REGION):
|
|||||||
return instance_id
|
return instance_id
|
||||||
|
|
||||||
|
|
||||||
|
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
|
||||||
|
"""Allocate a VM with proxy configuration"""
|
||||||
|
if not PROXY_SUPPORT_AVAILABLE:
|
||||||
|
logger.warning("Proxy support not available, falling back to regular VM allocation")
|
||||||
|
return _allocate_vm(region)
|
||||||
|
|
||||||
|
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
|
||||||
|
|
||||||
|
# Initialize proxy pool if needed
|
||||||
|
if proxy_config_file:
|
||||||
|
init_proxy_pool(proxy_config_file)
|
||||||
|
|
||||||
|
# Get current proxy
|
||||||
|
proxy_pool = get_global_proxy_pool()
|
||||||
|
current_proxy = proxy_pool.get_next_proxy()
|
||||||
|
|
||||||
|
if current_proxy:
|
||||||
|
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
|
||||||
|
|
||||||
|
# Create provider instance
|
||||||
|
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
|
||||||
|
|
||||||
|
# Create new instance
|
||||||
|
instance_id = provider.create_instance_with_proxy(
|
||||||
|
image_id=IMAGE_ID_MAP[region],
|
||||||
|
instance_type=INSTANCE_TYPE,
|
||||||
|
security_groups=[os.getenv('AWS_SECURITY_GROUP_ID')],
|
||||||
|
subnet_id=os.getenv('AWS_SUBNET_ID')
|
||||||
|
)
|
||||||
|
|
||||||
|
return instance_id
|
||||||
|
|
||||||
|
|
||||||
class AWSVMManager(VMManager):
|
class AWSVMManager(VMManager):
|
||||||
"""
|
"""
|
||||||
AWS VM Manager for managing virtual machines on AWS.
|
AWS VM Manager for managing virtual machines on AWS.
|
||||||
|
|
||||||
AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs.
|
AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs.
|
||||||
This class remains the interface of VMManager for compatibility with other components.
|
This class supports both regular VM allocation and proxy-enabled VM allocation.
|
||||||
"""
|
"""
|
||||||
def __init__(self, registry_path=REGISTRY_PATH):
|
def __init__(self, proxy_config_file=None, **kwargs):
|
||||||
self.registry_path = registry_path
|
self.proxy_config_file = proxy_config_file
|
||||||
# self.lock = FileLock(".aws_lck", timeout=60)
|
# self.lock = FileLock(".aws_lck", timeout=60)
|
||||||
self.initialize_registry()
|
self.initialize_registry()
|
||||||
|
|
||||||
|
# Initialize proxy pool if proxy configuration is provided
|
||||||
|
if proxy_config_file and PROXY_SUPPORT_AVAILABLE:
|
||||||
|
init_proxy_pool(proxy_config_file)
|
||||||
|
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
|
||||||
|
|
||||||
def initialize_registry(self, **kwargs):
|
def initialize_registry(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
@@ -164,6 +207,10 @@ class AWSVMManager(VMManager):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def get_vm_path(self, region=DEFAULT_REGION, **kwargs):
|
def get_vm_path(self, region=DEFAULT_REGION, **kwargs):
|
||||||
logger.info("Allocating a new VM in region: {}".format(region))
|
if self.proxy_config_file:
|
||||||
new_vm_path = _allocate_vm(region)
|
logger.info("Allocating a new VM with proxy configuration in region: {}".format(region))
|
||||||
|
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
|
||||||
|
else:
|
||||||
|
logger.info("Allocating a new VM in region: {}".format(region))
|
||||||
|
new_vm_path = _allocate_vm(region)
|
||||||
return new_vm_path
|
return new_vm_path
|
||||||
|
|||||||
@@ -1,329 +0,0 @@
|
|||||||
import os
|
|
||||||
from filelock import FileLock
|
|
||||||
import boto3
|
|
||||||
import psutil
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from desktop_env.providers.base import VMManager
|
|
||||||
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy")
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
REGISTRY_PATH = '.aws_vms_proxy'
|
|
||||||
|
|
||||||
DEFAULT_REGION = "us-east-1"
|
|
||||||
IMAGE_ID_MAP = {
|
|
||||||
"us-east-1": "ami-05e7d7bd279ea4f14",
|
|
||||||
"ap-east-1": "ami-0c092a5b8be4116f5"
|
|
||||||
}
|
|
||||||
|
|
||||||
INSTANCE_TYPE = "t3.medium"
|
|
||||||
|
|
||||||
NETWORK_INTERFACE_MAP = {
|
|
||||||
"us-east-1": [
|
|
||||||
{
|
|
||||||
"SubnetId": "subnet-037edfff66c2eb894",
|
|
||||||
"AssociatePublicIpAddress": True,
|
|
||||||
"DeviceIndex": 0,
|
|
||||||
"Groups": [
|
|
||||||
"sg-0342574803206ee9c"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"ap-east-1": [
|
|
||||||
{
|
|
||||||
"SubnetId": "subnet-011060501be0b589c",
|
|
||||||
"AssociatePublicIpAddress": True,
|
|
||||||
"DeviceIndex": 0,
|
|
||||||
"Groups": [
|
|
||||||
"sg-090470e64df78f6eb"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
|
|
||||||
"""分配带有代理配置的VM"""
|
|
||||||
from .provider_with_proxy import AWSProviderWithProxy
|
|
||||||
|
|
||||||
# 初始化代理池(如果还没有初始化)
|
|
||||||
if proxy_config_file:
|
|
||||||
init_proxy_pool(proxy_config_file)
|
|
||||||
|
|
||||||
# 获取当前代理
|
|
||||||
proxy_pool = get_global_proxy_pool()
|
|
||||||
current_proxy = proxy_pool.get_next_proxy()
|
|
||||||
|
|
||||||
if current_proxy:
|
|
||||||
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
|
|
||||||
|
|
||||||
# 创建provider实例
|
|
||||||
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
|
|
||||||
|
|
||||||
# 创建新实例
|
|
||||||
instance_id = provider.create_instance_with_proxy(
|
|
||||||
image_id=IMAGE_ID_MAP[region],
|
|
||||||
instance_type=INSTANCE_TYPE,
|
|
||||||
security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"],
|
|
||||||
subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return instance_id
|
|
||||||
|
|
||||||
|
|
||||||
class AWSVMManagerWithProxy(VMManager):
|
|
||||||
def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None):
|
|
||||||
self.registry_path = registry_path
|
|
||||||
self.lock = FileLock(".aws_proxy_lck", timeout=60)
|
|
||||||
self.proxy_config_file = proxy_config_file
|
|
||||||
self.initialize_registry()
|
|
||||||
|
|
||||||
# 初始化代理池
|
|
||||||
if proxy_config_file:
|
|
||||||
init_proxy_pool(proxy_config_file)
|
|
||||||
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
|
|
||||||
|
|
||||||
def initialize_registry(self):
|
|
||||||
with self.lock:
|
|
||||||
if not os.path.exists(self.registry_path):
|
|
||||||
with open(self.registry_path, 'w') as file:
|
|
||||||
file.write('')
|
|
||||||
|
|
||||||
def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True):
|
|
||||||
if lock_needed:
|
|
||||||
with self.lock:
|
|
||||||
self._add_vm(vm_path, region, proxy_info)
|
|
||||||
else:
|
|
||||||
self._add_vm(vm_path, region, proxy_info)
|
|
||||||
|
|
||||||
def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None):
|
|
||||||
with open(self.registry_path, 'r') as file:
|
|
||||||
lines = file.readlines()
|
|
||||||
|
|
||||||
# 格式: vm_path@region|status|proxy_host:proxy_port
|
|
||||||
vm_path_at_vm_region = f"{vm_path}@{region}"
|
|
||||||
proxy_str = ""
|
|
||||||
if proxy_info:
|
|
||||||
proxy_str = f"{proxy_info['host']}:{proxy_info['port']}"
|
|
||||||
|
|
||||||
new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n'
|
|
||||||
new_lines = lines + [new_line]
|
|
||||||
|
|
||||||
with open(self.registry_path, 'w') as file:
|
|
||||||
file.writelines(new_lines)
|
|
||||||
|
|
||||||
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
|
|
||||||
if lock_needed:
|
|
||||||
with self.lock:
|
|
||||||
self._delete_vm(vm_path, region)
|
|
||||||
else:
|
|
||||||
self._delete_vm(vm_path, region)
|
|
||||||
|
|
||||||
def _delete_vm(self, vm_path, region=DEFAULT_REGION):
|
|
||||||
new_lines = []
|
|
||||||
with open(self.registry_path, 'r') as file:
|
|
||||||
lines = file.readlines()
|
|
||||||
for line in lines:
|
|
||||||
parts = line.strip().split('|')
|
|
||||||
if len(parts) >= 2:
|
|
||||||
vm_path_at_vm_region = parts[0]
|
|
||||||
if vm_path_at_vm_region == f"{vm_path}@{region}":
|
|
||||||
continue
|
|
||||||
new_lines.append(line)
|
|
||||||
|
|
||||||
with open(self.registry_path, 'w') as file:
|
|
||||||
file.writelines(new_lines)
|
|
||||||
|
|
||||||
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
|
|
||||||
if lock_needed:
|
|
||||||
with self.lock:
|
|
||||||
self._occupy_vm(vm_path, pid, region)
|
|
||||||
else:
|
|
||||||
self._occupy_vm(vm_path, pid, region)
|
|
||||||
|
|
||||||
def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
|
|
||||||
new_lines = []
|
|
||||||
with open(self.registry_path, 'r') as file:
|
|
||||||
lines = file.readlines()
|
|
||||||
for line in lines:
|
|
||||||
parts = line.strip().split('|')
|
|
||||||
if len(parts) >= 2:
|
|
||||||
registered_vm_path = parts[0]
|
|
||||||
if registered_vm_path == f"{vm_path}@{region}":
|
|
||||||
proxy_str = parts[2] if len(parts) > 2 else ""
|
|
||||||
new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n')
|
|
||||||
else:
|
|
||||||
new_lines.append(line)
|
|
||||||
else:
|
|
||||||
new_lines.append(line)
|
|
||||||
|
|
||||||
with open(self.registry_path, 'w') as file:
|
|
||||||
file.writelines(new_lines)
|
|
||||||
|
|
||||||
def check_and_clean(self, lock_needed=True):
|
|
||||||
if lock_needed:
|
|
||||||
with self.lock:
|
|
||||||
self._check_and_clean()
|
|
||||||
else:
|
|
||||||
self._check_and_clean()
|
|
||||||
|
|
||||||
def _check_and_clean(self):
|
|
||||||
# Get active PIDs
|
|
||||||
active_pids = {p.pid for p in psutil.process_iter()}
|
|
||||||
|
|
||||||
new_lines = []
|
|
||||||
vm_path_at_vm_regions = {}
|
|
||||||
|
|
||||||
with open(self.registry_path, 'r') as file:
|
|
||||||
lines = file.readlines()
|
|
||||||
|
|
||||||
# Collect all VM paths and their regions
|
|
||||||
for line in lines:
|
|
||||||
parts = line.strip().split('|')
|
|
||||||
if len(parts) >= 2:
|
|
||||||
vm_path_at_vm_region = parts[0]
|
|
||||||
status = parts[1]
|
|
||||||
proxy_str = parts[2] if len(parts) > 2 else ""
|
|
||||||
|
|
||||||
vm_path, vm_region = vm_path_at_vm_region.split("@")
|
|
||||||
if vm_region not in vm_path_at_vm_regions:
|
|
||||||
vm_path_at_vm_regions[vm_region] = []
|
|
||||||
vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str))
|
|
||||||
|
|
||||||
# Process each region
|
|
||||||
for region, vm_info_list in vm_path_at_vm_regions.items():
|
|
||||||
ec2_client = boto3.client('ec2', region_name=region)
|
|
||||||
instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = ec2_client.describe_instances(InstanceIds=instance_ids)
|
|
||||||
reservations = response.get('Reservations', [])
|
|
||||||
|
|
||||||
terminated_ids = set()
|
|
||||||
stopped_ids = set()
|
|
||||||
active_ids = set()
|
|
||||||
|
|
||||||
for reservation in reservations:
|
|
||||||
for instance in reservation.get('Instances', []):
|
|
||||||
instance_id = instance.get('InstanceId')
|
|
||||||
instance_state = instance['State']['Name']
|
|
||||||
if instance_state in ['terminated', 'shutting-down']:
|
|
||||||
terminated_ids.add(instance_id)
|
|
||||||
elif instance_state == 'stopped':
|
|
||||||
stopped_ids.add(instance_id)
|
|
||||||
else:
|
|
||||||
active_ids.add(instance_id)
|
|
||||||
|
|
||||||
for vm_path_at_vm_region, status, proxy_str in vm_info_list:
|
|
||||||
vm_path = vm_path_at_vm_region.split('@')[0]
|
|
||||||
|
|
||||||
if vm_path in terminated_ids:
|
|
||||||
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
|
|
||||||
continue
|
|
||||||
elif vm_path in stopped_ids:
|
|
||||||
logger.info(f"VM {vm_path} stopped, mark it as free")
|
|
||||||
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
|
|
||||||
continue
|
|
||||||
|
|
||||||
if status == "free":
|
|
||||||
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
|
|
||||||
elif status.isdigit() and int(status) in active_pids:
|
|
||||||
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
|
|
||||||
else:
|
|
||||||
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error checking instances in region {region}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
with open(self.registry_path, 'w') as file:
|
|
||||||
file.writelines(new_lines)
|
|
||||||
|
|
||||||
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
|
|
||||||
if lock_needed:
|
|
||||||
with self.lock:
|
|
||||||
return self._list_free_vms(region)
|
|
||||||
else:
|
|
||||||
return self._list_free_vms(region)
|
|
||||||
|
|
||||||
def _list_free_vms(self, region=DEFAULT_REGION):
|
|
||||||
free_vms = []
|
|
||||||
with open(self.registry_path, 'r') as file:
|
|
||||||
lines = file.readlines()
|
|
||||||
for line in lines:
|
|
||||||
parts = line.strip().split('|')
|
|
||||||
if len(parts) >= 2:
|
|
||||||
vm_path_at_vm_region = parts[0]
|
|
||||||
status = parts[1]
|
|
||||||
proxy_str = parts[2] if len(parts) > 2 else ""
|
|
||||||
|
|
||||||
vm_path, vm_region = vm_path_at_vm_region.split("@")
|
|
||||||
if status == "free" and vm_region == region:
|
|
||||||
free_vms.append((vm_path, status, proxy_str))
|
|
||||||
|
|
||||||
return free_vms
|
|
||||||
|
|
||||||
def get_vm_path(self, region=DEFAULT_REGION):
|
|
||||||
with self.lock:
|
|
||||||
if not AWSVMManagerWithProxy.checked_and_cleaned:
|
|
||||||
AWSVMManagerWithProxy.checked_and_cleaned = True
|
|
||||||
self._check_and_clean()
|
|
||||||
|
|
||||||
allocation_needed = False
|
|
||||||
with self.lock:
|
|
||||||
free_vms_paths = self._list_free_vms(region)
|
|
||||||
|
|
||||||
if len(free_vms_paths) == 0:
|
|
||||||
allocation_needed = True
|
|
||||||
else:
|
|
||||||
chosen_vm_path, _, proxy_str = free_vms_paths[0]
|
|
||||||
self._occupy_vm(chosen_vm_path, os.getpid(), region)
|
|
||||||
logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}")
|
|
||||||
return chosen_vm_path
|
|
||||||
|
|
||||||
if allocation_needed:
|
|
||||||
logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕")
|
|
||||||
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
|
|
||||||
|
|
||||||
# 获取当前使用的代理信息
|
|
||||||
proxy_pool = get_global_proxy_pool()
|
|
||||||
current_proxy = proxy_pool.get_next_proxy()
|
|
||||||
proxy_info = None
|
|
||||||
if current_proxy:
|
|
||||||
proxy_info = {
|
|
||||||
'host': current_proxy.host,
|
|
||||||
'port': current_proxy.port
|
|
||||||
}
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
self._add_vm(new_vm_path, region, proxy_info)
|
|
||||||
self._occupy_vm(new_vm_path, os.getpid(), region)
|
|
||||||
return new_vm_path
|
|
||||||
|
|
||||||
def get_proxy_stats(self):
|
|
||||||
"""获取代理池统计信息"""
|
|
||||||
proxy_pool = get_global_proxy_pool()
|
|
||||||
return proxy_pool.get_stats()
|
|
||||||
|
|
||||||
def test_all_proxies(self):
|
|
||||||
"""测试所有代理"""
|
|
||||||
proxy_pool = get_global_proxy_pool()
|
|
||||||
return proxy_pool.test_all_proxies()
|
|
||||||
|
|
||||||
def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION):
|
|
||||||
"""为特定VM强制轮换代理"""
|
|
||||||
logger.info(f"Force rotating proxy for VM {vm_path}")
|
|
||||||
|
|
||||||
# 这里需要重新创建实例来应用新的代理配置
|
|
||||||
# 在实际应用中,可能需要保存当前状态并恢复
|
|
||||||
proxy_pool = get_global_proxy_pool()
|
|
||||||
new_proxy = proxy_pool.get_next_proxy()
|
|
||||||
|
|
||||||
if new_proxy:
|
|
||||||
logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning(f"No available proxy for VM {vm_path}")
|
|
||||||
return False
|
|
||||||
@@ -47,48 +47,62 @@ class AWSProviderWithProxy(Provider):
|
|||||||
proxy_url = self._format_proxy_url(self.current_proxy)
|
proxy_url = self._format_proxy_url(self.current_proxy)
|
||||||
|
|
||||||
user_data_script = f"""#!/bin/bash
|
user_data_script = f"""#!/bin/bash
|
||||||
# 配置系统代理
|
# Configure system proxy
|
||||||
echo 'export http_proxy={proxy_url}' >> /etc/environment
|
echo 'export http_proxy={proxy_url}' >> /etc/environment
|
||||||
echo 'export https_proxy={proxy_url}' >> /etc/environment
|
echo 'export https_proxy={proxy_url}' >> /etc/environment
|
||||||
echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment
|
echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment
|
||||||
echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment
|
echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment
|
||||||
|
|
||||||
# 配置apt代理
|
# Configure apt proxy
|
||||||
cat > /etc/apt/apt.conf.d/95proxy << EOF
|
cat > /etc/apt/apt.conf.d/95proxy << EOF
|
||||||
Acquire::http::Proxy "{proxy_url}";
|
Acquire::http::Proxy "{proxy_url}";
|
||||||
Acquire::https::Proxy "{proxy_url}";
|
Acquire::https::Proxy "{proxy_url}";
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
# 配置chrome/chromium代理
|
# Configure chrome/chromium proxy
|
||||||
mkdir -p /etc/opt/chrome/policies/managed
|
mkdir -p /etc/opt/chrome/policies/managed
|
||||||
cat > /etc/opt/chrome/policies/managed/proxy.json << EOF
|
cat > /etc/opt/chrome/policies/managed/proxy.json << EOF
|
||||||
{{
|
{{
|
||||||
"ProxyMode": "fixed_servers",
|
"ProxyMode": "fixed_servers",
|
||||||
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
|
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
|
||||||
}}
|
}}
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
# 配置firefox代理
|
# Configure chromium proxy (Ubuntu default)
|
||||||
mkdir -p /etc/firefox/policies
|
mkdir -p /etc/chromium/policies/managed
|
||||||
cat > /etc/firefox/policies/policies.json << EOF
|
cat > /etc/chromium/policies/managed/proxy.json << EOF
|
||||||
{{
|
{{
|
||||||
"policies": {{
|
"ProxyMode": "fixed_servers",
|
||||||
"Proxy": {{
|
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
|
||||||
"Mode": "manual",
|
}}
|
||||||
"HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
|
EOF
|
||||||
"HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
|
|
||||||
"UseHTTPProxyForAllProtocols": true
|
|
||||||
}}
|
|
||||||
}}
|
|
||||||
}}
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# 重新加载环境变量
|
# Configure firefox proxy - support multiple possible paths
|
||||||
source /etc/environment
|
for firefox_dir in /etc/firefox/policies /usr/lib/firefox/distribution/policies /etc/firefox-esr/policies; do
|
||||||
|
if [ -d "$(dirname "$firefox_dir")" ]; then
|
||||||
|
mkdir -p "$firefox_dir"
|
||||||
|
cat > "$firefox_dir/policies.json" << EOF
|
||||||
|
{{
|
||||||
|
"policies": {{
|
||||||
|
"Proxy": {{
|
||||||
|
"Mode": "manual",
|
||||||
|
"HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
|
||||||
|
"HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
|
||||||
|
"UseHTTPProxyForAllProtocols": true
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
EOF
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
# 记录代理配置日志
|
# Reload environment variables
|
||||||
echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log
|
source /etc/environment
|
||||||
"""
|
|
||||||
|
# Log proxy configuration
|
||||||
|
echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log
|
||||||
|
"""
|
||||||
|
|
||||||
return base64.b64encode(user_data_script.encode()).decode()
|
return base64.b64encode(user_data_script.encode()).decode()
|
||||||
|
|
||||||
@@ -99,7 +113,7 @@ class AWSProviderWithProxy(Provider):
|
|||||||
else:
|
else:
|
||||||
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
|
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
|
||||||
|
|
||||||
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
|
def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs):
|
||||||
logger.info("Starting AWS VM with proxy configuration...")
|
logger.info("Starting AWS VM with proxy configuration...")
|
||||||
ec2_client = boto3.client('ec2', region_name=self.region)
|
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||||
|
|
||||||
|
|||||||
@@ -305,7 +305,25 @@ class OpenAICUAAgent:
|
|||||||
logger.info(f"Response: {response}")
|
logger.info(f"Response: {response}")
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"OpenAI API error: {str(e)},will retry in 1s...")
|
logger.error(f"OpenAI API error: {str(e)}")
|
||||||
|
new_screenshot = self.env._get_obs()
|
||||||
|
new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8')
|
||||||
|
|
||||||
|
# Update the image in the last message based on its structure
|
||||||
|
last_message = self.cua_messages[-1]
|
||||||
|
if "output" in last_message:
|
||||||
|
# Computer call output message structure
|
||||||
|
last_message["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
|
||||||
|
elif "content" in last_message:
|
||||||
|
# User message structure - find and update the image content
|
||||||
|
for content_item in last_message["content"]:
|
||||||
|
if content_item.get("type") == "input_image":
|
||||||
|
content_item["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning("Unknown message structure, cannot update screenshot")
|
||||||
|
|
||||||
|
retry_count += 1
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def _handle_item(self, item: Dict[str, Any]) -> Optional[Union[str, Dict[str, Any]]]:
|
def _handle_item(self, item: Dict[str, Any]) -> Optional[Union[str, Dict[str, Any]]]:
|
||||||
@@ -446,10 +464,7 @@ class OpenAICUAAgent:
|
|||||||
logger.warning("Empty text for type action")
|
logger.warning("Empty text for type action")
|
||||||
return "import pyautogui\n# Empty text, no action taken"
|
return "import pyautogui\n# Empty text, no action taken"
|
||||||
|
|
||||||
pattern = r"(?<!\\)'"
|
# Use repr() to properly escape the string content without double-escaping
|
||||||
text = re.sub(pattern, r"\\'", text)
|
|
||||||
|
|
||||||
# 使用三重引号来确保字符串内容不会破坏格式
|
|
||||||
pyautogui_code = f"""import pyautogui\npyautogui.typewrite({repr(text)})"""
|
pyautogui_code = f"""import pyautogui\npyautogui.typewrite({repr(text)})"""
|
||||||
logger.info(f"Pyautogui code: {pyautogui_code}")
|
logger.info(f"Pyautogui code: {pyautogui_code}")
|
||||||
return pyautogui_code
|
return pyautogui_code
|
||||||
|
|||||||
@@ -2,10 +2,13 @@
|
|||||||
# Do not write any secret keys or sensitive information here.
|
# Do not write any secret keys or sensitive information here.
|
||||||
|
|
||||||
# Monitor configuration
|
# Monitor configuration
|
||||||
TASK_CONFIG_PATH=../evaluation_examples/test_all_error.json
|
TASK_CONFIG_PATH=../evaluation_examples/test_all.json
|
||||||
EXAMPLES_BASE_PATH=../evaluation_examples/examples
|
EXAMPLES_BASE_PATH=../evaluation_examples/examples
|
||||||
RESULTS_BASE_PATH=../results_all_error/pyautogui/screenshot/computer-use-preview
|
RESULTS_BASE_PATH=../results_operator_aws
|
||||||
|
ACTION_SPACE=pyautogui
|
||||||
|
OBSERVATION_TYPE=screenshot
|
||||||
|
MODEL_NAME=computer-use-preview
|
||||||
MAX_STEPS=150
|
MAX_STEPS=150
|
||||||
FLASK_PORT=80
|
FLASK_PORT=80
|
||||||
FLASK_HOST=0.0.0.0
|
FLASK_HOST=0.0.0.0
|
||||||
FLASK_DEBUG=false
|
FLASK_DEBUG=true
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
version: '3'
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
monitor:
|
monitor:
|
||||||
build:
|
build:
|
||||||
@@ -9,10 +7,11 @@ services:
|
|||||||
- "${FLASK_PORT:-8080}:8080"
|
- "${FLASK_PORT:-8080}:8080"
|
||||||
volumes:
|
volumes:
|
||||||
- .:/app/monitor
|
- .:/app/monitor
|
||||||
- ../evaluation_examples:/app/evaluation_examples
|
- ${TASK_CONFIG_PATH:-../evaluation_examples/test_all.json}:/app/evaluation_examples/test.json
|
||||||
- ../results_operator_aws:/app/results_operator_aws
|
- ${EXAMPLES_BASE_PATH:-../evaluation_examples/examples}:/app/evaluation_examples/examples
|
||||||
|
- ${RESULTS_BASE_PATH:-../results_operator_aws}:/app/results
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
- FLASK_ENV=production
|
- MONITOR_IN_DOCKER=true
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|||||||
@@ -17,12 +17,26 @@ TASK_STATUS_CACHE = {}
|
|||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# Load configuration from environment variables
|
MONITOR_IN_DOCKER = os.getenv("MONITOR_IN_DOCKER", "false").lower() == "true"
|
||||||
TASK_CONFIG_PATH = os.getenv("TASK_CONFIG_PATH", "../evaluation_examples/test_small.json")
|
|
||||||
EXAMPLES_BASE_PATH = os.getenv("EXAMPLES_BASE_PATH", "../evaluation_examples/examples")
|
if MONITOR_IN_DOCKER:
|
||||||
RESULTS_BASE_PATH = os.getenv("RESULTS_BASE_PATH", "../results_operator_aws/pyautogui/screenshot/computer-use-preview")
|
# If running in Docker, use default paths
|
||||||
|
TASK_CONFIG_PATH = "/app/evaluation_examples/test.json"
|
||||||
|
EXAMPLES_BASE_PATH = "/app/evaluation_examples/examples"
|
||||||
|
RESULTS_BASE_PATH = "/app/results"
|
||||||
|
else:
|
||||||
|
# Load configuration from environment variables
|
||||||
|
TASK_CONFIG_PATH = os.getenv("TASK_CONFIG_PATH", "../evaluation_examples/test_small.json")
|
||||||
|
EXAMPLES_BASE_PATH = os.getenv("EXAMPLES_BASE_PATH", "../evaluation_examples/examples")
|
||||||
|
RESULTS_BASE_PATH = os.getenv("RESULTS_BASE_PATH", "../results")
|
||||||
|
|
||||||
|
ACTION_SPACE=os.getenv("ACTION_SPACE", "pyautogui")
|
||||||
|
OBSERVATION_TYPE=os.getenv("OBSERVATION_TYPE", "screenshot")
|
||||||
|
MODEL_NAME=os.getenv("MODEL_NAME", "computer-use-preview")
|
||||||
MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
|
MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
|
||||||
|
|
||||||
|
RESULTS_PATH = os.path.join(RESULTS_BASE_PATH, ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME)
|
||||||
|
|
||||||
def load_task_list():
|
def load_task_list():
|
||||||
with open(TASK_CONFIG_PATH, 'r') as f:
|
with open(TASK_CONFIG_PATH, 'r') as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
@@ -35,7 +49,7 @@ def get_task_info(task_type, task_id):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_task_status(task_type, task_id):
|
def get_task_status(task_type, task_id):
|
||||||
result_dir = os.path.join(RESULTS_BASE_PATH, task_type, task_id)
|
result_dir = os.path.join(RESULTS_PATH, task_type, task_id)
|
||||||
|
|
||||||
if not os.path.exists(result_dir):
|
if not os.path.exists(result_dir):
|
||||||
return {
|
return {
|
||||||
@@ -167,7 +181,7 @@ def get_task_status_brief(task_type, task_id):
|
|||||||
if cache_key in TASK_STATUS_CACHE:
|
if cache_key in TASK_STATUS_CACHE:
|
||||||
return TASK_STATUS_CACHE[cache_key]
|
return TASK_STATUS_CACHE[cache_key]
|
||||||
|
|
||||||
result_dir = os.path.join(RESULTS_BASE_PATH, task_type, task_id)
|
result_dir = os.path.join(RESULTS_PATH, task_type, task_id)
|
||||||
|
|
||||||
if not os.path.exists(result_dir):
|
if not os.path.exists(result_dir):
|
||||||
return {
|
return {
|
||||||
@@ -367,7 +381,7 @@ def api_tasks_brief():
|
|||||||
@app.route('/task/<task_type>/<task_id>/screenshot/<path:filename>')
|
@app.route('/task/<task_type>/<task_id>/screenshot/<path:filename>')
|
||||||
def task_screenshot(task_type, task_id, filename):
|
def task_screenshot(task_type, task_id, filename):
|
||||||
"""Get task screenshot"""
|
"""Get task screenshot"""
|
||||||
screenshot_path = os.path.join(RESULTS_BASE_PATH, task_type, task_id, filename)
|
screenshot_path = os.path.join(RESULTS_PATH, task_type, task_id, filename)
|
||||||
if os.path.exists(screenshot_path):
|
if os.path.exists(screenshot_path):
|
||||||
return send_file(screenshot_path, mimetype='image/png')
|
return send_file(screenshot_path, mimetype='image/png')
|
||||||
else:
|
else:
|
||||||
@@ -376,7 +390,7 @@ def task_screenshot(task_type, task_id, filename):
|
|||||||
@app.route('/task/<task_type>/<task_id>/recording')
|
@app.route('/task/<task_type>/<task_id>/recording')
|
||||||
def task_recording(task_type, task_id):
|
def task_recording(task_type, task_id):
|
||||||
"""Get task recording video"""
|
"""Get task recording video"""
|
||||||
recording_path = os.path.join(RESULTS_BASE_PATH, task_type, task_id, "recording.mp4")
|
recording_path = os.path.join(RESULTS_PATH, task_type, task_id, "recording.mp4")
|
||||||
if os.path.exists(recording_path):
|
if os.path.exists(recording_path):
|
||||||
response = send_file(recording_path, mimetype='video/mp4')
|
response = send_file(recording_path, mimetype='video/mp4')
|
||||||
# Add headers to improve mobile compatibility
|
# Add headers to improve mobile compatibility
|
||||||
|
|||||||
@@ -237,7 +237,9 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
shared_scores,
|
shared_scores,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
try:
|
try:
|
||||||
env.controller.end_recording(
|
env.controller.end_recording(
|
||||||
os.path.join(example_result_dir, "recording.mp4")
|
os.path.join(example_result_dir, "recording.mp4")
|
||||||
|
|||||||
Reference in New Issue
Block a user