Merge remote-tracking branch 'upstream/feat/aws-provider-support'

This commit is contained in:
yuanmengqi
2025-06-06 10:22:56 +00:00
10 changed files with 163 additions and 399 deletions

View File

@@ -121,7 +121,7 @@ class SetupController:
logger.error(
f"Failed to download {url} caused by {e}. Retrying... ({max_retries - i - 1} attempts left)")
if not downloaded:
raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}")
raise requests.RequestException(f"Failed to download {url}. No retries left.")
form = MultipartEncoder({
"file_path": path,

View File

@@ -20,14 +20,13 @@ def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: b
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
return VirtualBoxVMManager(), VirtualBoxProvider(region)
elif provider_name in ["aws", "amazon web services"]:
from desktop_env.providers.aws.manager import AWSVMManager
if use_proxy:
# Use proxy-enabled AWS provider
from desktop_env.providers.aws.manager_with_proxy import AWSVMManagerWithProxy
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
return AWSVMManagerWithProxy(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
return AWSVMManager(proxy_config_file="dataimpulse_proxy_config.json"), AWSProviderWithProxy(region, proxy_config_file="dataimpulse_proxy_config.json")
else:
# Use regular AWS provider
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
return AWSVMManager(), AWSProvider(region)
elif provider_name == "azure":

View File

@@ -18,11 +18,16 @@ if not os.getenv('AWS_SUBNET_ID') or not os.getenv('AWS_SECURITY_GROUP_ID'):
from desktop_env.providers.base import VMManager
# Import proxy-related modules only when needed
try:
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
PROXY_SUPPORT_AVAILABLE = True
except ImportError:
PROXY_SUPPORT_AVAILABLE = False
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
logger.setLevel(logging.INFO)
REGISTRY_PATH = '.aws_vms'
DEFAULT_REGION = "us-east-1"
# todo: Add doc for the configuration of image, security group and network interface
# todo: public the AMI images
@@ -118,17 +123,55 @@ def _allocate_vm(region=DEFAULT_REGION):
return instance_id
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
"""Allocate a VM with proxy configuration"""
if not PROXY_SUPPORT_AVAILABLE:
logger.warning("Proxy support not available, falling back to regular VM allocation")
return _allocate_vm(region)
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
# Initialize proxy pool if needed
if proxy_config_file:
init_proxy_pool(proxy_config_file)
# Get current proxy
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
if current_proxy:
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
# Create provider instance
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
# Create new instance
instance_id = provider.create_instance_with_proxy(
image_id=IMAGE_ID_MAP[region],
instance_type=INSTANCE_TYPE,
security_groups=[os.getenv('AWS_SECURITY_GROUP_ID')],
subnet_id=os.getenv('AWS_SUBNET_ID')
)
return instance_id
class AWSVMManager(VMManager):
"""
AWS VM Manager for managing virtual machines on AWS.
AWS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs.
This class remains the interface of VMManager for compatibility with other components.
This class supports both regular VM allocation and proxy-enabled VM allocation.
"""
def __init__(self, registry_path=REGISTRY_PATH):
self.registry_path = registry_path
def __init__(self, proxy_config_file=None, **kwargs):
self.proxy_config_file = proxy_config_file
# self.lock = FileLock(".aws_lck", timeout=60)
self.initialize_registry()
# Initialize proxy pool if proxy configuration is provided
if proxy_config_file and PROXY_SUPPORT_AVAILABLE:
init_proxy_pool(proxy_config_file)
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
def initialize_registry(self, **kwargs):
pass
@@ -164,6 +207,10 @@ class AWSVMManager(VMManager):
pass
def get_vm_path(self, region=DEFAULT_REGION, **kwargs):
logger.info("Allocating a new VM in region: {}".format(region))
new_vm_path = _allocate_vm(region)
if self.proxy_config_file:
logger.info("Allocating a new VM with proxy configuration in region: {}".format(region))
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
else:
logger.info("Allocating a new VM in region: {}".format(region))
new_vm_path = _allocate_vm(region)
return new_vm_path

View File

@@ -1,329 +0,0 @@
import os
from filelock import FileLock
import boto3
import psutil
import logging
from desktop_env.providers.base import VMManager
from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool, init_proxy_pool
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManagerWithProxy")
logger.setLevel(logging.INFO)
REGISTRY_PATH = '.aws_vms_proxy'
DEFAULT_REGION = "us-east-1"
IMAGE_ID_MAP = {
"us-east-1": "ami-05e7d7bd279ea4f14",
"ap-east-1": "ami-0c092a5b8be4116f5"
}
INSTANCE_TYPE = "t3.medium"
NETWORK_INTERFACE_MAP = {
"us-east-1": [
{
"SubnetId": "subnet-037edfff66c2eb894",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-0342574803206ee9c"
]
}
],
"ap-east-1": [
{
"SubnetId": "subnet-011060501be0b589c",
"AssociatePublicIpAddress": True,
"DeviceIndex": 0,
"Groups": [
"sg-090470e64df78f6eb"
]
}
]
}
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
"""分配带有代理配置的VM"""
from .provider_with_proxy import AWSProviderWithProxy
# 初始化代理池(如果还没有初始化)
if proxy_config_file:
init_proxy_pool(proxy_config_file)
# 获取当前代理
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
if current_proxy:
logger.info(f"Allocating VM with proxy: {current_proxy.host}:{current_proxy.port}")
# 创建provider实例
provider = AWSProviderWithProxy(region=region, proxy_config_file=proxy_config_file)
# 创建新实例
instance_id = provider.create_instance_with_proxy(
image_id=IMAGE_ID_MAP[region],
instance_type=INSTANCE_TYPE,
security_groups=NETWORK_INTERFACE_MAP[region][0]["Groups"],
subnet_id=NETWORK_INTERFACE_MAP[region][0]["SubnetId"]
)
return instance_id
class AWSVMManagerWithProxy(VMManager):
def __init__(self, registry_path=REGISTRY_PATH, proxy_config_file=None):
self.registry_path = registry_path
self.lock = FileLock(".aws_proxy_lck", timeout=60)
self.proxy_config_file = proxy_config_file
self.initialize_registry()
# 初始化代理池
if proxy_config_file:
init_proxy_pool(proxy_config_file)
logger.info(f"Proxy pool initialized with config: {proxy_config_file}")
def initialize_registry(self):
with self.lock:
if not os.path.exists(self.registry_path):
with open(self.registry_path, 'w') as file:
file.write('')
def add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None, lock_needed=True):
if lock_needed:
with self.lock:
self._add_vm(vm_path, region, proxy_info)
else:
self._add_vm(vm_path, region, proxy_info)
def _add_vm(self, vm_path, region=DEFAULT_REGION, proxy_info=None):
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# 格式: vm_path@region|status|proxy_host:proxy_port
vm_path_at_vm_region = f"{vm_path}@{region}"
proxy_str = ""
if proxy_info:
proxy_str = f"{proxy_info['host']}:{proxy_info['port']}"
new_line = f'{vm_path_at_vm_region}|free|{proxy_str}\n'
new_lines = lines + [new_line]
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._delete_vm(vm_path, region)
else:
self._delete_vm(vm_path, region)
def _delete_vm(self, vm_path, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
if vm_path_at_vm_region == f"{vm_path}@{region}":
continue
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
self._occupy_vm(vm_path, pid, region)
else:
self._occupy_vm(vm_path, pid, region)
def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
registered_vm_path = parts[0]
if registered_vm_path == f"{vm_path}@{region}":
proxy_str = parts[2] if len(parts) > 2 else ""
new_lines.append(f'{registered_vm_path}|{pid}|{proxy_str}\n')
else:
new_lines.append(line)
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def check_and_clean(self, lock_needed=True):
if lock_needed:
with self.lock:
self._check_and_clean()
else:
self._check_and_clean()
def _check_and_clean(self):
# Get active PIDs
active_pids = {p.pid for p in psutil.process_iter()}
new_lines = []
vm_path_at_vm_regions = {}
with open(self.registry_path, 'r') as file:
lines = file.readlines()
# Collect all VM paths and their regions
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if vm_region not in vm_path_at_vm_regions:
vm_path_at_vm_regions[vm_region] = []
vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, status, proxy_str))
# Process each region
for region, vm_info_list in vm_path_at_vm_regions.items():
ec2_client = boto3.client('ec2', region_name=region)
instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list]
try:
response = ec2_client.describe_instances(InstanceIds=instance_ids)
reservations = response.get('Reservations', [])
terminated_ids = set()
stopped_ids = set()
active_ids = set()
for reservation in reservations:
for instance in reservation.get('Instances', []):
instance_id = instance.get('InstanceId')
instance_state = instance['State']['Name']
if instance_state in ['terminated', 'shutting-down']:
terminated_ids.add(instance_id)
elif instance_state == 'stopped':
stopped_ids.add(instance_id)
else:
active_ids.add(instance_id)
for vm_path_at_vm_region, status, proxy_str in vm_info_list:
vm_path = vm_path_at_vm_region.split('@')[0]
if vm_path in terminated_ids:
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
continue
elif vm_path in stopped_ids:
logger.info(f"VM {vm_path} stopped, mark it as free")
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
continue
if status == "free":
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
elif status.isdigit() and int(status) in active_pids:
new_lines.append(f'{vm_path}@{region}|{status}|{proxy_str}\n')
else:
new_lines.append(f'{vm_path}@{region}|free|{proxy_str}\n')
except Exception as e:
logger.error(f"Error checking instances in region {region}: {e}")
continue
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)
def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
if lock_needed:
with self.lock:
return self._list_free_vms(region)
else:
return self._list_free_vms(region)
def _list_free_vms(self, region=DEFAULT_REGION):
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
parts = line.strip().split('|')
if len(parts) >= 2:
vm_path_at_vm_region = parts[0]
status = parts[1]
proxy_str = parts[2] if len(parts) > 2 else ""
vm_path, vm_region = vm_path_at_vm_region.split("@")
if status == "free" and vm_region == region:
free_vms.append((vm_path, status, proxy_str))
return free_vms
def get_vm_path(self, region=DEFAULT_REGION):
with self.lock:
if not AWSVMManagerWithProxy.checked_and_cleaned:
AWSVMManagerWithProxy.checked_and_cleaned = True
self._check_and_clean()
allocation_needed = False
with self.lock:
free_vms_paths = self._list_free_vms(region)
if len(free_vms_paths) == 0:
allocation_needed = True
else:
chosen_vm_path, _, proxy_str = free_vms_paths[0]
self._occupy_vm(chosen_vm_path, os.getpid(), region)
logger.info(f"Using existing VM {chosen_vm_path} with proxy: {proxy_str}")
return chosen_vm_path
if allocation_needed:
logger.info("No free virtual machine available. Generating a new one with proxy configuration...☕")
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
# 获取当前使用的代理信息
proxy_pool = get_global_proxy_pool()
current_proxy = proxy_pool.get_next_proxy()
proxy_info = None
if current_proxy:
proxy_info = {
'host': current_proxy.host,
'port': current_proxy.port
}
with self.lock:
self._add_vm(new_vm_path, region, proxy_info)
self._occupy_vm(new_vm_path, os.getpid(), region)
return new_vm_path
def get_proxy_stats(self):
"""获取代理池统计信息"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.get_stats()
def test_all_proxies(self):
"""测试所有代理"""
proxy_pool = get_global_proxy_pool()
return proxy_pool.test_all_proxies()
def force_rotate_proxy_for_vm(self, vm_path, region=DEFAULT_REGION):
"""为特定VM强制轮换代理"""
logger.info(f"Force rotating proxy for VM {vm_path}")
# 这里需要重新创建实例来应用新的代理配置
# 在实际应用中,可能需要保存当前状态并恢复
proxy_pool = get_global_proxy_pool()
new_proxy = proxy_pool.get_next_proxy()
if new_proxy:
logger.info(f"New proxy for VM {vm_path}: {new_proxy.host}:{new_proxy.port}")
return True
else:
logger.warning(f"No available proxy for VM {vm_path}")
return False

View File

@@ -47,48 +47,62 @@ class AWSProviderWithProxy(Provider):
proxy_url = self._format_proxy_url(self.current_proxy)
user_data_script = f"""#!/bin/bash
# 配置系统代理
echo 'export http_proxy={proxy_url}' >> /etc/environment
echo 'export https_proxy={proxy_url}' >> /etc/environment
echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment
echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment
# Configure system proxy
echo 'export http_proxy={proxy_url}' >> /etc/environment
echo 'export https_proxy={proxy_url}' >> /etc/environment
echo 'export HTTP_PROXY={proxy_url}' >> /etc/environment
echo 'export HTTPS_PROXY={proxy_url}' >> /etc/environment
# 配置apt代理
cat > /etc/apt/apt.conf.d/95proxy << EOF
Acquire::http::Proxy "{proxy_url}";
Acquire::https::Proxy "{proxy_url}";
EOF
# Configure apt proxy
cat > /etc/apt/apt.conf.d/95proxy << EOF
Acquire::http::Proxy "{proxy_url}";
Acquire::https::Proxy "{proxy_url}";
EOF
# 配置chrome/chromium代理
mkdir -p /etc/opt/chrome/policies/managed
cat > /etc/opt/chrome/policies/managed/proxy.json << EOF
{{
"ProxyMode": "fixed_servers",
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
}}
EOF
# Configure chrome/chromium proxy
mkdir -p /etc/opt/chrome/policies/managed
cat > /etc/opt/chrome/policies/managed/proxy.json << EOF
{{
"ProxyMode": "fixed_servers",
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
}}
EOF
# 配置firefox代理
mkdir -p /etc/firefox/policies
cat > /etc/firefox/policies/policies.json << EOF
{{
"policies": {{
"Proxy": {{
"Mode": "manual",
"HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"UseHTTPProxyForAllProtocols": true
}}
}}
}}
EOF
# Configure chromium proxy (Ubuntu default)
mkdir -p /etc/chromium/policies/managed
cat > /etc/chromium/policies/managed/proxy.json << EOF
{{
"ProxyMode": "fixed_servers",
"ProxyServer": "{self.current_proxy.host}:{self.current_proxy.port}"
}}
EOF
# 重新加载环境变量
source /etc/environment
# Configure firefox proxy - support multiple possible paths
for firefox_dir in /etc/firefox/policies /usr/lib/firefox/distribution/policies /etc/firefox-esr/policies; do
if [ -d "$(dirname "$firefox_dir")" ]; then
mkdir -p "$firefox_dir"
cat > "$firefox_dir/policies.json" << EOF
{{
"policies": {{
"Proxy": {{
"Mode": "manual",
"HTTPProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"HTTPSProxy": "{self.current_proxy.host}:{self.current_proxy.port}",
"UseHTTPProxyForAllProtocols": true
}}
}}
}}
EOF
break
fi
done
# 记录代理配置日志
echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log
"""
# Reload environment variables
source /etc/environment
# Log proxy configuration
echo "$(date): Configured proxy {self.current_proxy.host}:{self.current_proxy.port}" >> /var/log/proxy-setup.log
"""
return base64.b64encode(user_data_script.encode()).decode()
@@ -99,7 +113,7 @@ class AWSProviderWithProxy(Provider):
else:
return f"{proxy.protocol}://{proxy.host}:{proxy.port}"
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs):
logger.info("Starting AWS VM with proxy configuration...")
ec2_client = boto3.client('ec2', region_name=self.region)

