diff --git a/desktop_env/providers/__init__.py b/desktop_env/providers/__init__.py index 19d9b98..792db58 100644 --- a/desktop_env/providers/__init__.py +++ b/desktop_env/providers/__init__.py @@ -20,14 +20,13 @@ def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: b from desktop_env.providers.virtualbox.provider import VirtualBoxProvider return VirtualBoxVMManager(), VirtualBoxProvider(region) elif provider_name in ["aws", "amazon web services"]: + from desktop_env.providers.aws.manager import AWSVMManager if use_proxy: # Use proxy-enabled AWS provider - from desktop_env.providers.aws.manager_with_proxy import AWSVMManagerWithProxy from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy - return AWSVMManagerWithProxy(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json") + return AWSVMManager(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json") else: # Use regular AWS provider - from desktop_env.providers.aws.manager import AWSVMManager from desktop_env.providers.aws.provider import AWSProvider return AWSVMManager(), AWSProvider(region) elif provider_name == "azure": diff --git a/desktop_env/providers/aws/manager.py b/desktop_env/providers/aws/manager.py index b9925a0..e28474a 100644 --- a/desktop_env/providers/aws/manager.py +++ b/desktop_env/providers/aws/manager.py @@ -18,11 +18,16 @@ if not os.getenv('AWS_SUBNET_ID') or not os.getenv('AWS_SECURITY_GROUP_ID'): from desktop_env.providers.base import VMManager +# Import proxy-related modules only when needed +try: + from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool + PROXY_SUPPORT_AVAILABLE = True +except ImportError: + PROXY_SUPPORT_AVAILABLE = False + logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager") logger.setLevel(logging.INFO) -REGISTRY_PATH = '.aws_vms' - DEFAULT_REGION = "us-east-1" # todo: Add doc for the configuration of image, security group and network interface # todo: public the AMI images @@ -118,17 +123,55 @@ def _allocate_vm(region=DEFAULT_REGION): return instance_id +def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None): + """Allocate a VM with proxy configuration""" + if not PROXY_SUPPORT_AVAILABLE: + logger.warning("Proxy support not available, falling back to regular VM allocation") + return _allocate_vm(region) + + from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy + + # Initialize proxy pool if needed + if proxy_config_file: + init_proxy_pool(proxy_config_file) + + # Get current proxy + proxy_pool = get_global_proxy_pool() + current_proxy = proxy_pool.get_next_proxy() + + if current_proxy: + logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}") + + # Create provider instance + provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file) + + # Create new instance + instance_id = provider.create_instance_with_proxy( + image_id=IMAGE_ID_MAP[region], + instance_type=INSTANCE_TYPE, + security_groups=[os.getenv('AWS_SECURITY_GROUP_ID')], + subnet_id=os.getenv('AWS_SUBNET_ID') + ) + + return instance_id + + class AWSVMManager(VMManager): """ AWS VM Manager for managing virtual machines on AWS. AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs. - This class remains the interface of VMManager for compatibility with other components. + This class supports both regular VM allocation and proxy-enabled VM allocation. """ - def __init__(self, registry_path=REGISTRY_PATH): - self.registry_path = registry_path + def __init__(self, proxy_config_file=None, **kwargs): + self.proxy_config_file = proxy_config_file # self.lock = FileLock(".aws_lck", timeout=60) self.initialize_registry() + + # Initialize proxy pool if proxy configuration is provided + if proxy_config_file and PROXY_SUPPORT_AVAILABLE: + init_proxy_pool(proxy_config_file) + logger.info(f"Proxy pool initialized with config: {proxy_config_file}") def initialize_registry(self, **kwargs): pass @@ -164,6 +207,10 @@ class AWSVMManager(VMManager): pass def get_vm_path(self, region=DEFAULT_REGION, **kwargs): - logger.info("Allocating a new VM in region: {}".format(region)) - new_vm_path = _allocate_vm(region) + if self.proxy_config_file: + logger.info("Allocating a new VM with proxy configuration in region: {}".format(region)) + new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file) + else: + logger.info("Allocating a new VM in region: {}".format(region)) + new_vm_path = _allocate_vm(region) return new_vm_path diff --git a/desktop_env/providers/aws/manager_with_proxy.py b/desktop_env/providers/aws/manager_with_proxy.py deleted file mode 100644 index f4150ec..0000000 --- a/desktop_env/providers/aws/manager_with_proxy.py +++ /dev/null @@ -1,329 +0,0 @@ -import os -from filelock import FileLock -import boto3 -import psutil -import logging - -from desktop_env.providers.base import VMManager -from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool - -logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy") -logger.setLevel(logging.INFO) - -REGISTRY_PATH = '.aws_vms_proxy' - -DEFAULT_REGION = "us-east-1" -IMAGE_ID_MAP = { - "us-east-1": "ami-05e7d7bd279ea4f14", - "ap-east-1": "ami-0c092a5b8be4116f5" -} - -INSTANCE_TYPE = "t3.medium" - -NETWORK_INTERFACE_MAP = { - "us-east-1": [ - { - "SubnetId": "subnet-037edfff66c2eb894", - "AssociatePublicIpAddress": True, - "DeviceIndex": 0, - "Groups": [ - "sg-0342574803206ee9c" - ] - } - ], - "ap-east-1": [ - { - "SubnetId": "subnet-011060501be0b589c", - "AssociatePublicIpAddress": True, - "DeviceIndex": 0, - "Groups": [ - "sg-090470e64df78f6eb" - ] - } - ] -} - - -def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None): - """分配带有代理配置的VM""" - from .provider_with_proxy import AWSProviderWithProxy - - # 初始化代理池(如果还没有初始化) - if proxy_config_file: - init_proxy_pool(proxy_config_file) - - # 获取当前代理 - proxy_pool = get_global_proxy_pool() - current_proxy = proxy_pool.get_next_proxy() - - if current_proxy: - logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}") - - # 创建provider实例 - provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file) - - # 创建新实例 - instance_id = provider.create_instance_with_proxy( - image_id=IMAGE_ID_MAP[region], - instance_type=INSTANCE_TYPE, - security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"], - subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"] - ) - - return instance_id - - -class AWSVMManagerWithProxy(VMManager): - def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None): - self.registry_path = registry_path - self.lock = FileLock(".aws_proxy_lck", timeout=60) - self.proxy_config_file = proxy_config_file - self.initialize_registry() - - # 初始化代理池 - if proxy_config_file: - init_proxy_pool(proxy_config_file) - logger.info(f"Proxy pool initialized with config: {proxy_config_file}") - - def initialize_registry(self): - with self.lock: - if not os.path.exists(self.registry_path): - with open(self.registry_path, 'w') as file: - file.write('') - - def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True): - if lock_needed: - with self.lock: - self._add_vm(vm_path, region, proxy_info) - else: - self._add_vm(vm_path, region, proxy_info) - - def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None): - with open(self.registry_path, 'r') as file: - lines = file.readlines() - - # 格式: vm_path@region|status|proxy_host:proxy_port - vm_path_at_vm_region = f"{vm_path}@{region}" - proxy_str = "" - if proxy_info: - proxy_str = f"{proxy_info['host']}:{proxy_info['port']}" - - new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n' - new_lines = lines + [new_line] - - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) - - def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True): - if lock_needed: - with self.lock: - self._delete_vm(vm_path, region) - else: - self._delete_vm(vm_path, region) - - def _delete_vm(self, vm_path, region=DEFAULT_REGION): - new_lines = [] - with open(self.registry_path, 'r') as file: - lines = file.readlines() - for line in lines: - parts = line.strip().split('|') - if len(parts) >= 2: - vm_path_at_vm_region = parts[0] - if vm_path_at_vm_region == f"{vm_path}@{region}": - continue - new_lines.append(line) - - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) - - def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True): - if lock_needed: - with self.lock: - self._occupy_vm(vm_path, pid, region) - else: - self._occupy_vm(vm_path, pid, region) - - def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION): - new_lines = [] - with open(self.registry_path, 'r') as file: - lines = file.readlines() - for line in lines: - parts = line.strip().split('|') - if len(parts) >= 2: - registered_vm_path = parts[0] - if registered_vm_path == f"{vm_path}@{region}": - proxy_str = parts[2] if len(parts) > 2 else "" - new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n') - else: - new_lines.append(line) - else: - new_lines.append(line) - - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) - - def check_and_clean(self, lock_needed=True): - if lock_needed: - with self.lock: - self._check_and_clean() - else: - self._check_and_clean() - - def _check_and_clean(self): - # Get active PIDs - active_pids = {p.pid for p in psutil.process_iter()} - - new_lines = [] - vm_path_at_vm_regions = {} - - with open(self.registry_path, 'r') as file: - lines = file.readlines() - - # Collect all VM paths and their regions - for line in lines: - parts = line.strip().split('|') - if len(parts) >= 2: - vm_path_at_vm_region = parts[0] - status = parts[1] - proxy_str = parts[2] if len(parts) > 2 else "" - - vm_path, vm_region = vm_path_at_vm_region.split("@") - if vm_region not in vm_path_at_vm_regions: - vm_path_at_vm_regions[vm_region] = [] - vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str)) - - # Process each region - for region, vm_info_list in vm_path_at_vm_regions.items(): - ec2_client = boto3.client('ec2', region_name=region) - instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list] - - try: - response = ec2_client.describe_instances(InstanceIds=instance_ids) - reservations = response.get('Reservations', []) - - terminated_ids = set() - stopped_ids = set() - active_ids = set() - - for reservation in reservations: - for instance in reservation.get('Instances', []): - instance_id = instance.get('InstanceId') - instance_state = instance['State']['Name'] - if instance_state in ['terminated', 'shutting-down']: - terminated_ids.add(instance_id) - elif instance_state == 'stopped': - stopped_ids.add(instance_id) - else: - active_ids.add(instance_id) - - for vm_path_at_vm_region, status, proxy_str in vm_info_list: - vm_path = vm_path_at_vm_region.split('@')[0] - - if vm_path in terminated_ids: - logger.info(f"VM {vm_path} not found or terminated, releasing it.") - continue - elif vm_path in stopped_ids: - logger.info(f"VM {vm_path} stopped, mark it as free") - new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') - continue - - if status == "free": - new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') - elif status.isdigit() and int(status) in active_pids: - new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n') - else: - new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n') - - except Exception as e: - logger.error(f"Error checking instances in region {region}: {e}") - continue - - with open(self.registry_path, 'w') as file: - file.writelines(new_lines) - - def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True): - if lock_needed: - with self.lock: - return self._list_free_vms(region) - else: - return self._list_free_vms(region) - - def _list_free_vms(self, region=DEFAULT_REGION): - free_vms = [] - with open(self.registry_path, 'r') as file: - lines = file.readlines() - for line in lines: - parts = line.strip().split('|') - if len(parts) >= 2: - vm_path_at_vm_region = parts[0] - status = parts[1] - proxy_str = parts[2] if len(parts) > 2 else "" - - vm_path, vm_region = vm_path_at_vm_region.split("@") - if status == "free" and vm_region == region: - free_vms.append((vm_path, status, proxy_str)) - - return free_vms - - def get_vm_path(self, region=DEFAULT_REGION): - with self.lock: - if not AWSVMManagerWithProxy.checked_and_cleaned: - AWSVMManagerWithProxy.checked_and_cleaned = True - self._check_and_clean() - - allocation_needed = False - with self.lock: - free_vms_paths = self._list_free_vms(region) - - if len(free_vms_paths) == 0: - allocation_needed = True - else: - chosen_vm_path, _, proxy_str = free_vms_paths[0] - self._occupy_vm(chosen_vm_path, os.getpid(), region) - logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}") - return chosen_vm_path - - if allocation_needed: - logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕") - new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file) - - # 获取当前使用的代理信息 - proxy_pool = get_global_proxy_pool() - current_proxy = proxy_pool.get_next_proxy() - proxy_info = None - if current_proxy: - proxy_info = { - 'host': current_proxy.host, - 'port': current_proxy.port - } - - with self.lock: - self._add_vm(new_vm_path, region, proxy_info) - self._occupy_vm(new_vm_path, os.getpid(), region) - return new_vm_path - - def get_proxy_stats(self): - """获取代理池统计信息""" - proxy_pool = get_global_proxy_pool() - return proxy_pool.get_stats() - - def test_all_proxies(self): - """测试所有代理""" - proxy_pool = get_global_proxy_pool() - return proxy_pool.test_all_proxies() - - def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION): - """为特定VM强制轮换代理""" - logger.info(f"Force rotating proxy for VM {vm_path}") - - # 这里需要重新创建实例来应用新的代理配置 - # 在实际应用中,可能需要保存当前状态并恢复 - proxy_pool = get_global_proxy_pool() - new_proxy = proxy_pool.get_next_proxy() - - if new_proxy: - logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}") - return True - else: - logger.warning(f"No available proxy for VM {vm_path}") - return False \ No newline at end of file diff --git a/mm_agents/openai_cua_agent.py b/mm_agents/openai_cua_agent.py index f3ecc13..d3fefbd 100644 --- a/mm_agents/openai_cua_agent.py +++ b/mm_agents/openai_cua_agent.py @@ -309,7 +309,21 @@ class OpenAICUAAgent: logger.error(f"OpenAI API error: {str(e)}") new_screenshot = self.env._get_obs() new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8') - self.cua_messages[-1]["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}" + + # Update the image in the last message based on its structure + last_message = self.cua_messages[-1] + if "output" in last_message: + # Computer call output message structure + last_message["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}" + elif "content" in last_message: + # User message structure - find and update the image content + for content_item in last_message["content"]: + if content_item.get("type") == "input_image": + content_item["image_url"] = f"data:image/png;base64,{new_screenshot_base64}" + break + else: + logger.warning("Unknown message structure, cannot update screenshot") + retry_count += 1 time.sleep(1) raise Exception("Failed to make OpenAI API call after 3 retries") @@ -452,10 +466,7 @@ class OpenAICUAAgent: logger.warning("Empty text for type action") return "import pyautogui\n# Empty text, no action taken" - pattern = r"(?