Initial commit

This commit is contained in:
2026-01-12 18:30:12 +08:00
commit 214e15c04c
102 changed files with 27857 additions and 0 deletions

4
scripts/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
"""
JADE Benchmark Scripts Package
"""

16
scripts/core/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
"""
JADE Benchmark 核心模块
包含VM控制、轨迹录制、评测等核心功能
"""
from .jade_env import JadeEnv
from .recorder import Recorder, record_interactive
from .evaluator import evaluate, load_xrd_data
__all__ = [
'JadeEnv',
'Recorder',
'record_interactive',
'evaluate',
'load_xrd_data'
]

178
scripts/core/evaluator.py Normal file
View File

@@ -0,0 +1,178 @@
import json
import numpy as np
import os
import sys
import re
def load_xrd_data(file_path):
"""
读取 XRD 导出的 txt 文件(如 background_result.txt跳过头部的 Metadata。
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
data = []
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
try:
values = [float(x) for x in parts]
if len(values) >= 2:
data.append(values[:2])
except ValueError:
continue
if not data:
raise ValueError(f"在文件中未找到有效的数值数据: {file_path}")
return np.array(data)
def load_peak_report(file_path):
"""
专门解析 JADE Peak Search Report (.pid)
提取表格部分的数值数据
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
peaks = []
metadata = {}
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
# 查找表格开始的位置(在含有 "2-Theta" 的行之后)
table_started = False
for line in lines:
line = line.strip()
if not line:
continue
# 提取元数据示例: JADE: Peak Search Report (72 Peaks, Max P/N = 19.3)
if "Peak Search Report" in line:
match = re.search(r"\((\d+) Peaks", line)
if match:
metadata['reported_peak_count'] = int(match.group(1))
# 识别表头
if "2-Theta" in line and "d(" in line:
table_started = True
continue
if table_started:
parts = line.split()
# Peak 报告通常每行有 8-9 个字段
if len(parts) >= 6:
try:
# 尝试将前几个字段转为 float
# 0: 2-Theta, 1: d, 2: BG, 3: Height, 4: I%, 5: Area, 6: I%, 7: FWHM
peak_data = [float(p) for p in parts[:8]]
peaks.append(peak_data)
except ValueError:
# 如果转换失败,可能是说明文字或空行,跳过
continue
if not peaks:
raise ValueError(f"未能从 Peak 报告中解析出有效数据: {file_path}")
return np.array(peaks), metadata
def evaluate(gt_path, agent_path, tolerance=1e-4, mode="xrd_data"):
"""
对比 Ground Truth 和 Agent Output。
支持两种模式:
- xrd_data: 对比 (2-Theta, Intensity) 原始数据点
- peak_report: 对比 Peak 搜索结果列表
"""
try:
if mode == "peak_report":
gt_data, gt_meta = load_peak_report(gt_path)
agent_data, agent_meta = load_peak_report(agent_path)
# 对于 Peak 报告,我们主要关注 2-Theta 位置和 Height
# 这里对比全表,但放宽容差,因为 Peak Search 的算法可能在不同环境下有极细微差异
if gt_data.shape != agent_data.shape:
# 如果数量不匹配,直接判定失败
return 0, f"失败: Peak 数量不匹配。GT {len(gt_data)}, Agent {len(agent_data)}"
else:
gt_data = load_xrd_data(gt_path)
agent_data = load_xrd_data(agent_path)
if gt_data.shape != agent_data.shape:
return 0, f"失败: 数据维度不匹配。GT 形状 {gt_data.shape}, Agent 形状 {agent_data.shape}"
diff = np.abs(gt_data - agent_data)
max_error = np.max(diff)
if max_error < tolerance:
return 1, f"成功: 最大绝对误差 {max_error:.2e} < 阈值 {tolerance}"
else:
return 0, f"失败: 最大绝对误差 {max_error:.2e} 超过阈值 {tolerance}"
except Exception as e:
return 0, f"错误: {str(e)}"
def evaluate_by_config(config_path):
"""
根据任务配置文件进行评测。
"""
# 兼容性处理:如果传入的是任务 ID 路径
if not os.path.isabs(config_path) and not config_path.startswith('.'):
# 尝试补全路径,例如 instructions/smoothing_001.json
if not config_path.endswith('.json'):
config_path = config_path + ".json"
if not os.path.exists(config_path):
# 尝试在任务目录下找
pass
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(config_path)))
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
eval_cfg = config.get('evaluation', {})
# 处理相对路径
gt_path = os.path.join(base_dir, eval_cfg['ground_truth'])
agent_path = os.path.join(base_dir, eval_cfg['target_output'])
tolerance = eval_cfg.get('tolerance', 1e-4)
# 自动识别模式
mode = "xrd_data"
if gt_path.lower().endswith('.pid') or eval_cfg.get('type') == 'peak_report':
mode = "peak_report"
# Peak 报告的默认容差放宽一些,因为算法可能受环境微小影响
if 'tolerance' not in eval_cfg:
tolerance = 1e-2
print(f"--- 正在执行评测: {config.get('id', 'unknown')} ---")
print(f"指令: {config.get('instruction')}")
print(f"模式: {mode}")
score, message = evaluate(gt_path, agent_path, tolerance, mode=mode)
return score, message
if __name__ == "__main__":
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
# 模式 A: Config 模式
try:
score, message = evaluate_by_config(sys.argv[1])
except Exception as e:
print(f"配置文件解析失败: {e}")
sys.exit(1)
elif len(sys.argv) >= 3:
# 模式 B: 直接对比模式
score, message = evaluate(sys.argv[1], sys.argv[2])
else:
print("用法:")
print(" python scripts/evaluator.py instructions/smoothing_001.json")
print(" python scripts/evaluator.py <gt_file_path> <agent_file_path>")
sys.exit(1)
print("-" * 30)
print(f"Score: {score}")
print(f"Reason: {message}")
print("-" * 30)
sys.exit(0 if score == 1 else 1)

514
scripts/core/jade_env.py Normal file
View File