View File

@@ -305,7 +305,25 @@ class OpenAICUAAgent:
logger.info(f"Response: {response}")
return response
except Exception as e:
logger.error(f"OpenAI API error: {str(e)}will retry in 1s...")
logger.error(f"OpenAI API error: {str(e)}")
new_screenshot = self.env._get_obs()
new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8')
# Update the image in the last message based on its structure
last_message = self.cua_messages[-1]
if "output" in last_message:
# Computer call output message structure
last_message["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
elif "content" in last_message:
# User message structure - find and update the image content
for content_item in last_message["content"]:
if content_item.get("type") == "input_image":
content_item["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
break
else:
logger.warning("Unknown message structure, cannot update screenshot")
retry_count += 1
time.sleep(1)
def _handle_item(self, item: Dict[str, Any]) -> Optional[Union[str, Dict[str, Any]]]:
@@ -446,10 +464,7 @@ class OpenAICUAAgent:
logger.warning("Empty text for type action")
return "import pyautogui\n# Empty text, no action taken"
pattern = r"(?<!\\)'"
text = re.sub(pattern, r"\\'", text)
# 使用三重引号来确保字符串内容不会破坏格式
# Use repr() to properly escape the string content without double-escaping
pyautogui_code = f"""import pyautogui\npyautogui.typewrite({repr(text)})"""
logger.info(f"Pyautogui code: {pyautogui_code}")
return pyautogui_code

View File

@@ -2,10 +2,13 @@
# Do not write any secret keys or sensitive information here.
# Monitor configuration
TASK_CONFIG_PATH=../evaluation_examples/test_all_error.json
TASK_CONFIG_PATH=../evaluation_examples/test_all.json
EXAMPLES_BASE_PATH=../evaluation_examples/examples
RESULTS_BASE_PATH=../results_all_error/pyautogui/screenshot/computer-use-preview
RESULTS_BASE_PATH=../results_operator_aws
ACTION_SPACE=pyautogui
OBSERVATION_TYPE=screenshot
MODEL_NAME=computer-use-preview
MAX_STEPS=150
FLASK_PORT=80
FLASK_HOST=0.0.0.0
FLASK_DEBUG=false
FLASK_DEBUG=true

View File

@@ -1,5 +1,3 @@
version: '3'
services:
monitor:
build:
@@ -9,10 +7,11 @@ services:
- "${FLASK_PORT:-8080}:8080"
volumes:
- .:/app/monitor
- ../evaluation_examples:/app/evaluation_examples
- ../results_operator_aws:/app/results_operator_aws
- ${TASK_CONFIG_PATH:-../evaluation_examples/test_all.json}:/app/evaluation_examples/test.json
- ${EXAMPLES_BASE_PATH:-../evaluation_examples/examples}:/app/evaluation_examples/examples
- ${RESULTS_BASE_PATH:-../results_operator_aws}:/app/results
env_file:
- .env
environment:
- FLASK_ENV=production
- MONITOR_IN_DOCKER=true
restart: unless-stopped

View File

@@ -17,12 +17,26 @@ TASK_STATUS_CACHE = {}
app = Flask(__name__)
# Load configuration from environment variables
TASK_CONFIG_PATH = os.getenv("TASK_CONFIG_PATH", "../evaluation_examples/test_small.json")
EXAMPLES_BASE_PATH = os.getenv("EXAMPLES_BASE_PATH", "../evaluation_examples/examples")
RESULTS_BASE_PATH = os.getenv("RESULTS_BASE_PATH", "../results_operator_aws/pyautogui/screenshot/computer-use-preview")
MONITOR_IN_DOCKER = os.getenv("MONITOR_IN_DOCKER", "false").lower() == "true"
if MONITOR_IN_DOCKER:
# If running in Docker, use default paths
TASK_CONFIG_PATH = "/app/evaluation_examples/test.json"
EXAMPLES_BASE_PATH = "/app/evaluation_examples/examples"
RESULTS_BASE_PATH = "/app/results"
else:
# Load configuration from environment variables
TASK_CONFIG_PATH = os.getenv("TASK_CONFIG_PATH", "../evaluation_examples/test_small.json")
EXAMPLES_BASE_PATH = os.getenv("EXAMPLES_BASE_PATH", "../evaluation_examples/examples")
RESULTS_BASE_PATH = os.getenv("RESULTS_BASE_PATH", "../results")
ACTION_SPACE=os.getenv("ACTION_SPACE", "pyautogui")
OBSERVATION_TYPE=os.getenv("OBSERVATION_TYPE", "screenshot")
MODEL_NAME=os.getenv("MODEL_NAME", "computer-use-preview")
MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
RESULTS_PATH = os.path.join(RESULTS_BASE_PATH, ACTION_SPACE, OBSERVATION_TYPE, MODEL_NAME)
def load_task_list():
with open(TASK_CONFIG_PATH, 'r') as f:
return json.load(f)
@@ -35,7 +49,7 @@ def get_task_info(task_type, task_id):
return None
def get_task_status(task_type, task_id):
result_dir = os.path.join(RESULTS_BASE_PATH, task_type, task_id)
result_dir = os.path.join(RESULTS_PATH, task_type, task_id)
if not os.path.exists(result_dir):
return {
@@ -167,7 +181,7 @@ def get_task_status_brief(task_type, task_id):
if cache_key in TASK_STATUS_CACHE:
return TASK_STATUS_CACHE[cache_key]
result_dir = os.path.join(RESULTS_BASE_PATH, task_type, task_id)
result_dir = os.path.join(RESULTS_PATH, task_type, task_id)
if not os.path.exists(result_dir):
return {
@@ -367,7 +381,7 @@ def api_tasks_brief():
@app.route('/task/<task_type>/<task_id>/screenshot/<path:filename>')
def task_screenshot(task_type, task_id, filename):
"""Get task screenshot"""
screenshot_path = os.path.join(RESULTS_BASE_PATH, task_type, task_id, filename)
screenshot_path = os.path.join(RESULTS_PATH, task_type, task_id, filename)
if os.path.exists(screenshot_path):
return send_file(screenshot_path, mimetype='image/png')
else:
@@ -376,7 +390,7 @@ def task_screenshot(task_type, task_id, filename):
@app.route('/task/<task_type>/<task_id>/recording')
def task_recording(task_type, task_id):
"""Get task recording video"""
recording_path = os.path.join(RESULTS_BASE_PATH, task_type, task_id, "recording.mp4")
recording_path = os.path.join(RESULTS_PATH, task_type, task_id, "recording.mp4")
if os.path.exists(recording_path):
response = send_file(recording_path, mimetype='video/mp4')
# Add headers to improve mobile compatibility

View File

@@ -237,7 +237,9 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")