refactor: remove AWSVMManagerWithProxy and integrate proxy support directly into AWSVMManager for streamlined VM allocation;

minor fix on openai_cua_agent
This commit is contained in:
Timothyxxx
2025-06-06 02:55:50 +08:00
parent 8b7727d955
commit 8373f7cff2
4 changed files with 72 additions and 344 deletions

View File

@@ -20,14 +20,13 @@ def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: b
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
return VirtualBoxVMManager(), VirtualBoxProvider(region)
elif provider_name in ["aws", "amazon web services"]:
from desktop_env.providers.aws.manager import AWSVMManager
if use_proxy:
# Use proxy-enabled AWS provider
from desktop_env.providers.aws.manager_with_proxy import AWSVMManagerWithProxy
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
return AWSVMManagerWithProxy(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
return AWSVMManager(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
else:
# Use regular AWS provider
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
return AWSVMManager(), AWSProvider(region)
elif provider_name == "azure":

View File

@@ -18,11 +18,16 @@ if not os.getenv('AWS_SUBNET_ID') or not os.getenv('AWS_SECURITY_GROUP_ID'):
from desktop_env.providers.base import VMManager
# Import proxy-related modules only when needed
try:
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
PROXY_SUPPORT_AVAILABLE = True
except ImportError:
PROXY_SUPPORT_AVAILABLE = False
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
logger.setLevel(logging.INFO)
REGISTRY_PATH = '.aws_vms'
DEFAULT_REGION = "us-east-1"
# todo: Add doc for the configuration of image, security group and network interface
# todo: public the AMI images
@@ -118,17 +123,55 @@ def _allocate_vm(region=DEFAULT_REGION):
return instance_id
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
"""Allocate a VM with proxy configuration"""
if not PROXY_SUPPORT_AVAILABLE:
logger.warning("Proxy support not available, falling back to regular VM allocation")
return _allocate_vm(region)
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
# Initialize proxy pool if needed
if proxy_config_file:
init_proxy_pool(proxy_config_file)
# Get current proxy
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
if current_proxy:
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
# Create provider instance
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
# Create new instance
instance_id = provider.create_instance_with_proxy(
image_id=IMAGE_ID_MAP[region],
instance_type=INSTANCE_TYPE,
security_groups=[os.getenv('AWS_SECURITY_GROUP_ID')],
subnet_id=os.getenv('AWS_SUBNET_ID')
)
return instance_id
class AWSVMManager(VMManager):
"""
AWS VM Manager for managing virtual machines on AWS.
AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs.
This class remains the interface of VMManager for compatibility with other components.
This class supports both regular VM allocation and proxy-enabled VM allocation.
"""
def __init__(self, registry_path=REGISTRY_PATH):
self.registry_path = registry_path
def __init__(self, proxy_config_file=None, **kwargs):
self.proxy_config_file = proxy_config_file
# self.lock = FileLock(".aws_lck", timeout=60)
self.initialize_registry()
# Initialize proxy pool if proxy configuration is provided
if proxy_config_file and PROXY_SUPPORT_AVAILABLE:
init_proxy_pool(proxy_config_file)
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
def initialize_registry(self, **kwargs):
pass
@@ -164,6 +207,10 @@ class AWSVMManager(VMManager):
pass
def get_vm_path(self, region=DEFAULT_REGION, **kwargs):
logger.info("Allocating a new VM in region: {}".format(region))
new_vm_path = _allocate_vm(region)
if self.proxy_config_file:
logger.info("Allocating a new VM with proxy configuration in region: {}".format(region))
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
else:
logger.info("Allocating a new VM in region: {}".format(region))
new_vm_path = _allocate_vm(region)
return new_vm_path

View File

@@ -1,329 +0,0 @@
import os
from filelock import FileLock
import boto3
import psutil
import logging
from desktop_env.providers.base import VMManager
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy")
logger.setLevel(logging.INFO)
REGISTRY_PATH = '.aws_vms_proxy'
DEFAULT_REGION = "us-east-1"
IMAGE_ID_MAP = {
"us-east-1": "ami-05e7d7bd279ea4f14",
"ap-east-1": "ami-0c092a5b8be4116f5"
}
INSTANCE_TYPE = "t3.medium"
NETWORK_INTERFACE_MAP = {
"us-east-1": [
{
"SubnetId": "subnet-037edfff66c2eb894",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-0342574803206ee9c"
]
}
],
"ap-east-1": [
{
"SubnetId": "subnet-011060501be0b589c",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-090470e64df78f6eb"
]
}
]
}
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
"""分配带有代理配置的VM"""
from .provider_with_proxy import AWSProviderWithProxy
# 初始化代理池(如果还没有初始化)
if proxy_config_file:
init_proxy_pool(proxy_config_file)
# 获取当前代理
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
if current_proxy:
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
# 创建provider实例
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
# 创建新实例
instance_id = provider.create_instance_with_proxy(
image_id=IMAGE_ID_MAP[region],
instance_type=INSTANCE_TYPE,
security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"],
subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"]
)
return instance_id
class AWSVMManagerWithProxy(VMManager):
def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None):
self.registry_path = registry_path
self.lock = FileLock(".aws_proxy_lck", timeout=60)
self.proxy_config_file = proxy_config_file
self.initialize_registry()
# 初始化代理池
if proxy_config_file:
init_proxy_pool(proxy_config_file)
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
def initialize_registry(self):
with self.lock:
if not os.path.exists(self.registry_path):
with open(self.registry_path, 'w') as file:
file.write('')
def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True):
if lock_needed:
with self.lock:
self._add_vm(vm_path, region, proxy_info)
else:
self._add_vm(vm_path, region, proxy_info)
def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None):
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# 格式: vm_path@region|status|proxy_host:proxy_port
vm_path_at_vm_region = f"{vm_path}@{region}"
proxy_str = ""
if proxy_info:
proxy_str = f"{proxy_info['host']}:{proxy_info['port']}"
new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n'
new_lines = lines + [new_line]
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._delete_vm(vm_path, region)
else:
self._delete_vm(vm_path, region)
def _delete_vm(self, vm_path, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
if vm_path_at_vm_region == f"{vm_path}@{region}":
continue
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._occupy_vm(vm_path, pid, region)
else:
self._occupy_vm(vm_path, pid, region)
def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
registered_vm_path = parts[0]
if registered_vm_path == f"{vm_path}@{region}":
proxy_str = parts[2] if len(parts) > 2 else ""
new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n')
else:
new_lines.append(line)
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def check_and_clean(self, lock_needed=True):
if lock_needed:
with self.lock:
self._check_and_clean()
else:
self._check_and_clean()
def _check_and_clean(self):
# Get active PIDs
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_path_at_vm_regions = {}
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# Collect all VM paths and their regions
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if vm_region not in vm_path_at_vm_regions:
vm_path_at_vm_regions[vm_region] = []
vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str))
# Process each region
for region, vm_info_list in vm_path_at_vm_regions.items():
ec2_client = boto3.client('ec2', region_name=region)
instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list]
try:
response = ec2_client.describe_instances(InstanceIds=instance_ids)
reservations = response.get('Reservations', [])
terminated_ids = set()
stopped_ids = set()
active_ids = set()
for reservation in reservations:
for instance in reservation.get('Instances', []):
instance_id = instance.get('InstanceId')
instance_state = instance['State']['Name']
if instance_state in ['terminated', 'shutting-down']:
terminated_ids.add(instance_id)
elif instance_state == 'stopped':
stopped_ids.add(instance_id)
else:
active_ids.add(instance_id)
for vm_path_at_vm_region, status, proxy_str in vm_info_list:
vm_path = vm_path_at_vm_region.split('@')[0]
if vm_path in terminated_ids:
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
continue
elif vm_path in stopped_ids:
logger.info(f"VM {vm_path} stopped, mark it as free")
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
continue
if status == "free":
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
elif status.isdigit() and int(status) in active_pids:
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
else:
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
except Exception as e:
logger.error(f"Error checking instances in region {region}: {e}")
continue
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
return self._list_free_vms(region)
else:
return self._list_free_vms(region)
def _list_free_vms(self, region=DEFAULT_REGION):
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if status == "free" and vm_region == region:
free_vms.append((vm_path, status, proxy_str))
return free_vms
def get_vm_path(self, region=DEFAULT_REGION):
with self.lock:
if not AWSVMManagerWithProxy.checked_and_cleaned:
AWSVMManagerWithProxy.checked_and_cleaned = True
self._check_and_clean()
allocation_needed = False
with self.lock:
free_vms_paths = self._list_free_vms(region)
if len(free_vms_paths) == 0:
allocation_needed = True
else:
chosen_vm_path, _, proxy_str = free_vms_paths[0]
self._occupy_vm(chosen_vm_path, os.getpid(), region)
logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}")
return chosen_vm_path
if allocation_needed:
logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕")
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
# 获取当前使用的代理信息
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
proxy_info = None
if current_proxy:
proxy_info = {
'host': current_proxy.host,
'port': current_proxy.port
}
with self.lock:
self._add_vm(new_vm_path, region, proxy_info)
self._occupy_vm(new_vm_path, os.getpid(), region)
return new_vm_path
def get_proxy_stats(self):
"""获取代理池统计信息"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.get_stats()
def test_all_proxies(self):
"""测试所有代理"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.test_all_proxies()
def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION):
"""为特定VM强制轮换代理"""
logger.info(f"Force rotating proxy for VM {vm_path}")
# 这里需要重新创建实例来应用新的代理配置
# 在实际应用中,可能需要保存当前状态并恢复
proxy_pool = get_global_proxy_pool()
new_proxy = proxy_pool.get_next_proxy()
if new_proxy:
logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}")
return True
else:
logger.warning(f"No available proxy for VM {vm_path}")
return False

View File

@@ -309,7 +309,21 @@ class OpenAICUAAgent:
logger.error(f"OpenAI API error: {str(e)}")
new_screenshot = self.env._get_obs()
new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8')
self.cua_messages[-1]["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
# Update the image in the last message based on its structure
last_message = self.cua_messages[-1]
if "output" in last_message:
# Computer call output message structure
last_message["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
elif "content" in last_message:
# User message structure - find and update the image content
for content_item in last_message["content"]:
if content_item.get("type") == "input_image":
content_item["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
break
else:
logger.warning("Unknown message structure, cannot update screenshot")
retry_count += 1
time.sleep(1)
raise Exception("Failed to make OpenAI API call after 3 retries")
@@ -452,10 +466,7 @@ class OpenAICUAAgent:
logger.warning("Empty text for type action")
return "import pyautogui\n# Empty text, no action taken"
pattern = r"(?<!\\)'"
text = re.sub(pattern, r"\\'", text)
# 使用三重引号来确保字符串内容不会破坏格式
# Use repr() to properly escape the string content without double-escaping
pyautogui_code = f"""import pyautogui\npyautogui.typewrite({repr(text)})"""
logger.info(f"Pyautogui code: {pyautogui_code}")
return pyautogui_code