diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 6e76341..d27aa00 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -162,7 +162,7 @@ class DesktopEnv(gym.Env): # vmware, virtualbox are always used as the emulator starts from a dirty state if self.provider_name in {"docker", "aws", "gcp", "azure", "aliyun", "volcengine"}: self.is_environment_used = False - elif self.provider_name in {"vmware", "virtualbox"}: + elif self.provider_name in {"vmware", "virtualbox", "proxmox"}: self.is_environment_used = True else: raise ValueError(f"Invalid provider name: {self.provider_name}") diff --git a/desktop_env/providers/__init__.py b/desktop_env/providers/__init__.py index 1f465ff..1555d83 100644 --- a/desktop_env/providers/__init__.py +++ b/desktop_env/providers/__init__.py @@ -4,7 +4,7 @@ from desktop_env.providers.base import VMManager, Provider def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: bool = False): """ Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name. - + Args: provider_name (str): The name of the provider (e.g., "aws", "vmware", etc.) region (str): The region for the provider @@ -39,5 +39,9 @@ def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: b from desktop_env.providers.volcengine.manager import VolcengineVMManager from desktop_env.providers.volcengine.provider import VolcengineProvider return VolcengineVMManager(), VolcengineProvider() + elif provider_name == "proxmox": + from desktop_env.providers.proxmox.manager import ProxmoxVMManager + from desktop_env.providers.proxmox.provider import ProxmoxProvider + return ProxmoxVMManager(), ProxmoxProvider(region) else: raise NotImplementedError(f"{provider_name} not implemented!") diff --git a/desktop_env/providers/proxmox/__init__.py b/desktop_env/providers/proxmox/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/desktop_env/providers/proxmox/manager.py b/desktop_env/providers/proxmox/manager.py new file mode 100644 index 0000000..e733845 --- /dev/null +++ b/desktop_env/providers/proxmox/manager.py @@ -0,0 +1,49 @@ +import logging +import os + +from desktop_env.providers.base import VMManager + +logger = logging.getLogger("desktopenv.providers.proxmox.ProxmoxVMManager") +logger.setLevel(logging.INFO) + + +class ProxmoxVMManager(VMManager): + """ + Simplified VM manager for Proxmox. + + Unlike VMware/VirtualBox, Proxmox VMs are pre-created on the server. + This manager does not handle VM provisioning, downloading, or local + registry management. The VM ID is passed directly via --path_to_vm. + """ + + def __init__(self, registry_path=""): + pass + + def initialize_registry(self, **kwargs): + pass + + def add_vm(self, vm_path, **kwargs): + pass + + def delete_vm(self, vm_path, **kwargs): + pass + + def occupy_vm(self, vm_path, pid, **kwargs): + pass + + def list_free_vms(self, **kwargs): + return [] + + def check_and_clean(self, **kwargs): + pass + + def get_vm_path(self, os_type="Windows", region=None, screen_size=(1920, 1080), **kwargs): + """Return the VM ID from environment variable or default. + + For Proxmox, the VM is pre-created. The VM ID should be passed + via --path_to_vm argument. This method is only called when + --path_to_vm is not provided. + """ + vmid = os.environ.get("PROXMOX_VM_ID", "102") + logger.info(f"Using Proxmox VM ID: {vmid}") + return vmid diff --git a/desktop_env/providers/proxmox/provider.py b/desktop_env/providers/proxmox/provider.py new file mode 100644 index 0000000..199d131 --- /dev/null +++ b/desktop_env/providers/proxmox/provider.py @@ -0,0 +1,235 @@ +import logging +import os +import subprocess +import time + +import requests + +from desktop_env.providers.base import Provider + +logger = logging.getLogger("desktopenv.providers.proxmox.ProxmoxProvider") +logger.setLevel(logging.INFO) + +WAIT_TIME = 5 +RETRY_INTERVAL = 3 +MAX_WAIT_READY = 300 # seconds to wait for VM HTTP server to be ready + + +class ProxmoxProvider(Provider): + """ + Proxmox VE provider that manages VMs via SSH to the Proxmox host, + executing `qm` commands for VM lifecycle management. + + Configuration via environment variables: + PROXMOX_SSH_HOST: SSH target (default: root@10.10.17.3) + PROXMOX_VM_IP: Fallback VM IP if guest agent is unavailable (default: 10.10.17.10) + """ + + def __init__(self, region: str = None): + super().__init__(region) + self.ssh_host = os.environ.get("PROXMOX_SSH_HOST", "root@10.10.17.3") + self.vm_ip_fallback = os.environ.get("PROXMOX_VM_IP", "10.10.17.10") + self._vm_ip_cache = None + + def _ssh_exec(self, command: str, timeout: int = 120, check: bool = True) -> str: + """Execute a command on the Proxmox host via SSH. + + Args: + command: The command to run on the remote host. + timeout: Timeout in seconds. + check: If True, raise on non-zero exit code. + + Returns: + stdout output as a stripped string. + """ + ssh_cmd = [ + "ssh", + "-o", "StrictHostKeyChecking=no", + "-o", "ConnectTimeout=10", + "-o", "BatchMode=yes", + self.ssh_host, + command, + ] + logger.debug(f"SSH exec: {' '.join(ssh_cmd)}") + try: + result = subprocess.run( + ssh_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding="utf-8", + timeout=timeout, + ) + if check and result.returncode != 0: + logger.error(f"SSH command failed (rc={result.returncode}): {result.stderr.strip()}") + return result.stdout.strip() + except subprocess.TimeoutExpired: + logger.error(f"SSH command timed out after {timeout}s: {command}") + return "" + except Exception as e: + logger.error(f"SSH execution error: {e}") + return "" + + def _get_vm_status(self, vmid: str) -> str: + """Get the current status of a VM (e.g. 'running', 'stopped').""" + output = self._ssh_exec(f"qm status {vmid}") + # output format: "status: running" + if ":" in output: + return output.split(":", 1)[1].strip() + return output.strip() + + def _wait_for_status(self, vmid: str, target_status: str, timeout: int = 120): + """Poll VM status until it matches target_status.""" + start = time.time() + while time.time() - start < timeout: + status = self._get_vm_status(vmid) + logger.info(f"VM {vmid} status: {status} (waiting for {target_status})") + if status == target_status: + return True + time.sleep(RETRY_INTERVAL) + logger.error(f"VM {vmid} did not reach status '{target_status}' within {timeout}s") + return False + + def _wait_for_vm_ready(self, vm_ip: str, server_port: int = 5000, timeout: int = MAX_WAIT_READY): + """Poll the VM's HTTP server until it responds with a screenshot.""" + start = time.time() + url = f"http://{vm_ip}:{server_port}/screenshot" + while time.time() - start < timeout: + try: + response = requests.get(url, timeout=(10, 10)) + if response.status_code == 200: + logger.info(f"VM HTTP server is ready at {url}") + return True + except Exception: + pass + logger.info(f"Waiting for VM HTTP server at {url}...") + time.sleep(RETRY_INTERVAL) + logger.error(f"VM HTTP server at {url} not ready within {timeout}s") + return False + + def start_emulator(self, path_to_vm: str, headless: bool, os_type: str = "Windows"): + """Start the Proxmox VM. + + Args: + path_to_vm: The VM ID as a string (e.g. "102"). + headless: Ignored for Proxmox (VMs are always headless on server). + os_type: OS type of the VM. + """ + vmid = path_to_vm + logger.info(f"Starting Proxmox VM {vmid}...") + print(f"Starting Proxmox VM {vmid}...") + + status = self._get_vm_status(vmid) + if status == "running": + logger.info(f"VM {vmid} is already running.") + else: + self._ssh_exec(f"qm start {vmid}") + if not self._wait_for_status(vmid, "running", timeout=120): + raise RuntimeError(f"Failed to start VM {vmid}") + + # Wait for Flask HTTP server inside VM to be ready + vm_ip = self._resolve_vm_ip(vmid) + self._wait_for_vm_ready(vm_ip) + + def _resolve_vm_ip(self, vmid: str) -> str: + """Try to get VM IP via QEMU Guest Agent, fall back to env var.""" + if self._vm_ip_cache: + return self._vm_ip_cache + + # Try QEMU Guest Agent + try: + output = self._ssh_exec( + f"qm guest cmd {vmid} network-get-interfaces", + timeout=15, + check=False, + ) + if output and "ip-address" in output: + import json + interfaces = json.loads(output) + for iface in interfaces: + for addr in iface.get("ip-addresses", []): + ip = addr.get("ip-address", "") + # Skip loopback and IPv6 link-local + if ip and not ip.startswith("127.") and not ip.startswith("fe80") and ":" not in ip: + logger.info(f"Got VM {vmid} IP from guest agent: {ip}") + self._vm_ip_cache = ip + return ip + except Exception as e: + logger.debug(f"Guest agent query failed: {e}") + + # Fallback to env var / default + logger.info(f"Using fallback VM IP: {self.vm_ip_fallback}") + self._vm_ip_cache = self.vm_ip_fallback + return self.vm_ip_fallback + + def get_ip_address(self, path_to_vm: str) -> str: + """Return the VM's IP address. + + Args: + path_to_vm: The VM ID as a string. + + Returns: + IP address string (e.g. "10.10.17.10"). + """ + vmid = path_to_vm + return self._resolve_vm_ip(vmid) + + def save_state(self, path_to_vm: str, snapshot_name: str): + """Create a snapshot of the VM. + + Args: + path_to_vm: The VM ID. + snapshot_name: Name for the snapshot. + """ + vmid = path_to_vm + logger.info(f"Creating snapshot '{snapshot_name}' for VM {vmid}...") + self._ssh_exec(f"qm snapshot {vmid} {snapshot_name}", timeout=120) + time.sleep(WAIT_TIME) + logger.info(f"Snapshot '{snapshot_name}' created for VM {vmid}.") + + def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): + """Revert the VM to a snapshot and restart it. + + Args: + path_to_vm: The VM ID. + snapshot_name: Name of the snapshot to revert to. + + Returns: + The VM ID (path_to_vm). + """ + vmid = path_to_vm + logger.info(f"Reverting VM {vmid} to snapshot '{snapshot_name}'...") + + # Stop VM first if running + status = self._get_vm_status(vmid) + if status == "running": + self._ssh_exec(f"qm stop {vmid}", timeout=60) + self._wait_for_status(vmid, "stopped", timeout=60) + + # Rollback to snapshot + self._ssh_exec(f"qm rollback {vmid} {snapshot_name}", timeout=120) + time.sleep(WAIT_TIME) + + # Clear IP cache since IP might change after rollback + self._vm_ip_cache = None + + logger.info(f"VM {vmid} reverted to snapshot '{snapshot_name}'.") + return path_to_vm + + def stop_emulator(self, path_to_vm: str, region=None, *args, **kwargs): + """Stop the VM. + + Args: + path_to_vm: The VM ID. + """ + vmid = path_to_vm + logger.info(f"Stopping Proxmox VM {vmid}...") + status = self._get_vm_status(vmid) + if status == "stopped": + logger.info(f"VM {vmid} is already stopped.") + return + + self._ssh_exec(f"qm stop {vmid}", timeout=60) + self._wait_for_status(vmid, "stopped", timeout=60) + self._vm_ip_cache = None + logger.info(f"VM {vmid} stopped.") diff --git a/run.py b/run.py index 016386e..2644359 100644 --- a/run.py +++ b/run.py @@ -87,6 +87,8 @@ def config() -> argparse.Namespace: parser.add_argument("--sleep_after_execution", type=float, default=0.0) parser.add_argument("--max_steps", type=int, default=8) parser.add_argument("--enable_recording", action="store_true", help="Enable video recording (disabled by default)") + parser.add_argument("--inject_steps", action="store_true", default=True, help="Inject metadata steps into agent prompt (default: True)") + parser.add_argument("--no_inject_steps", dest="inject_steps", action="store_false", help="Disable metadata steps injection for ablation study") # agent config parser.add_argument("--max_trajectory_length", type=int, default=3) @@ -139,6 +141,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: "max_tokens": args.max_tokens, "stop_token": args.stop_token, "result_dir": args.result_dir, + "inject_steps": args.inject_steps, } agent = PromptAgent( @@ -198,9 +201,13 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: logger.info(f"[Example ID]: {example_id}") instruction = example["instruction"] - metadata_steps = example.get("metadata", {}).get("steps", "") + if args.inject_steps: + metadata_steps = example.get("metadata", {}).get("steps", "") + else: + metadata_steps = "" logger.info(f"[Instruction]: {instruction}") + logger.info(f"[Inject Steps]: {args.inject_steps}") if metadata_steps: logger.info(f"[Metadata Steps]: {metadata_steps}") # wandb each example config settings diff --git a/run_proxmox.sh b/run_proxmox.sh new file mode 100755 index 0000000..6fe422a --- /dev/null +++ b/run_proxmox.sh @@ -0,0 +1,122 @@ +#!/bin/bash +# ============================================================================= +# Jade-BenchMark-MVP 一键评测脚本(Proxmox 远程虚拟机版) +# ============================================================================= + +# ---------- Proxmox 配置 ---------- +# Proxmox 主机 SSH 地址(格式: user@host) +export PROXMOX_SSH_HOST="root@10.10.17.3" + +# VM 的内网 IP(你的 Mac 能通过内网访问到的 IP) +export PROXMOX_VM_IP="10.10.17.10" + +# ---------- LLM API 配置 ---------- +# OpenAI 兼容代理(同时用于 Agent 模型和 Eval 模型) +export OPENAI_API_KEY="sk-EQGuvk0rS7EG4Cu22cF6D5Cc3a324c88B2E2D432Bc59Bb17" +export OPENAI_BASE_URL="https://vip.apiyi.com/v1" + +# ---------- 评测参数(按需修改) ---------- +PROVIDER="proxmox" +VM_ID="102" # Proxmox 上的 VM ID +MODEL="gpt-5.2-chat-latest" # Agent 模型 +EVAL_MODEL="gemini-3.1-pro-preview" # 评测模型 +MAX_STEPS=50 # 每个任务最大步数(公共评测指南推荐50) +SLEEP_AFTER_EXEC=3 # 每步执行后等待秒数 +TEMPERATURE=0.5 # 生成温度(越低越稳定可复现) +TOP_P=0.9 # nucleus sampling +MAX_TOKENS=16384 # 模型最大输出 token 数 +MAX_TRAJECTORY_LENGTH=3 # 历史轨迹保留长度 +OBSERVATION_TYPE="screenshot_a11y_tree" # 观测类型 +ACTION_SPACE="pyautogui" # 动作空间 +SCREEN_WIDTH=1920 # 屏幕宽度 +SCREEN_HEIGHT=1080 # 屏幕高度 +RESULT_DIR="/Volumes/Castor/课题/results" # 结果输出目录 +TEST_META="evaluation_examples/test_curated.json" # 评测任务列表 +DOMAIN="jade" # 评测领域 +SNAPSHOT_NAME="snapshot" # 快照名称(需提前创建) +INJECT_STEPS=false # 是否注入教程步骤到 Agent prompt(baseline 不注入) + +# ---------- 预检查 ---------- +echo "=== 预检查 ===" + +# 检查 SSH 连通性 +echo -n "SSH 到 Proxmox... " +if ssh -o BatchMode=yes -o ConnectTimeout=5 ${PROXMOX_SSH_HOST} "echo ok" 2>/dev/null | grep -q "ok"; then + echo "✅ 连接成功" +else + echo "❌ SSH 连接失败,请确认:" + echo " 1. 已执行 ssh-copy-id ${PROXMOX_SSH_HOST}" + echo " 2. Proxmox 主机可达" + exit 1 +fi + +# 检查 VM 状态 +echo -n "VM ${VM_ID} 状态... " +VM_STATUS=$(ssh -o BatchMode=yes ${PROXMOX_SSH_HOST} "qm status ${VM_ID}" 2>/dev/null) +echo "${VM_STATUS}" + +# 检查 Flask Server +echo -n "Flask Server (${PROXMOX_VM_IP}:5000)... " +if curl -s --connect-timeout 5 "http://${PROXMOX_VM_IP}:5000/screenshot" -o /dev/null -w "%{http_code}" | grep -q "200"; then + echo "✅ 可访问" +else + echo "⚠️ 不可访问(VM 可能未启动或 Flask 未运行,评测启动时会自动启动 VM)" +fi + +# 检查快照 +echo -n "快照 '${SNAPSHOT_NAME}'... " +SNAPSHOTS=$(ssh -o BatchMode=yes ${PROXMOX_SSH_HOST} "qm listsnapshot ${VM_ID}" 2>/dev/null) +if echo "${SNAPSHOTS}" | grep -q "${SNAPSHOT_NAME}"; then + echo "✅ 已存在" +else + echo "⚠️ 未找到快照 '${SNAPSHOT_NAME}'。" + echo " 批量评测需要快照来回滚环境。可以现在创建:" + echo " ssh ${PROXMOX_SSH_HOST} \"qm snapshot ${VM_ID} ${SNAPSHOT_NAME}\"" + read -p " 是否现在创建快照?(y/N) " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + echo " 正在创建快照..." + ssh ${PROXMOX_SSH_HOST} "qm snapshot ${VM_ID} ${SNAPSHOT_NAME}" + echo " ✅ 快照已创建" + fi +fi + +echo "" +echo "=== 开始评测 ===" +echo "Provider: ${PROVIDER}" +echo "VM ID: ${VM_ID}" +echo "VM IP: ${PROXMOX_VM_IP}" +echo "Model: ${MODEL}" +echo "Eval: ${EVAL_MODEL}" +echo "Domain: ${DOMAIN}" +echo "Results: ${RESULT_DIR}" +echo "Inject: ${INJECT_STEPS}" +echo "" + +# ---------- 运行评测 ---------- +# 构建 inject_steps 参数 +if [ "${INJECT_STEPS}" = true ]; then + INJECT_STEPS_FLAG="--inject_steps" +else + INJECT_STEPS_FLAG="--no_inject_steps" +fi + +python run.py \ + --provider_name "${PROVIDER}" \ + --path_to_vm "${VM_ID}" \ + --observation_type "${OBSERVATION_TYPE}" \ + --action_space "${ACTION_SPACE}" \ + --model "${MODEL}" \ + --eval_model "${EVAL_MODEL}" \ + --temperature "${TEMPERATURE}" \ + --top_p "${TOP_P}" \ + --max_tokens "${MAX_TOKENS}" \ + --max_trajectory_length "${MAX_TRAJECTORY_LENGTH}" \ + --screen_width "${SCREEN_WIDTH}" \ + --screen_height "${SCREEN_HEIGHT}" \ + --sleep_after_execution "${SLEEP_AFTER_EXEC}" \ + --max_steps "${MAX_STEPS}" \ + --result_dir "${RESULT_DIR}" \ + --test_all_meta_path "${TEST_META}" \ + --domain "${DOMAIN}" \ + ${INJECT_STEPS_FLAG}