Initial commit
This commit is contained in:
16
scripts/core/__init__.py
Normal file
16
scripts/core/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
JADE Benchmark 核心模块
|
||||
包含VM控制、轨迹录制、评测等核心功能
|
||||
"""
|
||||
from .jade_env import JadeEnv
|
||||
from .recorder import Recorder, record_interactive
|
||||
from .evaluator import evaluate, load_xrd_data
|
||||
|
||||
__all__ = [
|
||||
'JadeEnv',
|
||||
'Recorder',
|
||||
'record_interactive',
|
||||
'evaluate',
|
||||
'load_xrd_data'
|
||||
]
|
||||
|
||||
178
scripts/core/evaluator.py
Normal file
178
scripts/core/evaluator.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
|
||||
def load_xrd_data(file_path):
|
||||
"""
|
||||
读取 XRD 导出的 txt 文件(如 background_result.txt),跳过头部的 Metadata。
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
data = []
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split()
|
||||
try:
|
||||
values = [float(x) for x in parts]
|
||||
if len(values) >= 2:
|
||||
data.append(values[:2])
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not data:
|
||||
raise ValueError(f"在文件中未找到有效的数值数据: {file_path}")
|
||||
|
||||
return np.array(data)
|
||||
|
||||
def load_peak_report(file_path):
|
||||
"""
|
||||
专门解析 JADE Peak Search Report (.pid)
|
||||
提取表格部分的数值数据
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
peaks = []
|
||||
metadata = {}
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 查找表格开始的位置(在含有 "2-Theta" 的行之后)
|
||||
table_started = False
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 提取元数据示例: JADE: Peak Search Report (72 Peaks, Max P/N = 19.3)
|
||||
if "Peak Search Report" in line:
|
||||
match = re.search(r"\((\d+) Peaks", line)
|
||||
if match:
|
||||
metadata['reported_peak_count'] = int(match.group(1))
|
||||
|
||||
# 识别表头
|
||||
if "2-Theta" in line and "d(" in line:
|
||||
table_started = True
|
||||
continue
|
||||
|
||||
if table_started:
|
||||
parts = line.split()
|
||||
# Peak 报告通常每行有 8-9 个字段
|
||||
if len(parts) >= 6:
|
||||
try:
|
||||
# 尝试将前几个字段转为 float
|
||||
# 0: 2-Theta, 1: d, 2: BG, 3: Height, 4: I%, 5: Area, 6: I%, 7: FWHM
|
||||
peak_data = [float(p) for p in parts[:8]]
|
||||
peaks.append(peak_data)
|
||||
except ValueError:
|
||||
# 如果转换失败,可能是说明文字或空行,跳过
|
||||
continue
|
||||
|
||||
if not peaks:
|
||||
raise ValueError(f"未能从 Peak 报告中解析出有效数据: {file_path}")
|
||||
|
||||
return np.array(peaks), metadata
|
||||
|
||||
def evaluate(gt_path, agent_path, tolerance=1e-4, mode="xrd_data"):
|
||||
"""
|
||||
对比 Ground Truth 和 Agent Output。
|
||||
支持两种模式:
|
||||
- xrd_data: 对比 (2-Theta, Intensity) 原始数据点
|
||||
- peak_report: 对比 Peak 搜索结果列表
|
||||
"""
|
||||
try:
|
||||
if mode == "peak_report":
|
||||
gt_data, gt_meta = load_peak_report(gt_path)
|
||||
agent_data, agent_meta = load_peak_report(agent_path)
|
||||
|
||||
# 对于 Peak 报告,我们主要关注 2-Theta 位置和 Height
|
||||
# 这里对比全表,但放宽容差,因为 Peak Search 的算法可能在不同环境下有极细微差异
|
||||
if gt_data.shape != agent_data.shape:
|
||||
# 如果数量不匹配,直接判定失败
|
||||
return 0, f"失败: Peak 数量不匹配。GT {len(gt_data)}, Agent {len(agent_data)}"
|
||||
else:
|
||||
gt_data = load_xrd_data(gt_path)
|
||||
agent_data = load_xrd_data(agent_path)
|
||||
|
||||
if gt_data.shape != agent_data.shape:
|
||||
return 0, f"失败: 数据维度不匹配。GT 形状 {gt_data.shape}, Agent 形状 {agent_data.shape}"
|
||||
|
||||
diff = np.abs(gt_data - agent_data)
|
||||
max_error = np.max(diff)
|
||||
|
||||
if max_error < tolerance:
|
||||
return 1, f"成功: 最大绝对误差 {max_error:.2e} < 阈值 {tolerance}"
|
||||
else:
|
||||
return 0, f"失败: 最大绝对误差 {max_error:.2e} 超过阈值 {tolerance}"
|
||||
|
||||
except Exception as e:
|
||||
return 0, f"错误: {str(e)}"
|
||||
|
||||
def evaluate_by_config(config_path):
|
||||
"""
|
||||
根据任务配置文件进行评测。
|
||||
"""
|
||||
# 兼容性处理:如果传入的是任务 ID 路径
|
||||
if not os.path.isabs(config_path) and not config_path.startswith('.'):
|
||||
# 尝试补全路径,例如 instructions/smoothing_001.json
|
||||
if not config_path.endswith('.json'):
|
||||
config_path = config_path + ".json"
|
||||
if not os.path.exists(config_path):
|
||||
# 尝试在任务目录下找
|
||||
pass
|
||||
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(config_path)))
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
eval_cfg = config.get('evaluation', {})
|
||||
# 处理相对路径
|
||||
gt_path = os.path.join(base_dir, eval_cfg['ground_truth'])
|
||||
agent_path = os.path.join(base_dir, eval_cfg['target_output'])
|
||||
tolerance = eval_cfg.get('tolerance', 1e-4)
|
||||
|
||||
# 自动识别模式
|
||||
mode = "xrd_data"
|
||||
if gt_path.lower().endswith('.pid') or eval_cfg.get('type') == 'peak_report':
|
||||
mode = "peak_report"
|
||||
# Peak 报告的默认容差放宽一些,因为算法可能受环境微小影响
|
||||
if 'tolerance' not in eval_cfg:
|
||||
tolerance = 1e-2
|
||||
|
||||
print(f"--- 正在执行评测: {config.get('id', 'unknown')} ---")
|
||||
print(f"指令: {config.get('instruction')}")
|
||||
print(f"模式: {mode}")
|
||||
|
||||
score, message = evaluate(gt_path, agent_path, tolerance, mode=mode)
|
||||
return score, message
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# 模式 A: Config 模式
|
||||
try:
|
||||
score, message = evaluate_by_config(sys.argv[1])
|
||||
except Exception as e:
|
||||
print(f"配置文件解析失败: {e}")
|
||||
sys.exit(1)
|
||||
elif len(sys.argv) >= 3:
|
||||
# 模式 B: 直接对比模式
|
||||
score, message = evaluate(sys.argv[1], sys.argv[2])
|
||||
else:
|
||||
print("用法:")
|
||||
print(" python scripts/evaluator.py instructions/smoothing_001.json")
|
||||
print(" python scripts/evaluator.py <gt_file_path> <agent_file_path>")
|
||||
sys.exit(1)
|
||||
|
||||
print("-" * 30)
|
||||
print(f"Score: {score}")
|
||||
print(f"Reason: {message}")
|
||||
print("-" * 30)
|
||||
sys.exit(0 if score == 1 else 1)
|
||||
514
scripts/core/jade_env.py
Normal file
514
scripts/core/jade_env.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""
|
||||
JADE Benchmark 环境控制器
|
||||
负责VM的重置、文件注入/收集、截图获取等操作
|
||||
"""
|
||||
import subprocess
|
||||
import time
|
||||
import os
|
||||
import requests
|
||||
from PIL import Image
|
||||
import io
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JadeEnv:
|
||||
"""轻量级JADE虚拟机环境控制器"""
|
||||
|
||||
def __init__(self, vmx_path, snapshot_name="Jade_Ready", vm_ip="192.168.116.129",
|
||||
vm_password=None, guest_username=None, guest_password=None):
|
||||
"""
|
||||
初始化JADE环境
|
||||
|
||||
Args:
|
||||
vmx_path: 虚拟机.vmx文件路径
|
||||
snapshot_name: 快照名称
|
||||
vm_ip: 虚拟机IP地址(用于HTTP通信)
|
||||
vm_password: 虚拟机文件加密密码(-vp参数)
|
||||
guest_username: 虚拟机内操作系统用户名(-gu参数)
|
||||
guest_password: 虚拟机内操作系统密码(-gp参数)
|
||||
"""
|
||||
self.vmx_path = vmx_path
|
||||
self.snapshot_name = snapshot_name
|
||||
self.vm_ip = vm_ip
|
||||
self.vm_url = f"http://{vm_ip}:5000"
|
||||
|
||||
# VMware认证参数
|
||||
self.vm_password = vm_password
|
||||
self.guest_username = guest_username
|
||||
self.guest_password = guest_password
|
||||
|
||||
# VMware Fusion路径(macOS)
|
||||
self.vmrun = "/Applications/VMware Fusion.app/Contents/Library/vmrun"
|
||||
|
||||
# 虚拟机内路径
|
||||
self.guest_desktop = r"C:\Users\lzy\Desktop"
|
||||
|
||||
logger.info(f"JadeEnv初始化: VM={os.path.basename(vmx_path)}, Snapshot={snapshot_name}")
|
||||
logger.info(f" 认证配置: vm_password={'已设置' if vm_password else '未设置'}, "
|
||||
f"guest_user={'已设置' if guest_username else '未设置'}, "
|
||||
f"guest_pass={'已设置' if guest_password else '未设置'}")
|
||||
|
||||
def _build_vmrun_cmd(self, *args):
|
||||
"""构建vmrun命令"""
|
||||
cmd = [self.vmrun, "-T", "fusion"]
|
||||
|
||||
# 添加认证参数
|
||||
if self.vm_password:
|
||||
cmd.extend(["-vp", self.vm_password])
|
||||
if self.guest_username:
|
||||
cmd.extend(["-gu", self.guest_username])
|
||||
if self.guest_password:
|
||||
cmd.extend(["-gp", self.guest_password])
|
||||
|
||||
cmd.extend(args)
|
||||
return cmd
|
||||
|
||||
def _run_vmrun(self, *args, check=True, timeout=30):
|
||||
"""执行vmrun命令"""
|
||||
cmd = self._build_vmrun_cmd(*args)
|
||||
# 打印完整命令(隐藏密码)
|
||||
cmd_display = []
|
||||
skip_next = False
|
||||
for i, part in enumerate(cmd):
|
||||
if skip_next:
|
||||
cmd_display.append("***")
|
||||
skip_next = False
|
||||
elif part in ["-vp", "-gp"]:
|
||||
cmd_display.append(part)
|
||||
skip_next = True
|
||||
else:
|
||||
cmd_display.append(part)
|
||||
logger.info(f"执行vmrun命令: {' '.join(cmd_display)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if check and result.returncode != 0:
|
||||
error_msg = result.stderr or result.stdout
|
||||
raise RuntimeError(f"vmrun命令执行失败: {error_msg}")
|
||||
|
||||
return result
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"❌ vmrun命令超时({timeout}秒)")
|
||||
raise RuntimeError(f"vmrun命令执行超时({timeout}秒)")
|
||||
|
||||
def _detect_and_update_ip(self):
|
||||
"""
|
||||
检测VM的IP地址,如果变化则自动更新
|
||||
|
||||
Returns:
|
||||
bool: IP是否发生变化
|
||||
"""
|
||||
logger.info("🔍 检测VM IP地址...")
|
||||
|
||||
try:
|
||||
# 使用vmrun获取VM IP
|
||||
cmd = self._build_vmrun_cmd("getGuestIPAddress", self.vmx_path)
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
new_ip = result.stdout.strip()
|
||||
if new_ip and new_ip != "":
|
||||
if new_ip != self.vm_ip:
|
||||
logger.info(f"⚠️ IP地址已变化: {self.vm_ip} → {new_ip}")
|
||||
logger.info(f" 自动更新IP地址...")
|
||||
|
||||
# 更新实例变量
|
||||
self.vm_ip = new_ip
|
||||
self.vm_url = f"http://{new_ip}:5000"
|
||||
|
||||
# 更新配置文件
|
||||
try:
|
||||
import json
|
||||
from pathlib import Path
|
||||
# 获取项目根目录(jade_env.py在scripts/core/,向上3级到项目根目录)
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
config_path = project_root / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
config["network"]["vm_ip"] = new_ip
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"✅ 配置文件已更新: {config_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 更新配置文件失败: {e}(不影响使用)")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.info(f"✅ IP地址未变化: {self.vm_ip}")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"⚠️ vmrun返回空IP地址")
|
||||
return False
|
||||
else:
|
||||
error_msg = result.stderr or result.stdout
|
||||
logger.warning(f"⚠️ 获取IP失败: {error_msg}(将使用配置中的IP)")
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"⚠️ 获取IP超时(将使用配置中的IP)")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 检测IP异常: {e}(将使用配置中的IP)")
|
||||
return False
|
||||
|
||||
def reset(self, wait_time=5):
|
||||
"""
|
||||
重置环境:恢复快照并启动虚拟机
|
||||
|
||||
Args:
|
||||
wait_time: 启动后等待时间(秒)
|
||||
"""
|
||||
logger.info(f"正在恢复快照: {self.snapshot_name}...")
|
||||
|
||||
try:
|
||||
# 1. 恢复快照
|
||||
self._run_vmrun("revertToSnapshot", self.vmx_path, self.snapshot_name)
|
||||
logger.info("✅ 快照恢复成功")
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
# 2. 启动虚拟机(如果未运行)
|
||||
logger.info("正在启动虚拟机...")
|
||||
result = self._run_vmrun("start", self.vmx_path, check=False)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("✅ 虚拟机启动成功")
|
||||
else:
|
||||
# 可能已经在运行
|
||||
if "is already running" in result.stderr.lower():
|
||||
logger.info("✅ 虚拟机已在运行")
|
||||
else:
|
||||
logger.warning(f"启动虚拟机警告: {result.stderr}")
|
||||
|
||||
# 3. 等待系统稳定(快照恢复后agent_server已在运行)
|
||||
logger.info(f"等待系统稳定 ({wait_time}秒)...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# 4. 检测并更新IP地址(恢复快照后IP可能变化)
|
||||
self._detect_and_update_ip()
|
||||
|
||||
# 5. 验证HTTP服务可用
|
||||
self._wait_for_http_service()
|
||||
|
||||
logger.info("✅ 环境重置完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 环境重置失败: {e}")
|
||||
raise
|
||||
|
||||
def _wait_for_http_service(self, max_retries=10, retry_interval=3):
|
||||
"""等待agent_server.py HTTP服务可用"""
|
||||
logger.info(f"等待虚拟机HTTP服务... (URL: {self.vm_url})")
|
||||
|
||||
# 绕过代理(避免Clash等代理工具干扰局域网访问)
|
||||
proxies = {
|
||||
'http': None,
|
||||
'https': None
|
||||
}
|
||||
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
logger.debug(f"尝试连接: {self.vm_url}/screen_info (timeout=5秒, 不使用代理)")
|
||||
response = requests.get(f"{self.vm_url}/screen_info", timeout=5, proxies=proxies)
|
||||
logger.debug(f"收到响应: status_code={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
logger.info("✅ HTTP服务已就绪")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"HTTP状态码异常: {response.status_code}")
|
||||
except requests.exceptions.Timeout as e:
|
||||
logger.info(f"HTTP服务未就绪(超时),重试 {i+1}/{max_retries}...")
|
||||
if i < max_retries - 1:
|
||||
time.sleep(retry_interval)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.info(f"HTTP服务未就绪(连接失败: {str(e)[:50]}),重试 {i+1}/{max_retries}...")
|
||||
if i < max_retries - 1:
|
||||
time.sleep(retry_interval)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"HTTP请求异常: {type(e).__name__}: {str(e)[:100]}")
|
||||
if i < max_retries - 1:
|
||||
logger.info(f"重试 {i+1}/{max_retries}...")
|
||||
time.sleep(retry_interval)
|
||||
else:
|
||||
logger.error("❌ HTTP服务超时!")
|
||||
logger.error(f" 最后错误: {e}")
|
||||
return False
|
||||
|
||||
logger.error("❌ HTTP服务超时!请检查agent_server.py是否在VM中运行")
|
||||
logger.info(" 在VM中运行: python agent_server.py")
|
||||
return False
|
||||
|
||||
def inject_file(self, host_path, guest_filename=None):
|
||||
"""
|
||||
将文件从主机注入到虚拟机桌面
|
||||
|
||||
Args:
|
||||
host_path: 主机文件路径
|
||||
guest_filename: 虚拟机中的文件名(默认使用原文件名)
|
||||
"""
|
||||
if not os.path.exists(host_path):
|
||||
raise FileNotFoundError(f"源文件不存在: {host_path}")
|
||||
|
||||
if guest_filename is None:
|
||||
guest_filename = os.path.basename(host_path)
|
||||
|
||||
guest_path = f"{self.guest_desktop}\\{guest_filename}"
|
||||
|
||||
# 获取文件大小
|
||||
file_size = os.path.getsize(host_path)
|
||||
file_size_kb = file_size / 1024
|
||||
|
||||
logger.info(f"注入文件: {os.path.basename(host_path)} ({file_size_kb:.1f}KB) → 虚拟机桌面")
|
||||
logger.info(f" 源路径: {host_path}")
|
||||
logger.info(f" 目标路径: {guest_path}")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
# 使用vmrun传输(30秒超时)
|
||||
self._run_vmrun(
|
||||
"copyFileFromHostToGuest",
|
||||
self.vmx_path,
|
||||
host_path,
|
||||
guest_path,
|
||||
timeout=30
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"✅ 文件注入成功: {guest_filename} (耗时 {elapsed:.1f}秒)")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 文件注入失败: {e}")
|
||||
raise
|
||||
|
||||
def collect_file(self, guest_filename, host_path):
|
||||
"""
|
||||
从虚拟机桌面收集文件到主机
|
||||
|
||||
Args:
|
||||
guest_filename: 虚拟机桌面上的文件名
|
||||
host_path: 主机保存路径
|
||||
"""
|
||||
guest_path = f"{self.guest_desktop}\\{guest_filename}"
|
||||
|
||||
logger.info(f"收集文件: {guest_filename} → {os.path.basename(host_path)}")
|
||||
|
||||
try:
|
||||
# 确保目标目录存在
|
||||
os.makedirs(os.path.dirname(host_path), exist_ok=True)
|
||||
|
||||
# 方法1: 尝试使用vmrun
|
||||
try:
|
||||
self._run_vmrun(
|
||||
"copyFileFromGuestToHost",
|
||||
self.vmx_path,
|
||||
guest_path,
|
||||
host_path
|
||||
)
|
||||
logger.info(f"✅ 文件收集成功(vmrun): {guest_filename}")
|
||||
return
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"vmrun收集失败,尝试HTTP方式: {e}")
|
||||
|
||||
# 方法2: 尝试通过HTTP下载(备用)
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.vm_url}/download/{guest_filename}",
|
||||
timeout=10
|
||||
)
|
||||
if response.status_code == 200:
|
||||
with open(host_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"✅ 文件收集成功(HTTP): {guest_filename}")
|
||||
else:
|
||||
raise RuntimeError(f"HTTP下载失败: {response.status_code}")
|
||||
except Exception as http_error:
|
||||
logger.error(f"❌ HTTP收集也失败: {http_error}")
|
||||
raise RuntimeError(f"文件收集失败(两种方法都失败)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 文件收集失败: {e}")
|
||||
raise
|
||||
|
||||
def get_screenshot(self, retry_with_ip_detect=True):
|
||||
"""
|
||||
获取虚拟机截图
|
||||
|
||||
Args:
|
||||
retry_with_ip_detect: 如果连接失败,是否尝试检测IP并重试
|
||||
|
||||
Returns:
|
||||
PIL.Image对象
|
||||
"""
|
||||
try:
|
||||
# 绕过代理
|
||||
proxies = {'http': None, 'https': None}
|
||||
response = requests.get(f"{self.vm_url}/screenshot", timeout=5, proxies=proxies)
|
||||
if response.status_code == 200:
|
||||
return Image.open(io.BytesIO(response.content))
|
||||
else:
|
||||
raise RuntimeError(f"截图失败: HTTP {response.status_code}")
|
||||
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
|
||||
if retry_with_ip_detect:
|
||||
logger.warning(f"⚠️ 截图连接失败,尝试检测并更新IP...")
|
||||
if self._detect_and_update_ip():
|
||||
# IP已更新,重试一次
|
||||
logger.info(f"🔄 使用新IP重试截图...")
|
||||
return self.get_screenshot(retry_with_ip_detect=False)
|
||||
logger.error(f"❌ 获取截图失败: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取截图失败: {e}")
|
||||
raise
|
||||
|
||||
def get_screen_info(self, retry_with_ip_detect=True):
|
||||
"""
|
||||
获取虚拟机屏幕信息(分辨率、DPI等)
|
||||
|
||||
Args:
|
||||
retry_with_ip_detect: 如果连接失败,是否尝试检测IP并重试
|
||||
|
||||
Returns:
|
||||
dict: 包含screen_width, screen_height, dpi_scale等信息
|
||||
"""
|
||||
try:
|
||||
# 绕过代理
|
||||
proxies = {'http': None, 'https': None}
|
||||
response = requests.get(f"{self.vm_url}/screen_info", timeout=5, proxies=proxies)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
raise RuntimeError(f"获取屏幕信息失败: HTTP {response.status_code}")
|
||||
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
|
||||
if retry_with_ip_detect:
|
||||
logger.warning(f"⚠️ 屏幕信息连接失败,尝试检测并更新IP...")
|
||||
if self._detect_and_update_ip():
|
||||
# IP已更新,重试一次
|
||||
logger.info(f"🔄 使用新IP重试获取屏幕信息...")
|
||||
return self.get_screen_info(retry_with_ip_detect=False)
|
||||
logger.error(f"❌ 获取屏幕信息失败: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取屏幕信息失败: {e}")
|
||||
raise
|
||||
|
||||
def list_desktop_files(self):
|
||||
"""
|
||||
列出虚拟机桌面文件(用于调试)
|
||||
|
||||
Returns:
|
||||
list: 文件名列表
|
||||
"""
|
||||
try:
|
||||
# 绕过代理
|
||||
proxies = {'http': None, 'https': None}
|
||||
response = requests.get(f"{self.vm_url}/list_desktop", timeout=5, proxies=proxies)
|
||||
if response.status_code == 200:
|
||||
return response.json().get('files', [])
|
||||
else:
|
||||
raise RuntimeError(f"列出文件失败: HTTP {response.status_code}")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 列出桌面文件失败: {e}")
|
||||
return []
|
||||
|
||||
def send_action(self, action_type, **params):
|
||||
"""
|
||||
发送动作到虚拟机(用于未来的Agent自动执行)
|
||||
|
||||
Args:
|
||||
action_type: 动作类型 (click/type/hotkey)
|
||||
**params: 动作参数
|
||||
"""
|
||||
try:
|
||||
# 绕过代理
|
||||
proxies = {'http': None, 'https': None}
|
||||
payload = {"type": action_type, **params}
|
||||
response = requests.post(
|
||||
f"{self.vm_url}/action",
|
||||
json=payload,
|
||||
timeout=5,
|
||||
proxies=proxies
|
||||
)
|
||||
if response.status_code == 200:
|
||||
logger.debug(f"动作执行成功: {action_type}")
|
||||
else:
|
||||
raise RuntimeError(f"动作执行失败: HTTP {response.status_code}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 发送动作失败: {e}")
|
||||
raise
|
||||
|
||||
def get_mouse_pos(self):
|
||||
"""
|
||||
从虚拟机获取当前鼠标物理坐标
|
||||
|
||||
Returns:
|
||||
tuple: (x, y) 物理坐标,失败返回 (None, None)
|
||||
"""
|
||||
try:
|
||||
# 绕过代理
|
||||
proxies = {'http': None, 'https': None}
|
||||
response = requests.get(f"{self.vm_url}/mouse_pos", timeout=2, proxies=proxies)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data['x'], data['y']
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.debug(f"获取VM鼠标位置失败: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
print("JadeEnv 测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 配置(需要根据实际情况修改)
|
||||
VMX_PATH = "/Volumes/Castor/虚拟机/Jade_Win_11.vmwarevm/Windows 11 64 位 ARM 2.vmx"
|
||||
SNAPSHOT = "Jade_Ready"
|
||||
VM_PASSWORD = "lizhanyuan"
|
||||
|
||||
try:
|
||||
env = JadeEnv(
|
||||
vmx_path=VMX_PATH,
|
||||
snapshot_name=SNAPSHOT,
|
||||
vm_password=VM_PASSWORD,
|
||||
guest_username="lzy",
|
||||
guest_password="LIZHANYUAN"
|
||||
)
|
||||
|
||||
# 测试重置
|
||||
print("\n测试1: 重置环境")
|
||||
env.reset()
|
||||
|
||||
# 测试获取屏幕信息
|
||||
print("\n测试2: 获取屏幕信息")
|
||||
info = env.get_screen_info()
|
||||
print(f" 分辨率: {info['screen_width']}x{info['screen_height']}")
|
||||
print(f" DPI缩放: {info['dpi_scale']}")
|
||||
|
||||
# 测试列出桌面文件
|
||||
print("\n测试3: 列出桌面文件")
|
||||
files = env.list_desktop_files()
|
||||
print(f" 桌面文件: {files[:5]}..." if len(files) > 5 else f" 桌面文件: {files}")
|
||||
|
||||
print("\n✅ 所有测试通过!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
295
scripts/core/recorder.py
Normal file
295
scripts/core/recorder.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
轨迹录制器
|
||||
监听鼠标键盘事件,记录操作轨迹和截图
|
||||
"""
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pynput import mouse, keyboard
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Recorder:
|
||||
"""轨迹录制器 - Host端事件驱动录制"""
|
||||
|
||||
def __init__(self, jade_env, task_id, output_dir):
|
||||
"""
|
||||
初始化录制器
|
||||
|
||||
Args:
|
||||
jade_env: JadeEnv实例
|
||||
task_id: 任务ID
|
||||
output_dir: 输出目录(human_demo/)
|
||||
"""
|
||||
self.env = jade_env
|
||||
self.task_id = task_id
|
||||
self.output_dir = output_dir
|
||||
|
||||
# 创建输出目录
|
||||
self.screens_dir = os.path.join(output_dir, "screens")
|
||||
os.makedirs(self.screens_dir, exist_ok=True)
|
||||
|
||||
# 数据结构
|
||||
self.actions = []
|
||||
self.metadata = {}
|
||||
self.start_time = None
|
||||
self.screenshot_counter = 0
|
||||
|
||||
# 监听器
|
||||
self.mouse_listener = None
|
||||
self.keyboard_listener = None
|
||||
|
||||
# 状态
|
||||
self.is_recording = False
|
||||
|
||||
logger.info(f"录制器初始化: 任务={task_id}")
|
||||
|
||||
def start(self):
|
||||
"""开始录制"""
|
||||
if self.is_recording:
|
||||
logger.warning("录制已在进行中")
|
||||
return
|
||||
|
||||
self.is_recording = True
|
||||
self.start_time = time.time()
|
||||
|
||||
# 获取虚拟机屏幕信息
|
||||
try:
|
||||
screen_info = self.env.get_screen_info()
|
||||
self.metadata = {
|
||||
"task_id": self.task_id,
|
||||
"vm_resolution": [screen_info['screen_width'], screen_info['screen_height']],
|
||||
"vm_screenshot_resolution": [screen_info['screenshot_width'], screen_info['screenshot_height']],
|
||||
"vm_dpi_scale": screen_info['dpi_scale'],
|
||||
"recording_start": datetime.now().isoformat(),
|
||||
"recording_end": None
|
||||
}
|
||||
logger.info(f"虚拟机分辨率: {screen_info['screen_width']}x{screen_info['screen_height']}")
|
||||
logger.info(f"截图分辨率: {screen_info['screenshot_width']}x{screen_info['screenshot_height']}")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取屏幕信息失败: {e}")
|
||||
self.metadata = {
|
||||
"task_id": self.task_id,
|
||||
"recording_start": datetime.now().isoformat(),
|
||||
"recording_end": None
|
||||
}
|
||||
|
||||
# 记录初始截图
|
||||
self._capture_screenshot("initial")
|
||||
|
||||
# 启动监听器
|
||||
self.mouse_listener = mouse.Listener(
|
||||
on_click=self._on_mouse_click,
|
||||
on_scroll=self._on_mouse_scroll
|
||||
)
|
||||
self.keyboard_listener = keyboard.Listener(
|
||||
on_press=self._on_key_press
|
||||
)
|
||||
|
||||
self.mouse_listener.start()
|
||||
self.keyboard_listener.start()
|
||||
|
||||
logger.info("✅ 录制已启动")
|
||||
print("\n" + "=" * 60)
|
||||
print("🎥 录制进行中...")
|
||||
print("💡 提示:")
|
||||
print(" - 请在VMware窗口中操作JADE")
|
||||
print(" - 每次点击都会自动截图")
|
||||
print(" - 按 Ctrl+C 停止录制")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
def _on_mouse_click(self, x, y, button, pressed):
|
||||
"""鼠标点击事件处理"""
|
||||
if not self.is_recording or not pressed:
|
||||
return
|
||||
|
||||
# 核心修改:立刻从虚拟机获取真实物理坐标
|
||||
vm_x, vm_y = self.env.get_mouse_pos()
|
||||
|
||||
elapsed = time.time() - self.start_time
|
||||
|
||||
# 记录动作
|
||||
action = {
|
||||
"t": round(elapsed, 3),
|
||||
"type": "click",
|
||||
"button": str(button).replace("Button.", ""),
|
||||
"pos_host": [x, y], # Mac 逻辑坐标(留作参考)
|
||||
"pos_vm": [vm_x, vm_y] if vm_x is not None else None # 真实VM物理坐标
|
||||
}
|
||||
|
||||
# 截图
|
||||
screenshot_filename = self._capture_screenshot("click")
|
||||
action["screenshot"] = screenshot_filename
|
||||
|
||||
self.actions.append(action)
|
||||
|
||||
if vm_x is not None:
|
||||
logger.info(f"[{elapsed:.1f}s] 点击: VM({vm_x}, {vm_y}) [Host: {int(x)}, {int(y)}] {action['button']}")
|
||||
else:
|
||||
logger.info(f"[{elapsed:.1f}s] 点击: Host({int(x)}, {int(y)}) [VM获取失败] {action['button']}")
|
||||
|
||||
def _on_mouse_scroll(self, x, y, dx, dy):
|
||||
"""鼠标滚轮事件处理"""
|
||||
if not self.is_recording:
|
||||
return
|
||||
|
||||
elapsed = time.time() - self.start_time
|
||||
|
||||
action = {
|
||||
"t": round(elapsed, 3),
|
||||
"type": "scroll",
|
||||
"pos_host": [x, y],
|
||||
"delta": [dx, dy],
|
||||
"pos_vm": None
|
||||
}
|
||||
|
||||
self.actions.append(action)
|
||||
logger.debug(f"[{elapsed:.1f}s] 滚轮: ({x}, {y}) delta=({dx}, {dy})")
|
||||
|
||||
def _on_key_press(self, key):
|
||||
"""键盘按键事件处理"""
|
||||
if not self.is_recording:
|
||||
return
|
||||
|
||||
elapsed = time.time() - self.start_time
|
||||
|
||||
# 转换按键名称
|
||||
try:
|
||||
if hasattr(key, 'char') and key.char:
|
||||
key_name = key.char
|
||||
else:
|
||||
key_name = str(key).replace("Key.", "")
|
||||
except:
|
||||
key_name = str(key)
|
||||
|
||||
action = {
|
||||
"t": round(elapsed, 3),
|
||||
"type": "key",
|
||||
"key": key_name
|
||||
}
|
||||
|
||||
self.actions.append(action)
|
||||
logger.debug(f"[{elapsed:.1f}s] 按键: {key_name}")
|
||||
|
||||
def _capture_screenshot(self, tag=""):
|
||||
"""
|
||||
捕获截图
|
||||
|
||||
Args:
|
||||
tag: 标签(用于文件名)
|
||||
|
||||
Returns:
|
||||
str: 截图相对路径
|
||||
"""
|
||||
try:
|
||||
screenshot = self.env.get_screenshot()
|
||||
|
||||
# 生成文件名
|
||||
self.screenshot_counter += 1
|
||||
if tag:
|
||||
filename = f"{self.screenshot_counter:04d}_{tag}.png"
|
||||
else:
|
||||
filename = f"{self.screenshot_counter:04d}.png"
|
||||
|
||||
filepath = os.path.join(self.screens_dir, filename)
|
||||
screenshot.save(filepath)
|
||||
|
||||
logger.debug(f"截图保存: {filename}")
|
||||
return f"screens/{filename}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"截图失败: {e}")
|
||||
return None
|
||||
|
||||
def stop(self):
|
||||
"""停止录制"""
|
||||
if not self.is_recording:
|
||||
logger.warning("录制未在进行中")
|
||||
return
|
||||
|
||||
self.is_recording = False
|
||||
|
||||
# 停止监听器
|
||||
if self.mouse_listener:
|
||||
self.mouse_listener.stop()
|
||||
if self.keyboard_listener:
|
||||
self.keyboard_listener.stop()
|
||||
|
||||
# 记录结束截图
|
||||
self._capture_screenshot("final")
|
||||
|
||||
# 更新元数据
|
||||
self.metadata["recording_end"] = datetime.now().isoformat()
|
||||
self.metadata["total_duration"] = round(time.time() - self.start_time, 2)
|
||||
self.metadata["total_actions"] = len(self.actions)
|
||||
self.metadata["total_screenshots"] = self.screenshot_counter
|
||||
|
||||
logger.info("✅ 录制已停止")
|
||||
|
||||
def save(self):
|
||||
"""保存轨迹数据"""
|
||||
if self.is_recording:
|
||||
logger.warning("录制仍在进行,先停止录制")
|
||||
self.stop()
|
||||
|
||||
# 保存原始数据(未处理坐标)
|
||||
output_data = {
|
||||
"metadata": self.metadata,
|
||||
"actions": self.actions
|
||||
}
|
||||
|
||||
raw_path = os.path.join(self.output_dir, "actions_raw.json")
|
||||
|
||||
with open(raw_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"✅ 轨迹数据已保存: {raw_path}")
|
||||
logger.info(f" - 总动作数: {len(self.actions)}")
|
||||
logger.info(f" - 截图数: {self.screenshot_counter}")
|
||||
logger.info(f" - 总时长: {self.metadata.get('total_duration', 0):.1f}秒")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 录制统计:")
|
||||
print(f" 动作数: {len(self.actions)}")
|
||||
print(f" 截图数: {self.screenshot_counter}")
|
||||
print(f" 时长: {self.metadata.get('total_duration', 0):.1f}秒")
|
||||
print(f" 保存位置: {raw_path}")
|
||||
print("=" * 60)
|
||||
print("\n💡 下一步:运行坐标转换")
|
||||
print(f" python scripts/tools/process_trajectory.py {self.task_id}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
def record_interactive(jade_env, task_id, output_dir):
|
||||
"""
|
||||
交互式录制(带Ctrl+C停止)
|
||||
|
||||
Args:
|
||||
jade_env: JadeEnv实例
|
||||
task_id: 任务ID
|
||||
output_dir: 输出目录
|
||||
"""
|
||||
recorder = Recorder(jade_env, task_id, output_dir)
|
||||
recorder.start()
|
||||
|
||||
try:
|
||||
# 保持录制状态,直到Ctrl+C
|
||||
while recorder.is_recording:
|
||||
time.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⏹ 收到停止信号...")
|
||||
finally:
|
||||
recorder.stop()
|
||||
recorder.save()
|
||||
|
||||
return recorder
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Recorder 独立测试模式")
|
||||
print("提示: 通常应该通过 collect_task.py 调用")
|
||||
|
||||
Reference in New Issue
Block a user