Initial commit

This commit is contained in:
2026-01-12 18:30:12 +08:00
commit 214e15c04c
102 changed files with 27857 additions and 0 deletions

178
scripts/core/evaluator.py Normal file
View File

@@ -0,0 +1,178 @@
import json
import numpy as np
import os
import sys
import re
def load_xrd_data(file_path):
"""
读取 XRD 导出的 txt 文件(如 background_result.txt跳过头部的 Metadata。
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
data = []
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
try:
values = [float(x) for x in parts]
if len(values) >= 2:
data.append(values[:2])
except ValueError:
continue
if not data:
raise ValueError(f"在文件中未找到有效的数值数据: {file_path}")
return np.array(data)
def load_peak_report(file_path):
"""
专门解析 JADE Peak Search Report (.pid)
提取表格部分的数值数据
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
peaks = []
metadata = {}
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
# 查找表格开始的位置(在含有 "2-Theta" 的行之后)
table_started = False
for line in lines:
line = line.strip()
if not line:
continue
# 提取元数据示例: JADE: Peak Search Report (72 Peaks, Max P/N = 19.3)
if "Peak Search Report" in line:
match = re.search(r"\((\d+) Peaks", line)
if match:
metadata['reported_peak_count'] = int(match.group(1))
# 识别表头
if "2-Theta" in line and "d(" in line:
table_started = True
continue
if table_started:
parts = line.split()
# Peak 报告通常每行有 8-9 个字段
if len(parts) >= 6:
try:
# 尝试将前几个字段转为 float
# 0: 2-Theta, 1: d, 2: BG, 3: Height, 4: I%, 5: Area, 6: I%, 7: FWHM
peak_data = [float(p) for p in parts[:8]]
peaks.append(peak_data)
except ValueError:
# 如果转换失败,可能是说明文字或空行,跳过
continue
if not peaks:
raise ValueError(f"未能从 Peak 报告中解析出有效数据: {file_path}")
return np.array(peaks), metadata
def evaluate(gt_path, agent_path, tolerance=1e-4, mode="xrd_data"):
"""
对比 Ground Truth 和 Agent Output。
支持两种模式:
- xrd_data: 对比 (2-Theta, Intensity) 原始数据点
- peak_report: 对比 Peak 搜索结果列表
"""
try:
if mode == "peak_report":
gt_data, gt_meta = load_peak_report(gt_path)
agent_data, agent_meta = load_peak_report(agent_path)
# 对于 Peak 报告,我们主要关注 2-Theta 位置和 Height
# 这里对比全表,但放宽容差,因为 Peak Search 的算法可能在不同环境下有极细微差异
if gt_data.shape != agent_data.shape:
# 如果数量不匹配,直接判定失败
return 0, f"失败: Peak 数量不匹配。GT {len(gt_data)}, Agent {len(agent_data)}"
else:
gt_data = load_xrd_data(gt_path)
agent_data = load_xrd_data(agent_path)
if gt_data.shape != agent_data.shape:
return 0, f"失败: 数据维度不匹配。GT 形状 {gt_data.shape}, Agent 形状 {agent_data.shape}"
diff = np.abs(gt_data - agent_data)
max_error = np.max(diff)
if max_error < tolerance:
return 1, f"成功: 最大绝对误差 {max_error:.2e} < 阈值 {tolerance}"
else:
return 0, f"失败: 最大绝对误差 {max_error:.2e} 超过阈值 {tolerance}"
except Exception as e:
return 0, f"错误: {str(e)}"
def evaluate_by_config(config_path):
"""
根据任务配置文件进行评测。
"""
# 兼容性处理:如果传入的是任务 ID 路径
if not os.path.isabs(config_path) and not config_path.startswith('.'):
# 尝试补全路径,例如 instructions/smoothing_001.json
if not config_path.endswith('.json'):
config_path = config_path + ".json"
if not os.path.exists(config_path):
# 尝试在任务目录下找
pass
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(config_path)))
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
eval_cfg = config.get('evaluation', {})
# 处理相对路径
gt_path = os.path.join(base_dir, eval_cfg['ground_truth'])
agent_path = os.path.join(base_dir, eval_cfg['target_output'])
tolerance = eval_cfg.get('tolerance', 1e-4)
# 自动识别模式
mode = "xrd_data"
if gt_path.lower().endswith('.pid') or eval_cfg.get('type') == 'peak_report':
mode = "peak_report"
# Peak 报告的默认容差放宽一些,因为算法可能受环境微小影响
if 'tolerance' not in eval_cfg:
tolerance = 1e-2
print(f"--- 正在执行评测: {config.get('id', 'unknown')} ---")
print(f"指令: {config.get('instruction')}")
print(f"模式: {mode}")
score, message = evaluate(gt_path, agent_path, tolerance, mode=mode)
return score, message
if __name__ == "__main__":
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
# 模式 A: Config 模式
try:
score, message = evaluate_by_config(sys.argv[1])
except Exception as e:
print(f"配置文件解析失败: {e}")
sys.exit(1)
elif len(sys.argv) >= 3:
# 模式 B: 直接对比模式
score, message = evaluate(sys.argv[1], sys.argv[2])
else:
print("用法:")
print(" python scripts/evaluator.py instructions/smoothing_001.json")
print(" python scripts/evaluator.py <gt_file_path> <agent_file_path>")
sys.exit(1)
print("-" * 30)
print(f"Score: {score}")
print(f"Reason: {message}")
print("-" * 30)
sys.exit(0 if score == 1 else 1)