@@ -0,0 +1,514 @@
"""
JADE Benchmark 环境控制器
负责VM的重置、文件注入/收集、截图获取等操作
"""
import subprocess
import time
import os
import requests
from PIL import Image
import io
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class JadeEnv:
"""轻量级JADE虚拟机环境控制器"""
def __init__(self, vmx_path, snapshot_name="Jade_Ready", vm_ip="192.168.116.129",
vm_password=None, guest_username=None, guest_password=None):
"""
初始化JADE环境
Args:
vmx_path: 虚拟机.vmx文件路径
snapshot_name: 快照名称
vm_ip: 虚拟机IP地址用于HTTP通信
vm_password: 虚拟机文件加密密码(-vp参数
guest_username: 虚拟机内操作系统用户名(-gu参数
guest_password: 虚拟机内操作系统密码(-gp参数
"""
self.vmx_path = vmx_path
self.snapshot_name = snapshot_name
self.vm_ip = vm_ip
self.vm_url = f"http://{vm_ip}:5000"
# VMware认证参数
self.vm_password = vm_password
self.guest_username = guest_username
self.guest_password = guest_password
# VMware Fusion路径macOS
self.vmrun = "/Applications/VMware Fusion.app/Contents/Library/vmrun"
# 虚拟机内路径
self.guest_desktop = r"C:\Users\lzy\Desktop"
logger.info(f"JadeEnv初始化: VM={os.path.basename(vmx_path)}, Snapshot={snapshot_name}")
logger.info(f" 认证配置: vm_password={'已设置' if vm_password else '未设置'}, "
f"guest_user={'已设置' if guest_username else '未设置'}, "
f"guest_pass={'已设置' if guest_password else '未设置'}")
def _build_vmrun_cmd(self, *args):
"""构建vmrun命令"""
cmd = [self.vmrun, "-T", "fusion"]
# 添加认证参数
if self.vm_password:
cmd.extend(["-vp", self.vm_password])
if self.guest_username:
cmd.extend(["-gu", self.guest_username])
if self.guest_password:
cmd.extend(["-gp", self.guest_password])
cmd.extend(args)
return cmd
def _run_vmrun(self, *args, check=True, timeout=30):
"""执行vmrun命令"""
cmd = self._build_vmrun_cmd(*args)
# 打印完整命令(隐藏密码)
cmd_display = []
skip_next = False
for i, part in enumerate(cmd):
if skip_next:
cmd_display.append("***")
skip_next = False
elif part in ["-vp", "-gp"]:
cmd_display.append(part)
skip_next = True
else:
cmd_display.append(part)
logger.info(f"执行vmrun命令: {' '.join(cmd_display)}")
try:
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=timeout
)
if check and result.returncode != 0:
error_msg = result.stderr or result.stdout
raise RuntimeError(f"vmrun命令执行失败: {error_msg}")
return result
except subprocess.TimeoutExpired:
logger.error(f"❌ vmrun命令超时{timeout}秒)")
raise RuntimeError(f"vmrun命令执行超时{timeout}秒)")
def _detect_and_update_ip(self):
"""
检测VM的IP地址如果变化则自动更新
Returns:
bool: IP是否发生变化
"""
logger.info("🔍 检测VM IP地址...")
try:
# 使用vmrun获取VM IP
cmd = self._build_vmrun_cmd("getGuestIPAddress", self.vmx_path)
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=10
)
if result.returncode == 0:
new_ip = result.stdout.strip()
if new_ip and new_ip != "":
if new_ip != self.vm_ip:
logger.info(f"⚠️ IP地址已变化: {self.vm_ip}{new_ip}")
logger.info(f" 自动更新IP地址...")
# 更新实例变量
self.vm_ip = new_ip
self.vm_url = f"http://{new_ip}:5000"
# 更新配置文件
try:
import json
from pathlib import Path
# 获取项目根目录jade_env.py在scripts/core/向上3级到项目根目录
project_root = Path(__file__).parent.parent.parent
config_path = project_root / "config.json"
if config_path.exists():
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
config["network"]["vm_ip"] = new_ip
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"✅ 配置文件已更新: {config_path}")
except Exception as e:
logger.warning(f"⚠️ 更新配置文件失败: {e}(不影响使用)")
return True
else:
logger.info(f"✅ IP地址未变化: {self.vm_ip}")
return False
else:
logger.warning(f"⚠️ vmrun返回空IP地址")
return False
else:
error_msg = result.stderr or result.stdout
logger.warning(f"⚠️ 获取IP失败: {error_msg}将使用配置中的IP")
return False
except subprocess.TimeoutExpired:
logger.warning(f"⚠️ 获取IP超时将使用配置中的IP")
return False
except Exception as e:
logger.warning(f"⚠️ 检测IP异常: {e}将使用配置中的IP")
return False
def reset(self, wait_time=5):
"""
重置环境:恢复快照并启动虚拟机
Args:
wait_time: 启动后等待时间(秒)
"""
logger.info(f"正在恢复快照: {self.snapshot_name}...")
try:
# 1. 恢复快照
self._run_vmrun("revertToSnapshot", self.vmx_path, self.snapshot_name)
logger.info("✅ 快照恢复成功")
time.sleep(3)
# 2. 启动虚拟机(如果未运行)
logger.info("正在启动虚拟机...")
result = self._run_vmrun("start", self.vmx_path, check=False)
if result.returncode == 0:
logger.info("✅ 虚拟机启动成功")
else:
# 可能已经在运行
if "is already running" in result.stderr.lower():
logger.info("✅ 虚拟机已在运行")
else:
logger.warning(f"启动虚拟机警告: {result.stderr}")
# 3. 等待系统稳定快照恢复后agent_server已在运行
logger.info(f"等待系统稳定 ({wait_time}秒)...")
time.sleep(wait_time)
# 4. 检测并更新IP地址恢复快照后IP可能变化
self._detect_and_update_ip()
# 5. 验证HTTP服务可用
self._wait_for_http_service()
logger.info("✅ 环境重置完成")
except Exception as e:
logger.error(f"❌ 环境重置失败: {e}")
raise
def _wait_for_http_service(self, max_retries=10, retry_interval=3):
"""等待agent_server.py HTTP服务可用"""
logger.info(f"等待虚拟机HTTP服务... (URL: {self.vm_url})")
# 绕过代理避免Clash等代理工具干扰局域网访问
proxies = {
'http': None,
'https': None
}
for i in range(max_retries):
try:
logger.debug(f"尝试连接: {self.vm_url}/screen_info (timeout=5秒, 不使用代理)")
response = requests.get(f"{self.vm_url}/screen_info", timeout=5, proxies=proxies)
logger.debug(f"收到响应: status_code={response.status_code}")
if response.status_code == 200:
logger.info("✅ HTTP服务已就绪")
return True
else:
logger.warning(f"HTTP状态码异常: {response.status_code}")
except requests.exceptions.Timeout as e:
logger.info(f"HTTP服务未就绪超时重试 {i+1}/{max_retries}...")
if i < max_retries - 1:
time.sleep(retry_interval)
except requests.exceptions.ConnectionError as e:
logger.info(f"HTTP服务未就绪连接失败: {str(e)[:50]}),重试 {i+1}/{max_retries}...")
if i < max_retries - 1:
time.sleep(retry_interval)
except requests.exceptions.RequestException as e:
logger.warning(f"HTTP请求异常: {type(e).__name__}: {str(e)[:100]}")
if i < max_retries - 1:
logger.info(f"重试 {i+1}/{max_retries}...")
time.sleep(retry_interval)
else:
logger.error("❌ HTTP服务超时")
logger.error(f" 最后错误: {e}")
return False
logger.error("❌ HTTP服务超时请检查agent_server.py是否在VM中运行")
logger.info(" 在VM中运行: python agent_server.py")
return False
def inject_file(self, host_path, guest_filename=None):
"""
将文件从主机注入到虚拟机桌面
Args:
host_path: 主机文件路径
guest_filename: 虚拟机中的文件名(默认使用原文件名)
"""
if not os.path.exists(host_path):
raise FileNotFoundError(f"源文件不存在: {host_path}")
if guest_filename is None:
guest_filename = os.path.basename(host_path)
guest_path = f"{self.guest_desktop}\\{guest_filename}"
# 获取文件大小
file_size = os.path.getsize(host_path)
file_size_kb = file_size / 1024
logger.info(f"注入文件: {os.path.basename(host_path)} ({file_size_kb:.1f}KB) → 虚拟机桌面")
logger.info(f" 源路径: {host_path}")
logger.info(f" 目标路径: {guest_path}")
try:
start_time = time.time()
# 使用vmrun传输30秒超时
self._run_vmrun(
"copyFileFromHostToGuest",
self.vmx_path,
host_path,
guest_path,
timeout=30
)
elapsed = time.time() - start_time
logger.info(f"✅ 文件注入成功: {guest_filename} (耗时 {elapsed:.1f}秒)")
except Exception as e:
logger.error(f"❌ 文件注入失败: {e}")
raise
def collect_file(self, guest_filename, host_path):
"""
从虚拟机桌面收集文件到主机
Args:
guest_filename: 虚拟机桌面上的文件名
host_path: 主机保存路径
"""
guest_path = f"{self.guest_desktop}\\{guest_filename}"
logger.info(f"收集文件: {guest_filename}{os.path.basename(host_path)}")
try:
# 确保目标目录存在
os.makedirs(os.path.dirname(host_path), exist_ok=True)
# 方法1: 尝试使用vmrun
try:
self._run_vmrun(
"copyFileFromGuestToHost",
self.vmx_path,
guest_path,
host_path
)
logger.info(f"✅ 文件收集成功vmrun: {guest_filename}")
return
except RuntimeError as e:
logger.warning(f"vmrun收集失败尝试HTTP方式: {e}")
# 方法2: 尝试通过HTTP下载备用
try:
response = requests.get(
f"{self.vm_url}/download/{guest_filename}",
timeout=10
)
if response.status_code == 200:
with open(host_path, 'wb') as f:
f.write(response.content)
logger.info(f"✅ 文件收集成功HTTP: {guest_filename}")
else:
raise RuntimeError(f"HTTP下载失败: {response.status_code}")
except Exception as http_error:
logger.error(f"❌ HTTP收集也失败: {http_error}")
raise RuntimeError(f"文件收集失败(两种方法都失败)")
except Exception as e:
logger.error(f"❌ 文件收集失败: {e}")
raise
def get_screenshot(self, retry_with_ip_detect=True):
"""
获取虚拟机截图
Args:
retry_with_ip_detect: 如果连接失败是否尝试检测IP并重试
Returns:
PIL.Image对象
"""
try:
# 绕过代理
proxies = {'http': None, 'https': None}
response = requests.get(f"{self.vm_url}/screenshot", timeout=5, proxies=proxies)
if response.status_code == 200:
return Image.open(io.BytesIO(response.content))
else:
raise RuntimeError(f"截图失败: HTTP {response.status_code}")
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
if retry_with_ip_detect:
logger.warning(f"⚠️ 截图连接失败尝试检测并更新IP...")
if self._detect_and_update_ip():
# IP已更新重试一次
logger.info(f"🔄 使用新IP重试截图...")
return self.get_screenshot(retry_with_ip_detect=False)
logger.error(f"❌ 获取截图失败: {e}")
raise
except Exception as e:
logger.error(f"❌ 获取截图失败: {e}")
raise
def get_screen_info(self, retry_with_ip_detect=True):
"""
获取虚拟机屏幕信息分辨率、DPI等
Args:
retry_with_ip_detect: 如果连接失败是否尝试检测IP并重试
Returns:
dict: 包含screen_width, screen_height, dpi_scale等信息
"""
try:
# 绕过代理
proxies = {'http': None, 'https': None}
response = requests.get(f"{self.vm_url}/screen_info", timeout=5, proxies=proxies)
if response.status_code == 200:
return response.json()
else:
raise RuntimeError(f"获取屏幕信息失败: HTTP {response.status_code}")
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
if retry_with_ip_detect:
logger.warning(f"⚠️ 屏幕信息连接失败尝试检测并更新IP...")
if self._detect_and_update_ip():
# IP已更新重试一次
logger.info(f"🔄 使用新IP重试获取屏幕信息...")
return self.get_screen_info(retry_with_ip_detect=False)
logger.error(f"❌ 获取屏幕信息失败: {e}")
raise
except Exception as e:
logger.error(f"❌ 获取屏幕信息失败: {e}")
raise
def list_desktop_files(self):
"""
列出虚拟机桌面文件(用于调试)
Returns:
list: 文件名列表
"""
try:
# 绕过代理
proxies = {'http': None, 'https': None}
response = requests.get(f"{self.vm_url}/list_desktop", timeout=5, proxies=proxies)
if response.status_code == 200:
return response.json().get('files', [])
else:
raise RuntimeError(f"列出文件失败: HTTP {response.status_code}")
except Exception as e:
logger.warning(f"⚠️ 列出桌面文件失败: {e}")
return []
def send_action(self, action_type, **params):
"""
发送动作到虚拟机用于未来的Agent自动执行
Args:
action_type: 动作类型 (click/type/hotkey)
**params: 动作参数
"""
try:
# 绕过代理
proxies = {'http': None, 'https': None}
payload = {"type": action_type, **params}
response = requests.post(
f"{self.vm_url}/action",
json=payload,
timeout=5,
proxies=proxies
)
if response.status_code == 200:
logger.debug(f"动作执行成功: {action_type}")
else:
raise RuntimeError(f"动作执行失败: HTTP {response.status_code}")
except Exception as e:
logger.error(f"❌ 发送动作失败: {e}")
raise
def get_mouse_pos(self):
"""
从虚拟机获取当前鼠标物理坐标
Returns:
tuple: (x, y) 物理坐标,失败返回 (None, None)
"""
try:
# 绕过代理
proxies = {'http': None, 'https': None}
response = requests.get(f"{self.vm_url}/mouse_pos", timeout=2, proxies=proxies)
if response.status_code == 200:
data = response.json()
return data['x'], data['y']
return None, None
except Exception as e:
logger.debug(f"获取VM鼠标位置失败: {e}")
return None, None
if __name__ == "__main__":
# 测试代码
print("JadeEnv 测试")
print("=" * 60)
# 配置(需要根据实际情况修改)
VMX_PATH = "/Volumes/Castor/虚拟机/Jade_Win_11.vmwarevm/Windows 11 64 位 ARM 2.vmx"
SNAPSHOT = "Jade_Ready"
VM_PASSWORD = "lizhanyuan"
try:
env = JadeEnv(
vmx_path=VMX_PATH,
snapshot_name=SNAPSHOT,
vm_password=VM_PASSWORD,
guest_username="lzy",
guest_password="LIZHANYUAN"
)
# 测试重置
print("\n测试1: 重置环境")
env.reset()
# 测试获取屏幕信息
print("\n测试2: 获取屏幕信息")
info = env.get_screen_info()
print(f" 分辨率: {info['screen_width']}x{info['screen_height']}")
print(f" DPI缩放: {info['dpi_scale']}")
# 测试列出桌面文件
print("\n测试3: 列出桌面文件")
files = env.list_desktop_files()
print(f" 桌面文件: {files[:5]}..." if len(files) > 5 else f" 桌面文件: {files}")
print("\n✅ 所有测试通过!")
except Exception as e:
print(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()

295
scripts/core/recorder.py Normal file
View File

@@ -0,0 +1,295 @@
"""
轨迹录制器
监听鼠标键盘事件,记录操作轨迹和截图
"""
import time
import json
import os
from datetime import datetime
from pynput import mouse, keyboard
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Recorder:
"""轨迹录制器 - Host端事件驱动录制"""
def __init__(self, jade_env, task_id, output_dir):
"""
初始化录制器
Args:
jade_env: JadeEnv实例
task_id: 任务ID
output_dir: 输出目录human_demo/
"""
self.env = jade_env
self.task_id = task_id
self.output_dir = output_dir
# 创建输出目录
self.screens_dir = os.path.join(output_dir, "screens")
os.makedirs(self.screens_dir, exist_ok=True)
# 数据结构
self.actions = []
self.metadata = {}
self.start_time = None
self.screenshot_counter = 0
# 监听器
self.mouse_listener = None
self.keyboard_listener = None
# 状态
self.is_recording = False
logger.info(f"录制器初始化: 任务={task_id}")
def start(self):
"""开始录制"""
if self.is_recording:
logger.warning("录制已在进行中")
return
self.is_recording = True
self.start_time = time.time()
# 获取虚拟机屏幕信息
try:
screen_info = self.env.get_screen_info()
self.metadata = {
"task_id": self.task_id,
"vm_resolution": [screen_info['screen_width'], screen_info['screen_height']],
"vm_screenshot_resolution": [screen_info['screenshot_width'], screen_info['screenshot_height']],
"vm_dpi_scale": screen_info['dpi_scale'],
"recording_start": datetime.now().isoformat(),
"recording_end": None
}
logger.info(f"虚拟机分辨率: {screen_info['screen_width']}x{screen_info['screen_height']}")
logger.info(f"截图分辨率: {screen_info['screenshot_width']}x{screen_info['screenshot_height']}")
except Exception as e:
logger.warning(f"获取屏幕信息失败: {e}")
self.metadata = {
"task_id": self.task_id,
"recording_start": datetime.now().isoformat(),
"recording_end": None
}
# 记录初始截图
self._capture_screenshot("initial")
# 启动监听器
self.mouse_listener = mouse.Listener(
on_click=self._on_mouse_click,
on_scroll=self._on_mouse_scroll
)
self.keyboard_listener = keyboard.Listener(
on_press=self._on_key_press
)
self.mouse_listener.start()
self.keyboard_listener.start()
logger.info("✅ 录制已启动")
print("\n" + "=" * 60)
print("🎥 录制进行中...")
print("💡 提示:")
print(" - 请在VMware窗口中操作JADE")
print(" - 每次点击都会自动截图")
print(" - 按 Ctrl+C 停止录制")
print("=" * 60 + "\n")
def _on_mouse_click(self, x, y, button, pressed):
"""鼠标点击事件处理"""
if not self.is_recording or not pressed:
return
# 核心修改:立刻从虚拟机获取真实物理坐标
vm_x, vm_y = self.env.get_mouse_pos()
elapsed = time.time() - self.start_time
# 记录动作
action = {
"t": round(elapsed, 3),
"type": "click",
"button": str(button).replace("Button.", ""),
"pos_host": [x, y], # Mac 逻辑坐标(留作参考)
"pos_vm": [vm_x, vm_y] if vm_x is not None else None # 真实VM物理坐标
}
# 截图
screenshot_filename = self._capture_screenshot("click")
action["screenshot"] = screenshot_filename
self.actions.append(action)
if vm_x is not None:
logger.info(f"[{elapsed:.1f}s] 点击: VM({vm_x}, {vm_y}) [Host: {int(x)}, {int(y)}] {action['button']}")
else:
logger.info(f"[{elapsed:.1f}s] 点击: Host({int(x)}, {int(y)}) [VM获取失败] {action['button']}")
def _on_mouse_scroll(self, x, y, dx, dy):
"""鼠标滚轮事件处理"""
if not self.is_recording:
return
elapsed = time.time() - self.start_time
action = {
"t": round(elapsed, 3),
"type": "scroll",
"pos_host": [x, y],
"delta": [dx, dy],
"pos_vm": None
}
self.actions.append(action)
logger.debug(f"[{elapsed:.1f}s] 滚轮: ({x}, {y}) delta=({dx}, {dy})")
def _on_key_press(self, key):
"""键盘按键事件处理"""
if not self.is_recording:
return
elapsed = time.time() - self.start_time
# 转换按键名称
try:
if hasattr(key, 'char') and key.char:
key_name = key.char
else:
key_name = str(key).replace("Key.", "")
except:
key_name = str(key)
action = {
"t": round(elapsed, 3),
"type": "key",
"key": key_name
}
self.actions.append(action)
logger.debug(f"[{elapsed:.1f}s] 按键: {key_name}")
def _capture_screenshot(self, tag=""):
"""
捕获截图
Args:
tag: 标签(用于文件名)
Returns:
str: 截图相对路径
"""
try:
screenshot = self.env.get_screenshot()
# 生成文件名
self.screenshot_counter += 1
if tag:
filename = f"{self.screenshot_counter:04d}_{tag}.png"
else:
filename = f"{self.screenshot_counter:04d}.png"
filepath = os.path.join(self.screens_dir, filename)
screenshot.save(filepath)
logger.debug(f"截图保存: {filename}")
return f"screens/{filename}"
except Exception as e:
logger.error(f"截图失败: {e}")
return None
def stop(self):
"""停止录制"""
if not self.is_recording:
logger.warning("录制未在进行中")
return
self.is_recording = False
# 停止监听器
if self.mouse_listener:
self.mouse_listener.stop()
if self.keyboard_listener:
self.keyboard_listener.stop()
# 记录结束截图
self._capture_screenshot("final")
# 更新元数据
self.metadata["recording_end"] = datetime.now().isoformat()
self.metadata["total_duration"] = round(time.time() - self.start_time, 2)
self.metadata["total_actions"] = len(self.actions)
self.metadata["total_screenshots"] = self.screenshot_counter
logger.info("✅ 录制已停止")
def save(self):
"""保存轨迹数据"""
if self.is_recording:
logger.warning("录制仍在进行,先停止录制")
self.stop()
# 保存原始数据(未处理坐标)
output_data = {
"metadata": self.metadata,
"actions": self.actions
}
raw_path = os.path.join(self.output_dir, "actions_raw.json")
with open(raw_path, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
logger.info(f"✅ 轨迹数据已保存: {raw_path}")
logger.info(f" - 总动作数: {len(self.actions)}")
logger.info(f" - 截图数: {self.screenshot_counter}")
logger.info(f" - 总时长: {self.metadata.get('total_duration', 0):.1f}")
print("\n" + "=" * 60)
print("📊 录制统计:")
print(f" 动作数: {len(self.actions)}")
print(f" 截图数: {self.screenshot_counter}")
print(f" 时长: {self.metadata.get('total_duration', 0):.1f}")
print(f" 保存位置: {raw_path}")
print("=" * 60)
print("\n💡 下一步:运行坐标转换")
print(f" python scripts/tools/process_trajectory.py {self.task_id}")
print("=" * 60 + "\n")
def record_interactive(jade_env, task_id, output_dir):
"""
交互式录制带Ctrl+C停止
Args:
jade_env: JadeEnv实例
task_id: 任务ID
output_dir: 输出目录
"""
recorder = Recorder(jade_env, task_id, output_dir)
recorder.start()
try:
# 保持录制状态直到Ctrl+C
while recorder.is_recording:
time.sleep(0.1)
except KeyboardInterrupt:
print("\n\n⏹ 收到停止信号...")
finally:
recorder.stop()
recorder.save()
return recorder
if __name__ == "__main__":
print("Recorder 独立测试模式")
print("提示: 通常应该通过 collect_task.py 调用")

View File

@@ -0,0 +1,182 @@
# 运行在 Windows 虚拟机内部
from flask import Flask, request, send_file
import pyautogui
import io
import os
import subprocess
import ctypes
import time
app = Flask(__name__)
# 获取Windows DPI缩放比例
def get_dpi_scale():
"""获取Windows的DPI缩放比例"""
try:
# 获取主显示器的DPI缩放比例
scale_factor = ctypes.windll.shcore.GetScaleFactorForDevice(0) / 100.0
return scale_factor
except:
# 如果获取失败默认返回1.0(无缩放)
return 1.0
# 获取实际屏幕分辨率
def get_screen_size():
"""获取实际屏幕分辨率(物理像素)"""
try:
user32 = ctypes.windll.user32
width = user32.GetSystemMetrics(0) # SM_CXSCREEN
height = user32.GetSystemMetrics(1) # SM_CYSCREEN
return width, height
except:
# 如果获取失败,使用 pyautogui 的方法
return pyautogui.size()
DPI_SCALE = get_dpi_scale()
SCREEN_WIDTH, SCREEN_HEIGHT = get_screen_size()
print(f"检测到DPI缩放比例: {DPI_SCALE}")
print(f"实际屏幕分辨率: {SCREEN_WIDTH} x {SCREEN_HEIGHT}")
# 获取截图分辨率(用于坐标转换)
def get_screenshot_size():
"""获取截图的实际分辨率"""
img = pyautogui.screenshot()
return img.size[0], img.size[1]
SCREENSHOT_WIDTH, SCREENSHOT_HEIGHT = get_screenshot_size()
print(f"截图分辨率: {SCREENSHOT_WIDTH} x {SCREENSHOT_HEIGHT}")
# 1. 获取屏幕截图
@app.route('/screenshot', methods=['GET'])
def screenshot():
img = pyautogui.screenshot()
img_io = io.BytesIO()
img.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
# 获取分辨率信息(用于调试)
@app.route('/screen_info', methods=['GET'])
def screen_info():
"""返回屏幕和截图的分辨率信息,用于调试坐标转换"""
screenshot_w, screenshot_h = get_screenshot_size()
return {
"screen_width": SCREEN_WIDTH,
"screen_height": SCREEN_HEIGHT,
"screenshot_width": screenshot_w,
"screenshot_height": screenshot_h,
"dpi_scale": DPI_SCALE,
"scale_ratio_x": SCREEN_WIDTH / screenshot_w if screenshot_w > 0 else 1.0,
"scale_ratio_y": SCREEN_HEIGHT / screenshot_h if screenshot_h > 0 else 1.0
}
# 2. 执行动作
@app.route('/action', methods=['POST'])
def action():
data = request.json
try:
if data['type'] == 'click':
# 获取当前截图分辨率(可能每次不同)
screenshot_w, screenshot_h = get_screenshot_size()
# 从截图坐标转换为实际屏幕坐标
# 如果截图分辨率和屏幕分辨率不同,需要按比例缩放
x = data['x']
y = data['y']
# 计算缩放比例
scale_x = SCREEN_WIDTH / screenshot_w if screenshot_w > 0 else 1.0
scale_y = SCREEN_HEIGHT / screenshot_h if screenshot_h > 0 else 1.0
# 应用缩放
actual_x = int(x * scale_x)
actual_y = int(y * scale_y)
print(f"收到坐标: ({x}, {y}) -> 转换后: ({actual_x}, {actual_y}) [缩放比例: {scale_x:.2f}, {scale_y:.2f}]")
pyautogui.click(x=actual_x, y=actual_y)
elif data['type'] == 'type':
pyautogui.write(data['text'])
elif data['type'] == 'hotkey':
pyautogui.hotkey(*data['keys']) # 例如 ['ctrl', 's']
return {"status": "success"}
except Exception as e:
return {"status": "error", "msg": str(e)}
# 获取当前鼠标位置 (用于Host录制辅助)
@app.route('/mouse_pos', methods=['GET'])
def mouse_pos():
"""获取虚拟机当前鼠标位置"""
try:
x, y = pyautogui.position()
return {
"status": "success",
"x": int(x),
"y": int(y),
"timestamp": time.time()
}
except Exception as e:
return {"status": "error", "msg": str(e)}, 500
# 3. [关键!] 初始化环境
@app.route('/reset', methods=['POST'])
def reset():
# 这里可以写简单的逻辑:
# 1. 杀死 Jade 进程
os.system("taskkill /f /im jade.exe")
# 2. 这里的"重置"比快照弱,但对于 M1 调试更方便
# 如果必须用快照,需要在 Step 3 的 Mac 端调用 vmrun
return {"status": "reset_done"}
# 4. 列出桌面文件(用于调试)
@app.route('/list_desktop', methods=['GET'])
def list_desktop():
"""列出桌面上的文件"""
try:
desktop = os.path.expanduser(r"~\Desktop")
if os.path.exists(desktop):
files = os.listdir(desktop)
return {"status": "success", "files": files, "desktop_path": desktop}
else:
return {"status": "error", "msg": "Desktop path not found"}
except Exception as e:
return {"status": "error", "msg": str(e)}
# 5. 下载桌面文件(备用文件收集方式)
@app.route('/download/<filename>', methods=['GET'])
def download_file(filename):
"""
从桌面下载文件
用作vmrun文件传输的备用方案
"""
try:
desktop = os.path.expanduser(r"~\Desktop")
filepath = os.path.join(desktop, filename)
if not os.path.exists(filepath):
return {"status": "error", "msg": f"File not found: {filename}"}, 404
return send_file(filepath, as_attachment=True, download_name=filename)
except Exception as e:
return {"status": "error", "msg": str(e)}, 500
if __name__ == '__main__':
# 监听 0.0.0.0 允许外部访问
print("\n" + "=" * 60)
print("JADE Agent Server 启动")
print("=" * 60)
print(f"监听地址: 0.0.0.0:5000")
print(f"屏幕分辨率: {SCREEN_WIDTH}x{SCREEN_HEIGHT}")
print(f"截图分辨率: {SCREENSHOT_WIDTH}x{SCREENSHOT_HEIGHT}")
print(f"DPI缩放: {DPI_SCALE}")
print("=" * 60)
print("\n可用接口:")
print(" GET /screenshot - 获取屏幕截图")
print(" GET /screen_info - 获取屏幕信息")
print(" POST /action - 执行动作")
print(" POST /reset - 重置环境")
print(" GET /list_desktop - 列出桌面文件")
print(" GET /download/<file> - 下载桌面文件")
print("=" * 60 + "\n")
app.run(host='0.0.0.0', port=5000)

View File

@@ -0,0 +1,4 @@
@echo off
cd C:\Users\lzy\workplace\OSWorld\desktop_env\server\
..venv\bin\activate
python agent_server.py

View File

@@ -0,0 +1,213 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量创建任务
从JSON定义文件批量创建任务
用法:
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json
"""
import json
import os
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from scripts.tools.init_task import init_task
def batch_create_tasks(definition_file, project_root=".", force=False, skip_existing=True):
"""
批量创建任务
Args:
definition_file: 任务定义JSON文件路径
project_root: 项目根目录
force: 是否强制覆盖已存在的任务
skip_existing: 是否跳过已存在的任务与force互斥
"""
# 读取任务定义
with open(definition_file, 'r', encoding='utf-8') as f:
data = json.load(f)
tasks = data.get("tasks", [])
category = data.get("category", "unknown")
tutorial_source = data.get("tutorial_source", "")
print("=" * 60)
print(f"📚 批量创建任务: {category}")
if tutorial_source:
print(f" 教程来源: {tutorial_source}")
print(f" 任务数量: {len(tasks)}")
print("=" * 60)
results = {
"success": [],
"skipped": [],
"failed": []
}
for i, task_def in enumerate(tasks, 1):
task_id = task_def["id"]
print(f"\n[{i}/{len(tasks)}] 处理任务: {task_id}")
print("-" * 60)
# 检查是否已存在
task_dir = os.path.join(project_root, "tasks", task_id)
if os.path.exists(task_dir) and not force:
if skip_existing:
print(f"⏭️ 跳过(已存在): {task_id}")
results["skipped"].append(task_id)
continue
else:
print(f"⚠️ 任务已存在: {task_id}")
print(" 使用 --force 强制覆盖,或设置 skip_existing=True")
results["failed"].append((task_id, "已存在"))
continue
try:
# 调用 init_task 创建任务结构
success = init_task(
task_id=task_id,
project_root=project_root,
force=force,
category=task_def.get("category", category),
difficulty=task_def.get("difficulty", "easy"),
instruction=task_def.get("instruction", "")
)
if not success:
results["failed"].append((task_id, "初始化失败"))
continue
# 更新 task.json
task_json_path = os.path.join(task_dir, "task.json")
if os.path.exists(task_json_path):
with open(task_json_path, 'r', encoding='utf-8') as f:
task_config = json.load(f)
# 更新输入输出配置
source_file = task_def.get("source_file", "DEMO01.MDI")
if not os.path.isabs(source_file) and not source_file.startswith("../"):
source_file = f"../../data/source/{source_file}"
filename = os.path.basename(source_file)
inject_to = f"C:\\Users\\lzy\\Desktop\\{filename}"
output_filename = task_def.get("output_filename", "result.txt")
collect_from = f"C:\\Users\\lzy\\Desktop\\{output_filename}"
task_config["input"] = {
"source_file": source_file,
"inject_to": inject_to
}
task_config["output"] = {
"expected_file": output_filename,
"collect_from": collect_from
}
# 更新评测方法
eval_method = task_def.get("evaluation_method", "xrd_data_compare")
task_config["evaluation"]["method"] = eval_method
# 添加教程来源
if task_def.get("tutorial_source"):
task_config["tutorial_source"] = task_def["tutorial_source"]
# 添加备注
if task_def.get("notes"):
task_config["notes"] = task_def["notes"]
# 保存
with open(task_json_path, 'w', encoding='utf-8') as f:
json.dump(task_config, f, ensure_ascii=False, indent=2)
print(f"✅ 任务创建成功: {task_id}")
results["success"].append(task_id)
else:
print(f"❌ task.json 未创建: {task_json_path}")
results["failed"].append((task_id, "配置文件未创建"))
except Exception as e:
print(f"❌ 创建任务失败: {task_id}")
print(f" 错误: {e}")
import traceback
traceback.print_exc()
results["failed"].append((task_id, str(e)))
# 打印总结
print("\n" + "=" * 60)
print("📊 批量创建总结")
print("=" * 60)
print(f"✅ 成功: {len(results['success'])}")
if results["success"]:
for task_id in results["success"]:
print(f" - {task_id}")
if results["skipped"]:
print(f"\n⏭️ 跳过: {len(results['skipped'])}")
for task_id in results["skipped"]:
print(f" - {task_id}")
if results["failed"]:
print(f"\n❌ 失败: {len(results['failed'])}")
for task_id, reason in results["failed"]:
print(f" - {task_id}: {reason}")
print("=" * 60)
return results
def main():
import argparse
parser = argparse.ArgumentParser(
description="批量创建任务",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
# 批量创建任务(跳过已存在的)
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json
# 强制覆盖已存在的任务
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json --force
# 不跳过已存在的任务(遇到已存在就失败)
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json --no-skip-existing
"""
)
parser.add_argument("definition_file", help="任务定义JSON文件路径")
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--force", action="store_true", help="强制覆盖已存在的任务")
parser.add_argument("--no-skip-existing", action="store_true", help="不跳过已存在的任务(遇到就失败)")
args = parser.parse_args()
if not os.path.exists(args.definition_file):
print(f"❌ 定义文件不存在: {args.definition_file}")
sys.exit(1)
skip_existing = not args.no_skip_existing
results = batch_create_tasks(
args.definition_file,
args.project_root,
force=args.force,
skip_existing=skip_existing
)
# 如果有失败的任务返回非0退出码
if results["failed"]:
sys.exit(1)
else:
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,265 @@
"""
任务数据采集入口
整合环境控制、轨迹录制、文件收集的完整流程
"""
import os
import sys
import argparse
import json
import logging
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.jade_env import JadeEnv
from core.recorder import record_interactive
from utils.config_loader import load_config, get_vm_config, get_network_config
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
def load_default_config():
"""加载默认配置"""
try:
config = load_config()
vm_config = get_vm_config(config)
network_config = get_network_config(config)
return {
"vmx_path": vm_config.get('vmx_path'),
"snapshot_name": vm_config.get('snapshot_name', 'Jade_Ready'),
"vm_ip": network_config.get('vm_ip'),
"vm_password": vm_config.get('vm_password'),
"guest_username": vm_config.get('guest_username'),
"guest_password": vm_config.get('guest_password')
}
except Exception as e:
logger.warning(f"⚠️ 无法加载config.json: {e}")
logger.info(" 使用硬编码配置")
return {
"vmx_path": "/Volumes/Castor/虚拟机/Jade_Win_11.vmwarevm/Windows 11 64 位 ARM 2.vmx",
"snapshot_name": "Jade_Ready",
"vm_ip": "192.168.116.129",
"vm_password": "lizhanyuan",
"guest_username": "lzy",
"guest_password": "LIZHANYUAN"
}
def load_task_config(task_id, project_root="."):
"""加载任务配置文件"""
task_json_path = os.path.join(project_root, "tasks", task_id, "task.json")
if not os.path.exists(task_json_path):
logger.error(f"❌ 任务配置文件不存在: {task_json_path}")
logger.info(" 请先创建任务目录和task.json")
return None
with open(task_json_path, 'r', encoding='utf-8') as f:
return json.load(f)
def mode_reset(env, task_config, project_root="."):
"""
模式1: 重置环境并注入输入文件
"""
print("\n" + "=" * 60)
print("🔄 模式: 重置环境")
print("=" * 60)
# 1. 重置虚拟机
env.reset()
# 2. 注入输入文件
if 'input' in task_config:
input_config = task_config['input']
source_file = input_config.get('source_file')
if source_file:
# 处理相对路径(相对于任务目录)
if not os.path.isabs(source_file):
task_dir = os.path.join(project_root, "tasks", task_config['id'])
source_file = os.path.normpath(os.path.join(task_dir, source_file))
# 确保使用绝对路径
source_file = os.path.abspath(source_file)
if os.path.exists(source_file):
# 从Windows路径中提取文件名处理反斜杠
inject_to = input_config.get('inject_to', '')
if inject_to:
# 使用Windows路径分隔符分割
guest_filename = inject_to.split('\\')[-1]
else:
guest_filename = os.path.basename(source_file)
env.inject_file(source_file, guest_filename)
else:
logger.warning(f"⚠️ 输入文件不存在: {source_file}")
print("\n✅ 环境准备完成!")
print("=" * 60)
print("💡 下一步:开始录制操作")
print(f" python scripts/tools/collect_task.py {task_config['id']} --mode record")
print("=" * 60 + "\n")
def mode_record(env, task_config, project_root="."):
"""
模式2: 录制人类操作轨迹
"""
task_id = task_config['id']
output_dir = os.path.join(project_root, "tasks", task_id, "human_demo")
print("\n" + "=" * 60)
print("🎥 模式: 录制轨迹")
print("=" * 60)
print(f"任务: {task_config.get('instruction', 'N/A')}")
print("=" * 60)
# 创建录制器并开始录制
record_interactive(env, task_id, output_dir)
print("\n💡 下一步:处理坐标转换")
print(f" python scripts/tools/process_trajectory.py {task_id}")
print("=" * 60 + "\n")
def mode_collect(env, task_config, project_root="."):
"""
模式3: 收集输出文件到ground_truth
"""
print("\n" + "=" * 60)
print("📦 模式: 收集结果文件")
print("=" * 60)
task_id = task_config['id']
if 'output' in task_config:
output_config = task_config['output']
expected_file = output_config.get('expected_file')
if expected_file:
# 目标路径
gt_dir = os.path.join(project_root, "tasks", task_id, "ground_truth")
os.makedirs(gt_dir, exist_ok=True)
host_path = os.path.join(gt_dir, expected_file)
# 收集文件
env.collect_file(expected_file, host_path)
print(f"\n✅ 文件已保存到: {host_path}")
else:
logger.warning("⚠️ 任务配置中未指定expected_file")
else:
logger.warning("⚠️ 任务配置中未指定output")
print("\n" + "=" * 60)
print("💡 下一步:验证评测")
print(f" python scripts/tools/run_eval.py {task_id}")
print("=" * 60 + "\n")
def mode_full(env, task_config, project_root="."):
"""
模式4: 完整流程reset + record + collect
"""
print("\n" + "=" * 60)
print("🔄 模式: 完整采集流程")
print("=" * 60)
# Step 1: Reset
mode_reset(env, task_config, project_root)
# Step 2: Record
input("\n按Enter键开始录制...")
mode_record(env, task_config, project_root)
# Step 3: Collect
input("\n按Enter键收集结果...")
mode_collect(env, task_config, project_root)
print("\n✅ 完整采集流程完成!")
def main():
parser = argparse.ArgumentParser(
description="JADE Benchmark 任务数据采集工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
# 完整流程(推荐)
python scripts/collect_task.py smoothing_001 --mode full
# 分步执行
python scripts/collect_task.py smoothing_001 --mode reset # 1. 重置并注入文件
python scripts/collect_task.py smoothing_001 --mode record # 2. 录制操作
python scripts/collect_task.py smoothing_001 --mode collect # 3. 收集结果
"""
)
parser.add_argument("task_id", help="任务ID对应tasks/目录下的子目录名)")
parser.add_argument(
"--mode",
choices=["reset", "record", "collect", "full"],
default="full",
help="采集模式默认full"
)
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--vmx", help="虚拟机.vmx文件路径覆盖默认配置")
parser.add_argument("--snapshot", help="快照名称(覆盖默认配置)")
parser.add_argument("--vm-ip", help="虚拟机IP地址覆盖默认配置")
args = parser.parse_args()
# 加载任务配置
task_config = load_task_config(args.task_id, args.project_root)
if not task_config:
sys.exit(1)
# 加载并合并配置
config = load_default_config()
if args.vmx:
config['vmx_path'] = args.vmx
if args.snapshot:
config['snapshot_name'] = args.snapshot
if args.vm_ip:
config['vm_ip'] = args.vm_ip
# 初始化环境
try:
logger.info("初始化JADE环境...")
env = JadeEnv(
vmx_path=config['vmx_path'],
snapshot_name=config['snapshot_name'],
vm_ip=config['vm_ip'],
vm_password=config.get('vm_password'),
guest_username=config.get('guest_username'),
guest_password=config.get('guest_password')
)
# 执行对应模式
if args.mode == "reset":
mode_reset(env, task_config, args.project_root)
elif args.mode == "record":
mode_record(env, task_config, args.project_root)
elif args.mode == "collect":
mode_collect(env, task_config, args.project_root)
elif args.mode == "full":
mode_full(env, task_config, args.project_root)
except KeyboardInterrupt:
print("\n\n⏹ 操作已取消")
sys.exit(1)
except Exception as e:
logger.error(f"❌ 错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

223
scripts/tools/detect_vm_ip.py Executable file
View File

@@ -0,0 +1,223 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
自动检测VM的IP地址
使用vmrun getGuestIPAddress命令获取VM的当前IP
"""
import subprocess
import sys
import json
import os
from pathlib import Path
try:
import requests
except ImportError:
requests = None
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from scripts.utils.config_loader import load_config
def get_vm_ip(vmx_path, vm_password=None):
"""
使用vmrun获取VM的IP地址
Args:
vmx_path: 虚拟机.vmx文件路径
vm_password: 虚拟机文件加密密码(可选)
Returns:
str: VM的IP地址如果失败返回None
"""
vmrun = "/Applications/VMware Fusion.app/Contents/Library/vmrun"
# 构建命令
cmd = [vmrun, "-T", "fusion"]
if vm_password:
cmd.extend(["-vp", vm_password])
cmd.extend(["getGuestIPAddress", vmx_path])
try:
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=10
)
if result.returncode == 0:
ip = result.stdout.strip()
if ip and ip != "":
return ip
else:
print(f"⚠️ vmrun返回空IP地址")
return None
else:
error_msg = result.stderr or result.stdout
print(f"❌ 获取IP失败: {error_msg}")
return None
except subprocess.TimeoutExpired:
print(f"❌ 获取IP超时")
return None
except Exception as e:
print(f"❌ 获取IP异常: {e}")
return None
def update_config_ip(new_ip, config_path="config.json"):
"""
更新config.json中的IP地址
Args:
new_ip: 新的IP地址
config_path: 配置文件路径
"""
config_path = os.path.join(project_root, config_path)
if not os.path.exists(config_path):
print(f"❌ 配置文件不存在: {config_path}")
return False
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
old_ip = config.get("network", {}).get("vm_ip", "未知")
if old_ip == new_ip:
print(f"✅ IP地址未变化: {new_ip}")
return True
config["network"]["vm_ip"] = new_ip
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
print(f"✅ 已更新IP地址: {old_ip}{new_ip}")
return True
except Exception as e:
print(f"❌ 更新配置文件失败: {e}")
return False
def test_http_connection(ip, port=5000, timeout=5):
"""
测试HTTP连接是否可用
Args:
ip: VM的IP地址
port: 端口号
timeout: 超时时间(秒)
Returns:
bool: 连接是否成功
"""
try:
import requests
except ImportError:
print(f"⚠️ requests模块未安装跳过HTTP连接测试")
return False
url = f"http://{ip}:{port}/screen_info"
proxies = {'http': None, 'https': None} # 绕过代理
try:
response = requests.get(url, timeout=timeout, proxies=proxies)
if response.status_code == 200:
print(f"✅ HTTP服务连接成功: {url}")
return True
else:
print(f"⚠️ HTTP服务响应异常: 状态码 {response.status_code}")
return False
except requests.exceptions.Timeout:
print(f"⚠️ HTTP服务连接超时: {url}")
return False
except requests.exceptions.ConnectionError as e:
print(f"⚠️ HTTP服务连接失败: {url}")
print(f" 错误: {str(e)[:100]}")
return False
except Exception as e:
print(f"⚠️ HTTP连接异常: {e}")
return False
def main():
"""主函数"""
print("=" * 60)
print("🔍 检测VM IP地址")
print("=" * 60)
# 加载配置
try:
config = load_config()
vmx_path = config["vmware"]["vmx_path"]
vm_password = config["vmware"].get("vm_password")
current_ip = config["network"].get("vm_ip", "未知")
except Exception as e:
print(f"❌ 加载配置失败: {e}")
sys.exit(1)
print(f"\n📋 当前配置:")
print(f" VM路径: {os.path.basename(vmx_path)}")
print(f" 当前IP: {current_ip}")
print(f" 端口: {config['network'].get('agent_server_port', 5000)}")
# 获取VM IP
print(f"\n🔍 正在获取VM IP地址...")
vm_ip = get_vm_ip(vmx_path, vm_password)
if not vm_ip:
print("\n❌ 无法获取VM IP地址")
print(" 可能原因:")
print(" 1. VM未运行")
print(" 2. VM网络未配置")
print(" 3. vmrun命令执行失败")
sys.exit(1)
print(f"✅ 检测到VM IP: {vm_ip}")
# 测试HTTP连接
port = config["network"].get("agent_server_port", 5000)
print(f"\n🔗 测试HTTP连接 (端口 {port})...")
http_ok = test_http_connection(vm_ip, port)
# 询问是否更新配置
if vm_ip != current_ip:
print(f"\n⚠️ IP地址已变化: {current_ip}{vm_ip}")
if http_ok:
print(f"\n❓ 是否更新配置文件? (y/n): ", end="")
try:
choice = input().strip().lower()
if choice == 'y':
if update_config_ip(vm_ip):
print(f"\n✅ 配置已更新!")
else:
print(f"\n❌ 配置更新失败")
else:
print(f"\n⏭️ 跳过更新")
except KeyboardInterrupt:
print(f"\n\n⚠️ 用户取消")
else:
print(f"\n⚠️ HTTP服务不可用请检查:")
print(f" 1. VM中是否运行了 agent_server.py?")
print(f" 2. 端口 {port} 是否被占用?")
print(f" 3. 防火墙是否阻止了连接?")
else:
print(f"\n✅ IP地址未变化")
if not http_ok:
print(f"\n⚠️ 但HTTP服务不可用请检查VM中的agent_server.py")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,271 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
从教程信息快速生成任务定义
用法:
python scripts/tools/extract_task_from_tutorial.py
"""
import json
import os
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from scripts.tools.init_task import init_task
# 任务类别和难度映射
CATEGORY_MAP = {
"1": "basic_processing",
"2": "peak_analysis",
"3": "phase_identification",
"4": "crystal_parameters",
"5": "calibration",
"6": "advanced_analysis",
}
DIFFICULTY_MAP = {
"1": "easy",
"2": "medium",
"3": "hard",
}
# 常见任务模板
TASK_TEMPLATES = {
"basic_processing": {
"open_file": "请打开桌面上的 {filename} 文件。",
"smooth": "请打开桌面上的 {filename} 文件,进行平滑处理 (Smoothing),然后将处理后的曲线导出为 ASCII (.txt) 文件并命名为 {output}",
"background": "请打开桌面上的 {filename} 文件,进行背景扣除 (Background Removal),然后将处理后的曲线导出为 ASCII (.txt) 文件并命名为 {output}",
"export": "请打开桌面上的 {filename} 文件,将当前曲线导出为 ASCII (.txt) 文件并命名为 {output}",
},
"peak_analysis": {
"peak_search": "请打开桌面上的 {filename} 文件,进行寻峰操作 (Peak Search),并导出寻峰结果文件 {output}",
"peak_separation": "请打开桌面上的 {filename} 文件,进行多峰分离操作 (Peak Separation),并导出结果文件 {output}",
"peak_fitting": "请打开桌面上的 {filename} 文件,进行峰形拟合 (Peak Fitting),并导出结果文件 {output}",
},
"phase_identification": {
"phase_search": "请打开桌面上的 {filename} 文件,进行物相检索 (Phase Search),并导出检索结果文件 {output}",
"quantitative": "请打开桌面上的 {filename} 文件,进行物相定量分析 (Quantitative Analysis),并导出结果文件 {output}",
},
"crystal_parameters": {
"lattice_constant": "请打开桌面上的 {filename} 文件,精确测定晶格常数 (Lattice Constant),并导出结果文件 {output}",
"crystal_size": "请打开桌面上的 {filename} 文件使用Scherrer公式计算晶粒大小 (Crystal Size),并导出结果文件 {output}",
"stress": "请打开桌面上的 {filename} 文件,进行残余应力分析 (Stress Analysis),并导出结果文件 {output}",
"crystallinity": "请打开桌面上的 {filename} 文件,计算结晶化度 (Crystallinity),并导出结果文件 {output}",
},
}
def print_category_menu():
"""打印类别菜单"""
print("\n📚 任务类别:")
print(" 1. basic_processing (基础处理)")
print(" 2. peak_analysis (峰分析)")
print(" 3. phase_identification (物相检索)")
print(" 4. crystal_parameters (晶体参数)")
print(" 5. calibration (校正)")
print(" 6. advanced_analysis (高级分析)")
def print_difficulty_menu():
"""打印难度菜单"""
print("\n📊 难度等级:")
print(" 1. easy (简单3-5步操作)")
print(" 2. medium (中等5-10步操作)")
print(" 3. hard (困难10+步操作)")
def get_user_input():
"""交互式获取用户输入"""
print("=" * 60)
print("🎯 从教程提取任务 - 快速生成工具")
print("=" * 60)
# 任务ID
task_id = input("\n📝 任务ID (例如: peak_search_001): ").strip()
if not task_id:
print("❌ 任务ID不能为空")
return None
# 类别
print_category_menu()
category_choice = input("\n选择类别 (1-6): ").strip()
category = CATEGORY_MAP.get(category_choice)
if not category:
print("❌ 无效的类别选择")
return None
# 难度
print_difficulty_menu()
difficulty_choice = input("\n选择难度 (1-3): ").strip()
difficulty = DIFFICULTY_MAP.get(difficulty_choice)
if not difficulty:
print("❌ 无效的难度选择")
return None
# 输入文件
print("\n📁 输入文件配置:")
source_file = input(" 源文件路径 (相对于data/source/, 例如: DEMO01.MDI): ").strip()
if not source_file:
source_file = "DEMO01.MDI"
# 输出文件
print("\n📤 输出文件配置:")
output_filename = input(" 输出文件名 (例如: result.txt): ").strip()
if not output_filename:
output_filename = "result.txt"
# 任务类型(如果类别有模板)
task_type = None
if category in TASK_TEMPLATES:
templates = TASK_TEMPLATES[category]
print(f"\n📋 可用任务模板 ({category}):")
for i, (key, template) in enumerate(templates.items(), 1):
print(f" {i}. {key}")
use_template = input("\n使用模板? (y/n, 默认n): ").strip().lower()
if use_template == 'y':
template_choice = input(f"选择模板 (1-{len(templates)}): ").strip()
try:
template_key = list(templates.keys())[int(template_choice) - 1]
task_type = template_key
except (ValueError, IndexError):
print("⚠️ 无效的模板选择,将使用自定义指令")
# 指令
if task_type and category in TASK_TEMPLATES:
# 使用模板
template = TASK_TEMPLATES[category][task_type]
instruction = template.format(
filename=os.path.basename(source_file),
output=output_filename
)
print(f"\n✅ 生成的指令 (模板): {instruction}")
confirm = input("使用此指令? (y/n, 默认y): ").strip().lower()
if confirm == 'n':
instruction = input("\n📝 自定义指令: ").strip()
else:
# 自定义指令
instruction = input("\n📝 任务指令 (中文描述): ").strip()
if not instruction:
print("❌ 指令不能为空")
return None
# 教程来源(可选)
tutorial_source = input("\n📚 教程来源 (可选,例如: 教程(1)): ").strip()
return {
"task_id": task_id,
"category": category,
"difficulty": difficulty,
"instruction": instruction,
"source_file": source_file,
"output_filename": output_filename,
"tutorial_source": tutorial_source,
}
def create_task_from_info(info):
"""根据信息创建任务"""
task_id = info["task_id"]
category = info["category"]
difficulty = info["difficulty"]
instruction = info["instruction"]
# 构建源文件路径
source_file = info["source_file"]
if not os.path.isabs(source_file):
# 相对路径,假设在 data/source/ 下
source_file = f"../../data/source/{source_file}"
# 构建VM路径
filename = os.path.basename(source_file)
inject_to = f"C:\\Users\\lzy\\Desktop\\{filename}"
# 输出文件路径
output_filename = info["output_filename"]
collect_from = f"C:\\Users\\lzy\\Desktop\\{output_filename}"
print(f"\n🚀 正在创建任务: {task_id}")
print(f" 类别: {category}")
print(f" 难度: {difficulty}")
print(f" 源文件: {source_file}")
print(f" 输出文件: {output_filename}")
# 调用 init_task
try:
init_task(
task_id=task_id,
category=category,
difficulty=difficulty,
instruction=instruction,
project_root=str(project_root)
)
# 更新 task.json
task_json_path = project_root / "tasks" / task_id / "task.json"
if task_json_path.exists():
with open(task_json_path, 'r', encoding='utf-8') as f:
task_config = json.load(f)
# 更新输入输出配置
task_config["input"] = {
"source_file": source_file,
"inject_to": inject_to
}
task_config["output"] = {
"expected_file": output_filename,
"collect_from": collect_from
}
# 添加教程来源(如果有)
if info.get("tutorial_source"):
task_config["tutorial_source"] = info["tutorial_source"]
# 保存
with open(task_json_path, 'w', encoding='utf-8') as f:
json.dump(task_config, f, ensure_ascii=False, indent=2)
print(f"\n✅ 任务创建成功!")
print(f" 任务目录: tasks/{task_id}/")
print(f" 配置文件: tasks/{task_id}/task.json")
print(f"\n📝 下一步:")
print(f" 1. 检查并完善 task.json")
print(f" 2. 运行: python scripts/tools/collect_task.py {task_id} --mode full")
return True
else:
print(f"❌ 任务目录创建失败: {task_json_path}")
return False
except Exception as e:
print(f"❌ 创建任务时出错: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""主函数"""
try:
info = get_user_input()
if info:
create_task_from_info(info)
else:
print("\n❌ 任务创建取消")
except KeyboardInterrupt:
print("\n\n⚠️ 用户取消操作")
except Exception as e:
print(f"\n❌ 发生错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

178
scripts/tools/init_task.py Normal file
View File

@@ -0,0 +1,178 @@
"""
任务初始化工具
快速创建新任务的目录结构和配置文件模板
"""
import os
import json
import sys
import argparse
TASK_JSON_TEMPLATE = {
"id": "",
"category": "basic_processing", # basic_processing, peak_analysis, phase_identification, compound_tasks
"difficulty": "easy", # easy, medium, hard
"instruction": "请填写任务指令",
"input": {
"source_file": "../../data/source/DEMO01.MDI",
"inject_to": "C:\\Users\\lzy\\Desktop\\DEMO01.MDI"
},
"output": {
"expected_file": "result.txt",
"collect_from": "C:\\Users\\lzy\\Desktop\\result.txt"
},
"evaluation": {
"method": "xrd_data_compare",
"ground_truth": "ground_truth/result.txt",
"target_output": "agent_output/result.txt",
"tolerance": 1e-4
}
}
def init_task(task_id, project_root=".", force=False, category=None, difficulty=None, instruction=None):
"""
初始化新任务
Args:
task_id: 任务ID
project_root: 项目根目录
force: 是否覆盖已存在的任务
category: 任务类别(可选)
difficulty: 任务难度(可选)
instruction: 任务指令(可选)
"""
task_dir = os.path.join(project_root, "tasks", task_id)
# 检查是否已存在
if os.path.exists(task_dir) and not force:
print(f"❌ 任务目录已存在: {task_dir}")
print(" 使用 --force 参数强制覆盖")
return False
print(f"创建任务: {task_id}")
print("=" * 60)
# 创建目录结构
directories = [
task_dir,
os.path.join(task_dir, "ground_truth"),
os.path.join(task_dir, "human_demo"),
os.path.join(task_dir, "human_demo", "screens"),
os.path.join(task_dir, "agent_output")
]
for directory in directories:
os.makedirs(directory, exist_ok=True)
print(f"✅ 创建目录: {os.path.relpath(directory, project_root)}")
# 创建task.json
task_config = TASK_JSON_TEMPLATE.copy()
task_config["id"] = task_id
# 更新可选参数
if category:
task_config["category"] = category
if difficulty:
task_config["difficulty"] = difficulty
if instruction:
task_config["instruction"] = instruction
task_json_path = os.path.join(task_dir, "task.json")
with open(task_json_path, 'w', encoding='utf-8') as f:
json.dump(task_config, f, indent=2, ensure_ascii=False)
print(f"✅ 创建配置: {os.path.relpath(task_json_path, project_root)}")
# 创建README
readme_content = f"""# 任务: {task_id}
## 任务信息
- **ID**: {task_id}
- **类别**: {task_config['category']}
- **难度**: {task_config['difficulty']}
## 指令
{task_config['instruction']}
## 数据采集状态
- [ ] 环境重置与文件注入
- [ ] 操作轨迹录制
- [ ] 结果文件收集
- [ ] 坐标转换处理
- [ ] 评测验证
## 采集命令
```bash
# 完整流程
python scripts/collect_task.py {task_id} --mode full
# 分步执行
python scripts/collect_task.py {task_id} --mode reset
python scripts/collect_task.py {task_id} --mode record
python scripts/collect_task.py {task_id} --mode collect
python scripts/process_trajectory.py {task_id}
python scripts/run_eval.py {task_id}
```
## 文件结构
```
{task_id}/
├── task.json # 任务配置
├── ground_truth/ # 标准答案输出
├── human_demo/ # 人类操作轨迹
│ ├── actions_raw.json # 原始轨迹(未转换坐标)
│ ├── actions.json # 处理后轨迹(已转换坐标)
│ └── screens/ # 截图序列
└── agent_output/ # Agent输出评测时使用
```
"""
readme_path = os.path.join(task_dir, "README.md")
with open(readme_path, 'w', encoding='utf-8') as f:
f.write(readme_content)
print(f"✅ 创建说明: {os.path.relpath(readme_path, project_root)}")
print("=" * 60)
print("✅ 任务初始化完成!")
print("\n📝 下一步:")
print(f" 1. 编辑任务配置: {task_json_path}")
print(f" 2. 确保输入文件存在:例如 {task_config['input']['source_file']}")
print(f" 3. 开始数据采集: python scripts/tools/collect_task.py {task_id}")
print("=" * 60 + "\n")
return True
def main():
parser = argparse.ArgumentParser(
description="初始化新任务",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
# 创建新任务
python scripts/init_task.py smoothing_001
# 强制覆盖已存在的任务
python scripts/init_task.py smoothing_001 --force
"""
)
parser.add_argument("task_id", help="任务ID建议格式: category_序号")
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--force", action="store_true", help="强制覆盖已存在的任务")
args = parser.parse_args()
success = init_task(args.task_id, args.project_root, args.force)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,165 @@
"""
轨迹数据后处理
将录制的原始Host坐标转换为VM内坐标
"""
import json
import os
import sys
import argparse
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def process_trajectory(task_id, project_root=".", force=False,
scale_x_adjust=1.0, scale_y_adjust=1.0,
offset_x=0, offset_y=0):
"""
处理轨迹数据将Host坐标转换为VM坐标
Args:
task_id: 任务ID
project_root: 项目根目录
force: 是否强制覆盖已有的处理结果
scale_x_adjust: X轴缩放调整系数
scale_y_adjust: Y轴缩放调整系数
offset_x: X轴偏移调整
offset_y: Y轴偏移调整
"""
# 路径
task_dir = os.path.join(project_root, "tasks", task_id)
human_demo_dir = os.path.join(task_dir, "human_demo")
raw_path = os.path.join(human_demo_dir, "actions_raw.json")
processed_path = os.path.join(human_demo_dir, "actions.json")
# 检查文件
if not os.path.exists(raw_path):
logger.error(f"❌ 原始轨迹文件不存在: {raw_path}")
logger.info(" 请先运行: python scripts/collect_task.py <task_id> --mode record")
return False
if os.path.exists(processed_path) and not force:
logger.warning(f"⚠️ 处理后的文件已存在: {processed_path}")
logger.info(" 使用 --force 参数强制覆盖")
return False
# 读取原始数据
logger.info(f"读取原始轨迹: {raw_path}")
with open(raw_path, 'r', encoding='utf-8') as f:
data = json.load(f)
metadata = data['metadata']
actions = data['actions']
logger.info(f"任务ID: {metadata['task_id']}")
logger.info(f"动作数: {len(actions)}")
# 获取分辨率信息
if 'vm_resolution' in metadata and 'vm_screenshot_resolution' in metadata:
vm_w, vm_h = metadata['vm_resolution']
screenshot_w, screenshot_h = metadata['vm_screenshot_resolution']
# 计算缩放比例
# 注意Host端的点击坐标对应截图坐标需要转换为VM内实际坐标
scale_x = (vm_w / screenshot_w) * scale_x_adjust
scale_y = (vm_h / screenshot_h) * scale_y_adjust
logger.info(f"VM分辨率: {vm_w}x{vm_h}")
logger.info(f"截图分辨率: {screenshot_w}x{screenshot_h}")
logger.info(f"转换比例: X={scale_x:.3f}, Y={scale_y:.3f}")
if scale_x_adjust != 1.0 or scale_y_adjust != 1.0:
logger.info(f"应用调整系数: X={scale_x_adjust}, Y={scale_y_adjust}")
if offset_x != 0 or offset_y != 0:
logger.info(f"应用偏移调整: X={offset_x}, Y={offset_y}")
else:
logger.warning("⚠️ 元数据缺少分辨率信息使用默认比例1.0")
scale_x = 1.0 * scale_x_adjust
scale_y = 1.0 * scale_y_adjust
# 转换坐标
converted_count = 0
for action in actions:
if 'pos_host' in action and action['pos_host']:
host_x, host_y = action['pos_host']
# 应用转换
vm_x = int(host_x * scale_x + offset_x)
vm_y = int(host_y * scale_y + offset_y)
action['pos_vm'] = [vm_x, vm_y]
converted_count += 1
logger.info(f"✅ 坐标转换完成: {converted_count}/{len(actions)} 个动作")
# 添加处理信息到元数据
metadata['processed'] = {
"processed_at": __import__('datetime').datetime.now().isoformat(),
"scale_x": scale_x,
"scale_y": scale_y,
"offset_x": offset_x,
"offset_y": offset_y,
"converted_actions": converted_count
}
# 保存处理后的数据
logger.info(f"保存处理后的轨迹: {processed_path}")
with open(processed_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
logger.info("✅ 处理完成!")
# 输出示例
print("\n" + "=" * 60)
print("📊 坐标转换示例前5个点击:")
print("-" * 60)
click_count = 0
for action in actions:
if action['type'] == 'click' and 'pos_host' in action:
host_x, host_y = action['pos_host']
vm_x, vm_y = action['pos_vm'] if action['pos_vm'] else (0, 0)
# 转换为整数显示
print(f" Host({int(host_x):4d}, {int(host_y):4d}) → VM({int(vm_x):4d}, {int(vm_y):4d})")
click_count += 1
if click_count >= 5:
break
print("=" * 60)
print("\n💡 下一步:可视化验证(可选)")
print(f" python scripts/visualize_trajectory.py {task_id}")
print("=" * 60 + "\n")
return True
def main():
parser = argparse.ArgumentParser(description="处理轨迹数据,转换坐标")
parser.add_argument("task_id", help="任务ID")
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--force", action="store_true", help="强制覆盖已有文件")
parser.add_argument("--scale-x", type=float, default=1.0, help="X轴缩放调整系数")
parser.add_argument("--scale-y", type=float, default=1.0, help="Y轴缩放调整系数")
parser.add_argument("--offset-x", type=int, default=0, help="X轴偏移调整")
parser.add_argument("--offset-y", type=int, default=0, help="Y轴偏移调整")
args = parser.parse_args()
success = process_trajectory(
task_id=args.task_id,
project_root=args.project_root,
force=args.force,
scale_x_adjust=args.scale_x,
scale_y_adjust=args.scale_y,
offset_x=args.offset_x,
offset_y=args.offset_y
)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

261
scripts/tools/run_eval.py Normal file
View File

@@ -0,0 +1,261 @@
"""
评测入口脚本
支持单任务或批量评测
"""
import os
import sys
import json
import argparse
import logging
from datetime import datetime
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.evaluator import evaluate
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
def evaluate_task(task_id, project_root=".", verbose=True):
"""
评测单个任务
Args:
task_id: 任务ID
project_root: 项目根目录
verbose: 是否详细输出
Returns:
tuple: (score, message, details)
"""
task_dir = os.path.join(project_root, "tasks", task_id)
task_json_path = os.path.join(task_dir, "task.json")
# 检查任务配置
if not os.path.exists(task_json_path):
logger.error(f"❌ 任务配置不存在: {task_json_path}")
return 0, "任务配置不存在", {}
# 加载任务配置
with open(task_json_path, 'r', encoding='utf-8') as f:
task_config = json.load(f)
if verbose:
print("\n" + "=" * 60)
print(f"📝 评测任务: {task_id}")
print("=" * 60)
print(f"类别: {task_config.get('category', 'N/A')}")
print(f"难度: {task_config.get('difficulty', 'N/A')}")
print(f"指令: {task_config.get('instruction', 'N/A')}")
print("=" * 60)
# 获取评测配置
eval_config = task_config.get('evaluation', {})
method = eval_config.get('method', 'xrd_data_compare')
# 构建文件路径
gt_path = os.path.join(task_dir, eval_config.get('ground_truth', ''))
agent_path = os.path.join(task_dir, eval_config.get('target_output', ''))
tolerance = eval_config.get('tolerance', 1e-4)
# 检查文件
if not os.path.exists(gt_path):
logger.error(f"❌ Ground truth文件不存在: {gt_path}")
return 0, "Ground truth文件不存在", {}
if not os.path.exists(agent_path):
logger.error(f"❌ Agent输出文件不存在: {agent_path}")
return 0, "Agent输出文件不存在", {}
# 执行评测
try:
if method == 'xrd_data_compare':
score, message = evaluate(gt_path, agent_path, tolerance, mode="xrd_data")
elif method == 'peak_report_compare':
score, message = evaluate(gt_path, agent_path, tolerance, mode="peak_report")
else:
logger.warning(f"⚠️ 未知的评测方法: {method}")
score, message = 0, f"未知的评测方法: {method}"
details = {
"task_id": task_id,
"method": method,
"ground_truth": gt_path,
"agent_output": agent_path,
"tolerance": tolerance,
"timestamp": datetime.now().isoformat()
}
if verbose:
print(f"\n📊 评测结果:")
print(f" Score: {score}")
print(f" {message}")
print("=" * 60 + "\n")
return score, message, details
except Exception as e:
logger.error(f"❌ 评测失败: {e}")
import traceback
traceback.print_exc()
return 0, f"评测失败: {str(e)}", {}
def evaluate_batch(task_ids, project_root=".", output_file=None):
"""
批量评测多个任务
Args:
task_ids: 任务ID列表
project_root: 项目根目录
output_file: 结果输出文件JSON格式
"""
print("\n" + "=" * 60)
print("📊 批量评测")
print("=" * 60)
print(f"任务数: {len(task_ids)}")
print("=" * 60 + "\n")
results = []
total_score = 0
for i, task_id in enumerate(task_ids, 1):
print(f"\n[{i}/{len(task_ids)}] 评测: {task_id}")
score, message, details = evaluate_task(task_id, project_root, verbose=False)
result = {
"task_id": task_id,
"score": score,
"message": message,
**details
}
results.append(result)
total_score += score
status = "✅ 通过" if score == 1 else "❌ 失败"
print(f" {status}: {message}")
# 统计
pass_count = sum(1 for r in results if r['score'] == 1)
pass_rate = pass_count / len(task_ids) * 100 if task_ids else 0
print("\n" + "=" * 60)
print("📈 评测统计")
print("=" * 60)
print(f"总任务数: {len(task_ids)}")
print(f"通过数: {pass_count}")
print(f"失败数: {len(task_ids) - pass_count}")
print(f"通过率: {pass_rate:.1f}%")
print(f"平均分: {total_score / len(task_ids):.2f}")
print("=" * 60 + "\n")
# 保存结果
if output_file:
output_data = {
"timestamp": datetime.now().isoformat(),
"total_tasks": len(task_ids),
"pass_count": pass_count,
"pass_rate": pass_rate,
"results": results
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"📄 详细结果已保存到: {output_file}\n")
return results
def discover_tasks(project_root="."):
"""
自动发现所有任务
Returns:
list: 任务ID列表
"""
tasks_dir = os.path.join(project_root, "tasks")
if not os.path.exists(tasks_dir):
return []
task_ids = []
for item in os.listdir(tasks_dir):
task_dir = os.path.join(tasks_dir, item)
task_json = os.path.join(task_dir, "task.json")
if os.path.isdir(task_dir) and os.path.exists(task_json):
task_ids.append(item)
return sorted(task_ids)
def main():
parser = argparse.ArgumentParser(
description="JADE Benchmark 评测工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
# 评测单个任务
python scripts/run_eval.py smoothing_001
# 评测多个任务
python scripts/run_eval.py smoothing_001 peak_search_001
# 评测所有任务
python scripts/run_eval.py --all
# 保存结果到文件
python scripts/run_eval.py --all --output results.json
"""
)
parser.add_argument("task_ids", nargs="*", help="任务ID列表")
parser.add_argument("--all", action="store_true", help="评测所有任务")
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--output", help="结果输出文件JSON格式")
args = parser.parse_args()
# 确定要评测的任务
if args.all:
task_ids = discover_tasks(args.project_root)
if not task_ids:
logger.error("❌ 未找到任何任务")
sys.exit(1)
logger.info(f"发现 {len(task_ids)} 个任务")
elif args.task_ids:
task_ids = args.task_ids
else:
parser.print_help()
sys.exit(1)
# 执行评测
try:
if len(task_ids) == 1:
# 单任务评测
score, message, _ = evaluate_task(task_ids[0], args.project_root)
sys.exit(0 if score == 1 else 1)
else:
# 批量评测
results = evaluate_batch(task_ids, args.project_root, args.output)
# 返回码全部通过返回0否则返回1
all_pass = all(r['score'] == 1 for r in results)
sys.exit(0 if all_pass else 1)
except KeyboardInterrupt:
print("\n\n⏹ 评测已取消")
sys.exit(1)
except Exception as e:
logger.error(f"❌ 错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

12
scripts/utils/__init__.py Normal file
View File

@@ -0,0 +1,12 @@
"""
JADE Benchmark 辅助工具
配置加载、文件处理等辅助功能
"""
from .config_loader import load_config, get_vm_config, get_network_config
__all__ = [
'load_config',
'get_vm_config',
'get_network_config'
]

View File

@@ -0,0 +1,51 @@
"""
配置文件加载器
"""
import json
import os
def load_config(config_path=None):
"""
加载配置文件
Args:
config_path: 配置文件路径默认在项目根目录的config.json
Returns:
dict: 配置字典
"""
if config_path is None:
# 查找项目根目录的config.json
current_dir = os.path.dirname(os.path.abspath(__file__))
scripts_dir = os.path.dirname(current_dir) # scripts/
project_root = os.path.dirname(scripts_dir) # Jade-BenchMark-MVP/
config_path = os.path.join(project_root, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"配置文件不存在: {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def get_vm_config(config=None):
"""获取VM配置"""
if config is None:
config = load_config()
return config.get('vmware', {})
def get_network_config(config=None):
"""获取网络配置"""
if config is None:
config = load_config()
return config.get('network', {})
if __name__ == "__main__":
# 测试配置加载
config = load_config()
print("配置加载成功:")
print(json.dumps(config, indent=2, ensure_ascii=False))