From bfd0a7ad0d05b1a24178ca90e3a3d4a12deea68b Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 6 Jun 2025 00:36:21 +0800 Subject: [PATCH 1/3] feat: implement proxy management for AWS VM provider and enhance task configuration handling --- desktop_env/desktop_env.py | 37 +- desktop_env/providers/__init__.py | 20 +- .../providers/aws/manager_with_proxy.py | 329 ++++++++++++++++++ .../providers/aws/provider_with_proxy.py | 261 ++++++++++++++ desktop_env/providers/aws/proxy_pool.py | 193 ++++++++++ 5 files changed, 835 insertions(+), 5 deletions(-) create mode 100644 desktop_env/providers/aws/manager_with_proxy.py create mode 100644 desktop_env/providers/aws/provider_with_proxy.py create mode 100644 desktop_env/providers/aws/proxy_pool.py diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 4a24a55..2a13afa 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -54,13 +54,17 @@ class DesktopEnv(gym.Env): """ # Initialize VM manager and vitualization provider self.region = region + self.provider_name = provider_name # Default TODO: self.server_port = 5000 self.chromium_port = 9222 self.vnc_port = 8006 self.vlc_port = 8080 - self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) + + # Initialize with default (no proxy) provider + self.current_use_proxy = False + self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False) self.os_type = os_type @@ -149,6 +153,32 @@ class DesktopEnv(gym.Env): self._step_no = 0 self.action_history.clear() + # Check and handle proxy requirement changes BEFORE starting emulator + if task_config is not None: + task_use_proxy = task_config.get("proxy", False) + if task_use_proxy != self.current_use_proxy: + logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}") + + # Close current provider if it exists + if hasattr(self, 'provider') and self.provider: + try: + self.provider.stop_emulator(self.path_to_vm) + except Exception as e: + logger.warning(f"Failed to stop current provider: {e}") + + # Create new provider with appropriate proxy setting + self.current_use_proxy = task_use_proxy + self.manager, self.provider = create_vm_manager_and_provider( + self.provider_name, + self.region, + use_proxy=task_use_proxy + ) + + if task_use_proxy: + logger.info("Using proxy-enabled AWS provider.") + else: + logger.info("Using regular AWS provider.") + logger.info("Reverting to snapshot to {}...".format(self.snapshot_name)) self._revert_to_snapshot() logger.info("Starting emulator...") @@ -184,12 +214,17 @@ class DesktopEnv(gym.Env): return self.controller.get_vm_screen_size() def _set_task_info(self, task_config: Dict[str, Any]): + """Set task info (proxy logic is handled in reset method)""" self.task_id: str = task_config["id"] self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id) os.makedirs(self.cache_dir, exist_ok=True) self.instruction = task_config["instruction"] self.config = task_config["config"] if "config" in task_config else [] + + self._set_evaluator_info(task_config) + def _set_evaluator_info(self, task_config: Dict[str, Any]): + """Set evaluator information from task config""" # evaluator dict # func -> metric function string, or list of metric function strings # conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or" diff --git a/desktop_env/providers/__init__.py b/desktop_env/providers/__init__.py index d1359a1..19d9b98 100644 --- a/desktop_env/providers/__init__.py +++ b/desktop_env/providers/__init__.py @@ -1,9 +1,14 @@ from desktop_env.providers.base import VMManager, Provider -def create_vm_manager_and_provider(provider_name: str, region: str): +def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: bool = False): """ Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name. + + Args: + provider_name (str): The name of the provider (e.g., "aws", "vmware", etc.) + region (str): The region for the provider + use_proxy (bool): Whether to use proxy-enabled providers (currently only supported for AWS) """ provider_name = provider_name.lower().strip() if provider_name == "vmware": @@ -15,9 +20,16 @@ def create_vm_manager_and_provider(provider_name: str, region: str): from desktop_env.providers.virtualbox.provider import VirtualBoxProvider return VirtualBoxVMManager(), VirtualBoxProvider(region) elif provider_name in ["aws", "amazon web services"]: - from desktop_env.providers.aws.manager import AWSVMManager - from desktop_env.providers.aws.provider import AWSProvider - return AWSVMManager(), AWSProvider(region) + if use_proxy: + # Use proxy-enabled AWS provider + from desktop_env.providers.aws.manager_with_proxy import AWSVMManagerWithProxy + from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy + return AWSVMManagerWithProxy(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json") + else: + # Use regular AWS provider + from desktop_env.providers.aws.manager import AWSVMManager + from desktop_env.providers.aws.provider import AWSProvider + return AWSVMManager(), AWSProvider(region) elif provider_name == "azure": from desktop_env.providers.azure.manager import AzureVMManager from desktop_env.providers.azure.provider import AzureProvider diff --git a/desktop_env/providers/aws/manager_with_proxy.py b/desktop_env/providers/aws/manager_with_proxy.py new file mode 100644 index 0000000..f4150ec --- /dev/null +++ b/desktop_env/providers/aws/manager_with_proxy.py @@ -0,0 +1,329 @@ +import os +from filelock import FileLock +import boto3 +import psutil +import logging + +from desktop_env.providers.base import VMManager +from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool + +logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy") +logger.setLevel(logging.INFO) + +REGISTRY_PATH = '.aws_vms_proxy' + +DEFAULT_REGION = "us-east-1" +IMAGE_ID_MAP = { + "us-east-1": "ami-05e7d7bd279ea4f14", + "ap-east-1": "ami-0c092a5b8be4116f5" +} + +INSTANCE_TYPE = "t3.medium" + +NETWORK_INTERFACE_MAP = { + "us-east-1": [ + { + "SubnetId": "subnet-037edfff66c2eb894", + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": [ + "sg-0342574803206ee9c" + ] + } + ], + "ap-east-1": [ + { + "SubnetId": "subnet-011060501be0b589c", + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": [ + "sg-090470e64df78f6eb" + ] + } + ] +} + + +def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None): + """分配带有代理配置的VM""" + from .provider_with_proxy import AWSProviderWithProxy + + # 初始化代理池(如果还没有初始化) + if proxy_config_file: + init_proxy_pool(proxy_config_file) + + # 获取当前代理 + proxy_pool = get_global_proxy_pool() + current_proxy = proxy_pool.get_next_proxy() + + if current_proxy: + logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}") + + # 创建provider实例 + provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file) + + # 创建新实例 + instance_id = provider.create_instance_with_proxy( + image_id=IMAGE_ID_MAP[region], + instance_type=INSTANCE_TYPE, + security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"], + subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"] + ) + + return instance_id + + +class AWSVMManagerWithProxy(VMManager): + def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None): + self.registry_path = registry_path + self.lock = FileLock(".aws_proxy_lck", timeout=60) + self.proxy_config_file = proxy_config_file + self.initialize_registry() + + # 初始化代理池 + if proxy_config_file: + init_proxy_pool(proxy_config_file) + logger.info(f"Proxy pool initialized with config: {proxy_config_file}") + + def initialize_registry(self): + with self.lock: + if not os.path.exists(self.registry_path): + with open(self.registry_path, 'w') as file: + file.write('') + + def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True): + if lock_needed: + with self.lock: + self._add_vm(vm_path, region, proxy_info) + else: + self._add_vm(vm_path, region, proxy_info) + + def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None): + with open(self.registry_path, 'r') as file: + lines = file.readlines() + + # 格式: vm_path@region|status|proxy_host:proxy_port + vm_path_at_vm_region = f"{vm_path}@{region}" + proxy_str = "" + if proxy_info: + proxy_str = f"{proxy_info['host']}:{proxy_info['port']}" + + new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n' + new_lines = lines + [new_line] + + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + self._delete_vm(vm_path, region) + else: + self._delete_vm(vm_path, region) + + def _delete_vm(self, vm_path, region=DEFAULT_REGION): + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + parts = line.strip().split('|') + if len(parts) >= 2: + vm_path_at_vm_region = parts[0] + if vm_path_at_vm_region == f"{vm_path}@{region}": + continue + new_lines.append(line) + + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + self._occupy_vm(vm_path, pid, region) + else: + self._occupy_vm(vm_path, pid, region) + + def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION): + new_lines = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + parts = line.strip().split('|') + if len(parts) >= 2: + registered_vm_path = parts[0] + if registered_vm_path == f"{vm_path}@{region}": + proxy_str = parts[2] if len(parts) > 2 else "" + new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n') + else: + new_lines.append(line) + else: + new_lines.append(line) + + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def check_and_clean(self, lock_needed=True): + if lock_needed: + with self.lock: + self._check_and_clean() + else: + self._check_and_clean() + + def _check_and_clean(self): + # Get active PIDs + active_pids = {p.pid for p in psutil.process_iter()} + + new_lines = [] + vm_path_at_vm_regions = {} + + with open(self.registry_path, 'r') as file: + lines = file.readlines() + + # Collect all VM paths and their regions + for line in lines: + parts = line.strip().split('|') + if len(parts) >= 2: + vm_path_at_vm_region = parts[0] + status = parts[1] + proxy_str = parts[2] if len(parts) > 2 else "" + + vm_path, vm_region = vm_path_at_vm_region.split("@") + if vm_region not in vm_path_at_vm_regions: + vm_path_at_vm_regions[vm_region] = [] + vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str)) + + # Process each region + for region, vm_info_list in vm_path_at_vm_regions.items(): + ec2_client = boto3.client('ec2', region_name=region) + instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list] + + try: + response = ec2_client.describe_instances(InstanceIds=instance_ids) + reservations = response.get('Reservations', []) + + terminated_ids = set() + stopped_ids = set() + active_ids = set() + + for reservation in reservations: + for instance in reservation.get('Instances', []): + instance_id = instance.get('InstanceId') + instance_state = instance['State']['Name'] + if instance_state in ['terminated', 'shutting-down']: + terminated_ids.add(instance_id) + elif instance_state == 'stopped': + stopped_ids.add(instance_id) + else: + active_ids.add(instance_id) + + for vm_path_at_vm_region, status, proxy_str in vm_info_list: + vm_path = vm_path_at_vm_region.split('@')[0] + + if vm_path in terminated_ids: + logger.info(f"VM {vm_path} not found or terminated, releasing it.") + continue + elif vm_path in stopped_ids: + logger.info(f"VM {vm_path} stopped, mark it as free") + new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') + continue + + if status == "free": + new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') + elif status.isdigit() and int(status) in active_pids: + new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') + else: + new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') + + except Exception as e: + logger.error(f"Error checking instances in region {region}: {e}") + continue + + with open(self.registry_path, 'w') as file: + file.writelines(new_lines) + + def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True): + if lock_needed: + with self.lock: + return self._list_free_vms(region) + else: + return self._list_free_vms(region) + + def _list_free_vms(self, region=DEFAULT_REGION): + free_vms = [] + with open(self.registry_path, 'r') as file: + lines = file.readlines() + for line in lines: + parts = line.strip().split('|') + if len(parts) >= 2: + vm_path_at_vm_region = parts[0] + status = parts[1] + proxy_str = parts[2] if len(parts) > 2 else "" + + vm_path, vm_region = vm_path_at_vm_region.split("@") + if status == "free" and vm_region == region: + free_vms.append((vm_path, status, proxy_str)) + + return free_vms + + def get_vm_path(self, region=DEFAULT_REGION): + with self.lock: + if not AWSVMManagerWithProxy.checked_and_cleaned: + AWSVMManagerWithProxy.checked_and_cleaned = True + self._check_and_clean() + + allocation_needed = False + with self.lock: + free_vms_paths = self._list_free_vms(region) + + if len(free_vms_paths) == 0: + allocation_needed = True + else: + chosen_vm_path, _, proxy_str = free_vms_paths[0] + self._occupy_vm(chosen_vm_path, os.getpid(), region) + logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}") + return chosen_vm_path + + if allocation_needed: + logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕") + new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file) + + # 获取当前使用的代理信息 + proxy_pool = get_global_proxy_pool() + current_proxy = proxy_pool.get_next_proxy() + proxy_info = None + if current_proxy: + proxy_info = { + 'host': current_proxy.host, + 'port': current_proxy.port + } + + with self.lock: + self._add_vm(new_vm_path, region, proxy_info) + self._occupy_vm(new_vm_path, os.getpid(), region) + return new_vm_path + + def get_proxy_stats(self): + """获取代理池统计信息""" + proxy_pool = get_global_proxy_pool() + return proxy_pool.get_stats() + + def test_all_proxies(self): + """测试所有代理""" + proxy_pool = get_global_proxy_pool() + return proxy_pool.test_all_proxies() + + def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION): + """为特定VM强制轮换代理""" + logger.info(f"Force rotating proxy for VM {vm_path}") + + # 这里需要重新创建实例来应用新的代理配置 + # 在实际应用中,可能需要保存当前状态并恢复 + proxy_pool = get_global_proxy_pool() + new_proxy = proxy_pool.get_next_proxy() + + if new_proxy: + logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}") + return True + else: + logger.warning(f"No available proxy for VM {vm_path}") + return False \ No newline at end of file diff --git a/desktop_env/providers/aws/provider_with_proxy.py b/desktop_env/providers/aws/provider_with_proxy.py new file mode 100644 index 0000000..309e71b --- /dev/null +++ b/desktop_env/providers/aws/provider_with_proxy.py @@ -0,0 +1,261 @@ +import boto3 +from botocore.exceptions import ClientError +import base64 +import logging +import json +from typing import Optional + +from desktop_env.providers.base import Provider +from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool, ProxyInfo + +logger = logging.getLogger("desktopenv.providers.aws.AWSProviderWithProxy") +logger.setLevel(logging.INFO) + +WAIT_DELAY = 15 +MAX_ATTEMPTS = 10 + + +class AWSProviderWithProxy(Provider): + + def __init__(self, region: str = None, proxy_config_file: str = None): + super().__init__(region) + self.current_proxy: Optional[ProxyInfo] = None + + # 初始化代理池 + if proxy_config_file: + init_proxy_pool(proxy_config_file) + logger.info(f"Initialized proxy pool from {proxy_config_file}") + + # 获取下一个可用代理 + self._rotate_proxy() + + def _rotate_proxy(self): + """轮换到下一个可用代理""" + proxy_pool = get_global_proxy_pool() + self.current_proxy = proxy_pool.get_next_proxy() + + if self.current_proxy: + logger.info(f"Switched to proxy: {self.current_proxy.host}:{self.current_proxy.port}") + else: + logger.warning("No proxy available, using direct connection") + + def _generate_proxy_user_data(self) -> str: + """生成包含代理配置的user data脚本""" + if not self.current_proxy: + return "" + + proxy_url = self._format_proxy_url(self.current_proxy) + + user_data_script = f"""#!/bin/bash +# 配置系统代理 +echo 'export http_proxy={proxy_url}' >> /etc/environment +echo 'export https_proxy={proxy_url}' >> /etc/environment +echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment +echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment + +# 配置apt代理 +cat > /etc/apt/apt.conf.d/95proxy << EOF +Acquire::http::Proxy "{proxy_url}"; +Acquire::https::Proxy "{proxy_url}"; +EOF + +# 配置chrome/chromium代理 +mkdir -p /etc/opt/chrome/policies/managed +cat > /etc/opt/chrome/policies/managed/proxy.json << EOF +{{ + "ProxyMode": "fixed_servers", + "ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}" +}} +EOF + +# 配置firefox代理 +mkdir -p /etc/firefox/policies +cat > /etc/firefox/policies/policies.json << EOF +{{ + "policies": {{ + "Proxy": {{ + "Mode": "manual", + "HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}", + "HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}", + "UseHTTPProxyForAllProtocols": true + }} + }} +}} +EOF + +# 重新加载环境变量 +source /etc/environment + +# 记录代理配置日志 +echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log +""" + + return base64.b64encode(user_data_script.encode()).decode() + + def _format_proxy_url(self, proxy: ProxyInfo) -> str: + """格式化代理URL""" + if proxy.username and proxy.password: + return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}" + else: + return f"{proxy.protocol}://{proxy.host}:{proxy.port}" + + def start_emulator(self, path_to_vm: str, headless: bool): + logger.info("Starting AWS VM with proxy configuration...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + # 如果实例已经存在,直接启动 + ec2_client.start_instances(InstanceIds=[path_to_vm]) + logger.info(f"Instance {path_to_vm} is starting...") + + # Wait for the instance to be in the 'running' state + waiter = ec2_client.get_waiter('instance_running') + waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS}) + logger.info(f"Instance {path_to_vm} is now running.") + + except ClientError as e: + logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}") + raise + + def create_instance_with_proxy(self, image_id: str, instance_type: str, + security_groups: list, subnet_id: str) -> str: + """创建带有代理配置的新实例""" + ec2_client = boto3.client('ec2', region_name=self.region) + + user_data = self._generate_proxy_user_data() + + run_instances_params = { + "MaxCount": 1, + "MinCount": 1, + "ImageId": image_id, + "InstanceType": instance_type, + "EbsOptimized": True, + "NetworkInterfaces": [ + { + "SubnetId": subnet_id, + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "Groups": security_groups + } + ] + } + + if user_data: + run_instances_params["UserData"] = user_data + + try: + response = ec2_client.run_instances(**run_instances_params) + instance_id = response['Instances'][0]['InstanceId'] + + logger.info(f"Created new instance {instance_id} with proxy configuration") + + # 等待实例运行 + logger.info(f"Waiting for instance {instance_id} to be running...") + ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id]) + logger.info(f"Instance {instance_id} is ready.") + + return instance_id + + except ClientError as e: + logger.error(f"Failed to create instance with proxy: {str(e)}") + # 如果当前代理失败,尝试轮换代理 + if self.current_proxy: + proxy_pool = get_global_proxy_pool() + proxy_pool.mark_proxy_failed(self.current_proxy) + self._rotate_proxy() + raise + + def get_ip_address(self, path_to_vm: str) -> str: + logger.info("Getting AWS VM IP address...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + response = ec2_client.describe_instances(InstanceIds=[path_to_vm]) + for reservation in response['Reservations']: + for instance in reservation['Instances']: + private_ip_address = instance.get('PrivateIpAddress', '') + return private_ip_address + return '' + except ClientError as e: + logger.error(f"Failed to retrieve private IP address for the instance {path_to_vm}: {str(e)}") + raise + + def save_state(self, path_to_vm: str, snapshot_name: str): + logger.info("Saving AWS VM state...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + image_response = ec2_client.create_image(InstanceId=path_to_vm, Name=snapshot_name) + image_id = image_response['ImageId'] + logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.") + return image_id + except ClientError as e: + logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}") + raise + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + # 获取原实例详情 + instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm]) + instance = instance_details['Reservations'][0]['Instances'][0] + security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']] + subnet_id = instance['SubnetId'] + instance_type = instance['InstanceType'] + + # 终止旧实例 + ec2_client.terminate_instances(InstanceIds=[path_to_vm]) + logger.info(f"Old instance {path_to_vm} has been terminated.") + + # 轮换到新的代理 + self._rotate_proxy() + + # 创建新实例 + new_instance_id = self.create_instance_with_proxy( + snapshot_name, instance_type, security_groups, subnet_id + ) + + return new_instance_id + + except ClientError as e: + logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}") + raise + + def stop_emulator(self, path_to_vm, region=None): + logger.info(f"Stopping AWS VM {path_to_vm}...") + ec2_client = boto3.client('ec2', region_name=self.region) + + try: + ec2_client.stop_instances(InstanceIds=[path_to_vm]) + waiter = ec2_client.get_waiter('instance_stopped') + waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS}) + logger.info(f"Instance {path_to_vm} has been stopped.") + except ClientError as e: + logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}") + raise + + def get_current_proxy_info(self) -> Optional[dict]: + """获取当前代理信息""" + if self.current_proxy: + return { + 'host': self.current_proxy.host, + 'port': self.current_proxy.port, + 'protocol': self.current_proxy.protocol, + 'failed_count': self.current_proxy.failed_count + } + return None + + def force_rotate_proxy(self): + """强制轮换代理""" + logger.info("Force rotating proxy...") + if self.current_proxy: + proxy_pool = get_global_proxy_pool() + proxy_pool.mark_proxy_failed(self.current_proxy) + self._rotate_proxy() + + def get_proxy_stats(self) -> dict: + """获取代理池统计信息""" + proxy_pool = get_global_proxy_pool() + return proxy_pool.get_stats() \ No newline at end of file diff --git a/desktop_env/providers/aws/proxy_pool.py b/desktop_env/providers/aws/proxy_pool.py new file mode 100644 index 0000000..812df18 --- /dev/null +++ b/desktop_env/providers/aws/proxy_pool.py @@ -0,0 +1,193 @@ +import random +import requests +import logging +import time +from typing import List, Dict, Optional +from dataclasses import dataclass +from threading import Lock +import json + +logger = logging.getLogger("desktopenv.providers.aws.ProxyPool") +logger.setLevel(logging.INFO) + +@dataclass +class ProxyInfo: + host: str + port: int + username: Optional[str] = None + password: Optional[str] = None + protocol: str = "http" # http, https, socks5 + failed_count: int = 0 + last_used: float = 0 + is_active: bool = True + +class ProxyPool: + def __init__(self, config_file: str = None): + self.proxies: List[ProxyInfo] = [] + self.current_index = 0 + self.lock = Lock() + self.max_failures = 3 # 最大失败次数 + self.cooldown_time = 300 # 5分钟冷却时间 + + if config_file: + self.load_proxies_from_file(config_file) + + def load_proxies_from_file(self, config_file: str): + """从配置文件加载代理列表""" + try: + with open(config_file, 'r') as f: + proxy_configs = json.load(f) + + for config in proxy_configs: + proxy = ProxyInfo( + host=config['host'], + port=config['port'], + username=config.get('username'), + password=config.get('password'), + protocol=config.get('protocol', 'http') + ) + self.proxies.append(proxy) + + logger.info(f"Loaded {len(self.proxies)} proxies from {config_file}") + except Exception as e: + logger.error(f"Failed to load proxies from {config_file}: {e}") + + def add_proxy(self, host: str, port: int, username: str = None, + password: str = None, protocol: str = "http"): + """添加代理到池中""" + proxy = ProxyInfo(host=host, port=port, username=username, + password=password, protocol=protocol) + with self.lock: + self.proxies.append(proxy) + logger.info(f"Added proxy {host}:{port}") + + def get_next_proxy(self) -> Optional[ProxyInfo]: + """获取下一个可用的代理""" + with self.lock: + if not self.proxies: + return None + + # 过滤掉失败次数过多的代理 + active_proxies = [p for p in self.proxies if self._is_proxy_available(p)] + + if not active_proxies: + logger.warning("No active proxies available") + return None + + # 轮询选择代理 + proxy = active_proxies[self.current_index % len(active_proxies)] + self.current_index += 1 + proxy.last_used = time.time() + + return proxy + + def _is_proxy_available(self, proxy: ProxyInfo) -> bool: + """检查代理是否可用""" + if not proxy.is_active: + return False + + if proxy.failed_count >= self.max_failures: + # 检查是否过了冷却时间 + if time.time() - proxy.last_used < self.cooldown_time: + return False + else: + # 重置失败计数 + proxy.failed_count = 0 + + return True + + def mark_proxy_failed(self, proxy: ProxyInfo): + """标记代理失败""" + with self.lock: + proxy.failed_count += 1 + if proxy.failed_count >= self.max_failures: + logger.warning(f"Proxy {proxy.host}:{proxy.port} marked as failed " + f"(failures: {proxy.failed_count})") + + def mark_proxy_success(self, proxy: ProxyInfo): + """标记代理成功""" + with self.lock: + proxy.failed_count = 0 + + def test_proxy(self, proxy: ProxyInfo, test_url: str = "http://httpbin.org/ip", + timeout: int = 10) -> bool: + """测试代理是否正常工作""" + try: + proxy_url = self._format_proxy_url(proxy) + proxies = { + 'http': proxy_url, + 'https': proxy_url + } + + response = requests.get(test_url, proxies=proxies, timeout=timeout) + if response.status_code == 200: + self.mark_proxy_success(proxy) + return True + else: + self.mark_proxy_failed(proxy) + return False + + except Exception as e: + logger.debug(f"Proxy test failed for {proxy.host}:{proxy.port}: {e}") + self.mark_proxy_failed(proxy) + return False + + def _format_proxy_url(self, proxy: ProxyInfo) -> str: + """格式化代理URL""" + if proxy.username and proxy.password: + return f"{proxy.protocol}://{proxy.username}:{proxy.password}@{proxy.host}:{proxy.port}" + else: + return f"{proxy.protocol}://{proxy.host}:{proxy.port}" + + def get_proxy_dict(self, proxy: ProxyInfo) -> Dict[str, str]: + """获取requests库使用的代理字典""" + proxy_url = self._format_proxy_url(proxy) + return { + 'http': proxy_url, + 'https': proxy_url + } + + def test_all_proxies(self, test_url: str = "http://httpbin.org/ip"): + """测试所有代理""" + logger.info("Testing all proxies...") + working_count = 0 + + for proxy in self.proxies: + if self.test_proxy(proxy, test_url): + working_count += 1 + logger.info(f"✓ Proxy {proxy.host}:{proxy.port} is working") + else: + logger.warning(f"✗ Proxy {proxy.host}:{proxy.port} failed") + + logger.info(f"Proxy test completed: {working_count}/{len(self.proxies)} working") + return working_count + + def get_stats(self) -> Dict: + """获取代理池统计信息""" + with self.lock: + total = len(self.proxies) + active = len([p for p in self.proxies if self._is_proxy_available(p)]) + failed = len([p for p in self.proxies if p.failed_count >= self.max_failures]) + + return { + 'total': total, + 'active': active, + 'failed': failed, + 'success_rate': active / total if total > 0 else 0 + } + +# 全局代理池实例 +_proxy_pool = None + +def get_global_proxy_pool() -> ProxyPool: + """获取全局代理池实例""" + global _proxy_pool + if _proxy_pool is None: + _proxy_pool = ProxyPool() + return _proxy_pool + +def init_proxy_pool(config_file: str = None): + """初始化全局代理池""" + global _proxy_pool + _proxy_pool = ProxyPool(config_file) + return _proxy_pool \ No newline at end of file From 8b7727d955443e50f0980214b95370c28c9f3473 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 6 Jun 2025 02:39:16 +0800 Subject: [PATCH 2/3] refactor: update proxy configuration script for AWSProviderWithProxy to enhance clarity and support multiple Firefox paths --- .../providers/aws/provider_with_proxy.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/desktop_env/providers/aws/provider_with_proxy.py b/desktop_env/providers/aws/provider_with_proxy.py index 309e71b..9472433 100644 --- a/desktop_env/providers/aws/provider_with_proxy.py +++ b/desktop_env/providers/aws/provider_with_proxy.py @@ -47,19 +47,19 @@ class AWSProviderWithProxy(Provider): proxy_url = self._format_proxy_url(self.current_proxy) user_data_script = f"""#!/bin/bash -# 配置系统代理 +# Configure system proxy echo 'export http_proxy={proxy_url}' >> /etc/environment echo 'export https_proxy={proxy_url}' >> /etc/environment echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment -# 配置apt代理 +# Configure apt proxy cat > /etc/apt/apt.conf.d/95proxy << EOF Acquire::http::Proxy "{proxy_url}"; Acquire::https::Proxy "{proxy_url}"; EOF -# 配置chrome/chromium代理 +# Configure chrome/chromium proxy mkdir -p /etc/opt/chrome/policies/managed cat > /etc/opt/chrome/policies/managed/proxy.json << EOF {{ @@ -68,9 +68,20 @@ cat > /etc/opt/chrome/policies/managed/proxy.json << EOF }} EOF -# 配置firefox代理 -mkdir -p /etc/firefox/policies -cat > /etc/firefox/policies/policies.json << EOF +# Configure chromium proxy (Ubuntu default) +mkdir -p /etc/chromium/policies/managed +cat > /etc/chromium/policies/managed/proxy.json << EOF +{{ + "ProxyMode": "fixed_servers", + "ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}" +}} +EOF + +# Configure firefox proxy - support multiple possible paths +for firefox_dir in /etc/firefox/policies /usr/lib/firefox/distribution/policies /etc/firefox-esr/policies; do + if [ -d "$(dirname "$firefox_dir")" ]; then + mkdir -p "$firefox_dir" + cat > "$firefox_dir/policies.json" << EOF {{ "policies": {{ "Proxy": {{ @@ -82,11 +93,14 @@ cat > /etc/firefox/policies/policies.json << EOF }} }} EOF + break + fi +done -# 重新加载环境变量 +# Reload environment variables source /etc/environment -# 记录代理配置日志 +# Log proxy configuration echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log """ From 8373f7cff2f126c00ca09209bc91efa4acbaf8fb Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 6 Jun 2025 02:55:50 +0800 Subject: [PATCH 3/3] refactor: remove AWSVMManagerWithProxy and integrate proxy support directly into AWSVMManager for streamlined VM allocation; minor fix on openai_cua_agent --- desktop_env/providers/__init__.py | 5 +- desktop_env/providers/aws/manager.py | 61 +++- .../providers/aws/manager_with_proxy.py | 329 ------------------ mm_agents/openai_cua_agent.py | 21 +- 4 files changed, 72 insertions(+), 344 deletions(-) delete mode 100644 desktop_env/providers/aws/manager_with_proxy.py diff --git a/desktop_env/providers/__init__.py b/desktop_env/providers/__init__.py index 19d9b98..792db58 100644 --- a/desktop_env/providers/__init__.py +++ b/desktop_env/providers/__init__.py @@ -20,14 +20,13 @@ def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: b from desktop_env.providers.virtualbox.provider import VirtualBoxProvider return VirtualBoxVMManager(), VirtualBoxProvider(region) elif provider_name in ["aws", "amazon web services"]: + from desktop_env.providers.aws.manager import AWSVMManager if use_proxy: # Use proxy-enabled AWS provider - from desktop_env.providers.aws.manager_with_proxy import AWSVMManagerWithProxy from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy - return AWSVMManagerWithProxy(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json") + return AWSVMManager(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json") else: # Use regular AWS provider - from desktop_env.providers.aws.manager import AWSVMManager from desktop_env.providers.aws.provider import AWSProvider return AWSVMManager(), AWSProvider(region) elif provider_name == "azure": diff --git a/desktop_env/providers/aws/manager.py b/desktop_env/providers/aws/manager.py index b9925a0..e28474a 100644 --- a/desktop_env/providers/aws/manager.py +++ b/desktop_env/providers/aws/manager.py @@ -18,11 +18,16 @@ if not os.getenv('AWS_SUBNET_ID') or not os.getenv('AWS_SECURITY_GROUP_ID'): from desktop_env.providers.base import VMManager +# Import proxy-related modules only when needed +try: + from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool + PROXY_SUPPORT_AVAILABLE = True +except ImportError: + PROXY_SUPPORT_AVAILABLE = False + logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager") logger.setLevel(logging.INFO) -REGISTRY_PATH = '.aws_vms' - DEFAULT_REGION = "us-east-1" # todo: Add doc for the configuration of image, security group and network interface # todo: public the AMI images @@ -118,17 +123,55 @@ def _allocate_vm(region=DEFAULT_REGION): return instance_id +def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None): + """Allocate a VM with proxy configuration""" + if not PROXY_SUPPORT_AVAILABLE: + logger.warning("Proxy support not available, falling back to regular VM allocation") + return _allocate_vm(region) + + from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy + + # Initialize proxy pool if needed + if proxy_config_file: + init_proxy_pool(proxy_config_file) + + # Get current proxy + proxy_pool = get_global_proxy_pool() + current_proxy = proxy_pool.get_next_proxy() + + if current_proxy: + logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}") + + # Create provider instance + provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file) + + # Create new instance + instance_id = provider.create_instance_with_proxy( + image_id=IMAGE_ID_MAP[region], + instance_type=INSTANCE_TYPE, + security_groups=[os.getenv('AWS_SECURITY_GROUP_ID')], + subnet_id=os.getenv('AWS_SUBNET_ID') + ) + + return instance_id + + class AWSVMManager(VMManager): """ AWS VM Manager for managing virtual machines on AWS. AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs. - This class remains the interface of VMManager for compatibility with other components. + This class supports both regular VM allocation and proxy-enabled VM allocation. """ - def __init__(self, registry_path=REGISTRY_PATH): - self.registry_path = registry_path + def __init__(self, proxy_config_file=None, **kwargs): + self.proxy_config_file = proxy_config_file # self.lock = FileLock(".aws_lck", timeout=60) self.initialize_registry() + + # Initialize proxy pool if proxy configuration is provided + if proxy_config_file and PROXY_SUPPORT_AVAILABLE: + init_proxy_pool(proxy_config_file) + logger.info(f"Proxy pool initialized with config: {proxy_config_file}") def initialize_registry(self, **kwargs): pass @@ -164,6 +207,10 @@ class AWSVMManager(VMManager): pass def get_vm_path(self, region=DEFAULT_REGION, **kwargs): - logger.info("Allocating a new VM in region: {}".format(region)) - new_vm_path = _allocate_vm(region) + if self.proxy_config_file: + logger.info("Allocating a new VM with proxy configuration in region: {}".format(region)) + new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file) + else: + logger.info("Allocating a new VM in region: {}".format(region)) + new_vm_path = _allocate_vm(region) return new_vm_path diff --git a/desktop_env/providers/aws/manager_with_proxy.py b/desktop_env/providers/aws/manager_with_proxy.py deleted file mode 100644 index f4150ec..0000000 --- a/desktop_env/providers/aws/manager_with_proxy.py +++ /dev/null @@ -1,329 +0,0 @@ -import os -from filelock import FileLock -import boto3 -import psutil -import logging - -from desktop_env.providers.base import VMManager -from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool - -logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy") -logger.setLevel(logging.INFO) - -REGISTRY_PATH = '.aws_vms_proxy' - -DEFAULT_REGION = "us-east-1" -IMAGE_ID_MAP = { - "us-east-1": "ami-05e7d7bd279ea4f14", - "ap-east-1": "ami-0c092a5b8be4116f5" -} - -INSTANCE_TYPE = "t3.medium" - -NETWORK_INTERFACE_MAP = { - "us-east-1": [ - { - "SubnetId": "subnet-037edfff66c2eb894", - "AssociatePublicIpAddress": True, - "DeviceIndex": 0, - "Groups": [ - "sg-0342574803206ee9c" - ] - } - ], - "ap-east-1": [ - { - "SubnetId": "subnet-011060501be0b589c", - "AssociatePublicIpAddress": True, - "DeviceIndex": 0, - "Groups": [ - "sg-090470e64df78f6eb" - ] - } - ] -} - - -def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None): - """分配带有代理配置的VM""" - from .provider_with_proxy import AWSProviderWithProxy - - # 初始化代理池(如果还没有初始化) - if proxy_config_file: - init_proxy_pool(proxy_config_file) - - # 获取当前代理 - proxy_pool = get_global_proxy_pool() - current_proxy = proxy_pool.get_next_proxy() - - if current_proxy: - logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}") - - # 创建provider实例 - provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file) - - # 创建新实例 - instance_id = provider.create_instance_with_proxy( - image_id=IMAGE_ID_MAP[region], - instance_type=INSTANCE_TYPE, - security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"], - subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"] - ) - - return instance_id - - -class AWSVMManagerWithProxy(VMManager): - def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None): - self.registry_path = registry_path - self.lock = FileLock(".aws_proxy_lck", timeout=60) - self.proxy_config_file = proxy_config_file - self.initialize_registry() - - # 初始化代理池 - if proxy_config_file: - init_proxy_pool(proxy_config_file) - logger.info(f"Proxy pool initialized with config: {proxy_config_file}") - - def initialize_registry(self): - with self.lock: - if not os.path.exists(self.registry_path): - with open(self.registry_path, 'w') as file: - file.write('') - - def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True): - if lock_needed: - with self.lock: - self._add_vm(vm_path, region, proxy_info) - else: - self._add_vm(vm_path, region, proxy_info) - - def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None): - with open(self.registry_path, 'r') as file: - lines = file.readlines() - - # 格式: vm_path@region|status|proxy_host:proxy_port - vm_path_at_vm_region = f"{vm_path}@{region}" - proxy_str = "" - if proxy_info: - proxy_str = f"{proxy_info['host']}:{proxy_info['port']}" - - new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n' - new_lines = lines + [new_line] - - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) - - def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): - if lock_needed: - with self.lock: - self._delete_vm(vm_path, region) - else: - self._delete_vm(vm_path, region) - - def _delete_vm(self, vm_path, region=DEFAULT_REGION): - new_lines = [] - with open(self.registry_path, 'r') as file: - lines = file.readlines() - for line in lines: - parts = line.strip().split('|') - if len(parts) >= 2: - vm_path_at_vm_region = parts[0] - if vm_path_at_vm_region == f"{vm_path}@{region}": - continue - new_lines.append(line) - - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) - - def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True): - if lock_needed: - with self.lock: - self._occupy_vm(vm_path, pid, region) - else: - self._occupy_vm(vm_path, pid, region) - - def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION): - new_lines = [] - with open(self.registry_path, 'r') as file: - lines = file.readlines() - for line in lines: - parts = line.strip().split('|') - if len(parts) >= 2: - registered_vm_path = parts[0] - if registered_vm_path == f"{vm_path}@{region}": - proxy_str = parts[2] if len(parts) > 2 else "" - new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n') - else: - new_lines.append(line) - else: - new_lines.append(line) - - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) - - def check_and_clean(self, lock_needed=True): - if lock_needed: - with self.lock: - self._check_and_clean() - else: - self._check_and_clean() - - def _check_and_clean(self): - # Get active PIDs - active_pids = {p.pid for p in psutil.process_iter()} - - new_lines = [] - vm_path_at_vm_regions = {} - - with open(self.registry_path, 'r') as file: - lines = file.readlines() - - # Collect all VM paths and their regions - for line in lines: - parts = line.strip().split('|') - if len(parts) >= 2: - vm_path_at_vm_region = parts[0] - status = parts[1] - proxy_str = parts[2] if len(parts) > 2 else "" - - vm_path, vm_region = vm_path_at_vm_region.split("@") - if vm_region not in vm_path_at_vm_regions: - vm_path_at_vm_regions[vm_region] = [] - vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str)) - - # Process each region - for region, vm_info_list in vm_path_at_vm_regions.items(): - ec2_client = boto3.client('ec2', region_name=region) - instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list] - - try: - response = ec2_client.describe_instances(InstanceIds=instance_ids) - reservations = response.get('Reservations', []) - - terminated_ids = set() - stopped_ids = set() - active_ids = set() - - for reservation in reservations: - for instance in reservation.get('Instances', []): - instance_id = instance.get('InstanceId') - instance_state = instance['State']['Name'] - if instance_state in ['terminated', 'shutting-down']: - terminated_ids.add(instance_id) - elif instance_state == 'stopped': - stopped_ids.add(instance_id) - else: - active_ids.add(instance_id) - - for vm_path_at_vm_region, status, proxy_str in vm_info_list: - vm_path = vm_path_at_vm_region.split('@')[0] - - if vm_path in terminated_ids: - logger.info(f"VM {vm_path} not found or terminated, releasing it.") - continue - elif vm_path in stopped_ids: - logger.info(f"VM {vm_path} stopped, mark it as free") - new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') - continue - - if status == "free": - new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') - elif status.isdigit() and int(status) in active_pids: - new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') - else: - new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') - - except Exception as e: - logger.error(f"Error checking instances in region {region}: {e}") - continue - - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) - - def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True): - if lock_needed: - with self.lock: - return self._list_free_vms(region) - else: - return self._list_free_vms(region) - - def _list_free_vms(self, region=DEFAULT_REGION): - free_vms = [] - with open(self.registry_path, 'r') as file: - lines = file.readlines() - for line in lines: - parts = line.strip().split('|') - if len(parts) >= 2: - vm_path_at_vm_region = parts[0] - status = parts[1] - proxy_str = parts[2] if len(parts) > 2 else "" - - vm_path, vm_region = vm_path_at_vm_region.split("@") - if status == "free" and vm_region == region: - free_vms.append((vm_path, status, proxy_str)) - - return free_vms - - def get_vm_path(self, region=DEFAULT_REGION): - with self.lock: - if not AWSVMManagerWithProxy.checked_and_cleaned: - AWSVMManagerWithProxy.checked_and_cleaned = True - self._check_and_clean() - - allocation_needed = False - with self.lock: - free_vms_paths = self._list_free_vms(region) - - if len(free_vms_paths) == 0: - allocation_needed = True - else: - chosen_vm_path, _, proxy_str = free_vms_paths[0] - self._occupy_vm(chosen_vm_path, os.getpid(), region) - logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}") - return chosen_vm_path - - if allocation_needed: - logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕") - new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file) - - # 获取当前使用的代理信息 - proxy_pool = get_global_proxy_pool() - current_proxy = proxy_pool.get_next_proxy() - proxy_info = None - if current_proxy: - proxy_info = { - 'host': current_proxy.host, - 'port': current_proxy.port - } - - with self.lock: - self._add_vm(new_vm_path, region, proxy_info) - self._occupy_vm(new_vm_path, os.getpid(), region) - return new_vm_path - - def get_proxy_stats(self): - """获取代理池统计信息""" - proxy_pool = get_global_proxy_pool() - return proxy_pool.get_stats() - - def test_all_proxies(self): - """测试所有代理""" - proxy_pool = get_global_proxy_pool() - return proxy_pool.test_all_proxies() - - def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION): - """为特定VM强制轮换代理""" - logger.info(f"Force rotating proxy for VM {vm_path}") - - # 这里需要重新创建实例来应用新的代理配置 - # 在实际应用中,可能需要保存当前状态并恢复 - proxy_pool = get_global_proxy_pool() - new_proxy = proxy_pool.get_next_proxy() - - if new_proxy: - logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}") - return True - else: - logger.warning(f"No available proxy for VM {vm_path}") - return False \ No newline at end of file diff --git a/mm_agents/openai_cua_agent.py b/mm_agents/openai_cua_agent.py index f3ecc13..d3fefbd 100644 --- a/mm_agents/openai_cua_agent.py +++ b/mm_agents/openai_cua_agent.py @@ -309,7 +309,21 @@ class OpenAICUAAgent: logger.error(f"OpenAI API error: {str(e)}") new_screenshot = self.env._get_obs() new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8') - self.cua_messages[-1]["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}" + + # Update the image in the last message based on its structure + last_message = self.cua_messages[-1] + if "output" in last_message: + # Computer call output message structure + last_message["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}" + elif "content" in last_message: + # User message structure - find and update the image content + for content_item in last_message["content"]: + if content_item.get("type") == "input_image": + content_item["image_url"] = f"data:image/png;base64,{new_screenshot_base64}" + break + else: + logger.warning("Unknown message structure, cannot update screenshot") + retry_count += 1 time.sleep(1) raise Exception("Failed to make OpenAI API call after 3 retries") @@ -452,10 +466,7 @@ class OpenAICUAAgent: logger.warning("Empty text for type action") return "import pyautogui\n# Empty text, no action taken" - pattern = r"(?