179 lines
6.3 KiB
Python
179 lines
6.3 KiB
Python
import json
|
||
import numpy as np
|
||
import os
|
||
import sys
|
||
import re
|
||
|
||
def load_xrd_data(file_path):
|
||
"""
|
||
读取 XRD 导出的 txt 文件(如 background_result.txt),跳过头部的 Metadata。
|
||
"""
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||
|
||
data = []
|
||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
parts = line.split()
|
||
try:
|
||
values = [float(x) for x in parts]
|
||
if len(values) >= 2:
|
||
data.append(values[:2])
|
||
except ValueError:
|
||
continue
|
||
|
||
if not data:
|
||
raise ValueError(f"在文件中未找到有效的数值数据: {file_path}")
|
||
|
||
return np.array(data)
|
||
|
||
def load_peak_report(file_path):
|
||
"""
|
||
专门解析 JADE Peak Search Report (.pid)
|
||
提取表格部分的数值数据
|
||
"""
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||
|
||
peaks = []
|
||
metadata = {}
|
||
|
||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||
lines = f.readlines()
|
||
|
||
# 查找表格开始的位置(在含有 "2-Theta" 的行之后)
|
||
table_started = False
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
# 提取元数据示例: JADE: Peak Search Report (72 Peaks, Max P/N = 19.3)
|
||
if "Peak Search Report" in line:
|
||
match = re.search(r"\((\d+) Peaks", line)
|
||
if match:
|
||
metadata['reported_peak_count'] = int(match.group(1))
|
||
|
||
# 识别表头
|
||
if "2-Theta" in line and "d(" in line:
|
||
table_started = True
|
||
continue
|
||
|
||
if table_started:
|
||
parts = line.split()
|
||
# Peak 报告通常每行有 8-9 个字段
|
||
if len(parts) >= 6:
|
||
try:
|
||
# 尝试将前几个字段转为 float
|
||
# 0: 2-Theta, 1: d, 2: BG, 3: Height, 4: I%, 5: Area, 6: I%, 7: FWHM
|
||
peak_data = [float(p) for p in parts[:8]]
|
||
peaks.append(peak_data)
|
||
except ValueError:
|
||
# 如果转换失败,可能是说明文字或空行,跳过
|
||
continue
|
||
|
||
if not peaks:
|
||
raise ValueError(f"未能从 Peak 报告中解析出有效数据: {file_path}")
|
||
|
||
return np.array(peaks), metadata
|
||
|
||
def evaluate(gt_path, agent_path, tolerance=1e-4, mode="xrd_data"):
|
||
"""
|
||
对比 Ground Truth 和 Agent Output。
|
||
支持两种模式:
|
||
- xrd_data: 对比 (2-Theta, Intensity) 原始数据点
|
||
- peak_report: 对比 Peak 搜索结果列表
|
||
"""
|
||
try:
|
||
if mode == "peak_report":
|
||
gt_data, gt_meta = load_peak_report(gt_path)
|
||
agent_data, agent_meta = load_peak_report(agent_path)
|
||
|
||
# 对于 Peak 报告,我们主要关注 2-Theta 位置和 Height
|
||
# 这里对比全表,但放宽容差,因为 Peak Search 的算法可能在不同环境下有极细微差异
|
||
if gt_data.shape != agent_data.shape:
|
||
# 如果数量不匹配,直接判定失败
|
||
return 0, f"失败: Peak 数量不匹配。GT {len(gt_data)}, Agent {len(agent_data)}"
|
||
else:
|
||
gt_data = load_xrd_data(gt_path)
|
||
agent_data = load_xrd_data(agent_path)
|
||
|
||
if gt_data.shape != agent_data.shape:
|
||
return 0, f"失败: 数据维度不匹配。GT 形状 {gt_data.shape}, Agent 形状 {agent_data.shape}"
|
||
|
||
diff = np.abs(gt_data - agent_data)
|
||
max_error = np.max(diff)
|
||
|
||
if max_error < tolerance:
|
||
return 1, f"成功: 最大绝对误差 {max_error:.2e} < 阈值 {tolerance}"
|
||
else:
|
||
return 0, f"失败: 最大绝对误差 {max_error:.2e} 超过阈值 {tolerance}"
|
||
|
||
except Exception as e:
|
||
return 0, f"错误: {str(e)}"
|
||
|
||
def evaluate_by_config(config_path):
|
||
"""
|
||
根据任务配置文件进行评测。
|
||
"""
|
||
# 兼容性处理:如果传入的是任务 ID 路径
|
||
if not os.path.isabs(config_path) and not config_path.startswith('.'):
|
||
# 尝试补全路径,例如 instructions/smoothing_001.json
|
||
if not config_path.endswith('.json'):
|
||
config_path = config_path + ".json"
|
||
if not os.path.exists(config_path):
|
||
# 尝试在任务目录下找
|
||
pass
|
||
|
||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(config_path)))
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
eval_cfg = config.get('evaluation', {})
|
||
# 处理相对路径
|
||
gt_path = os.path.join(base_dir, eval_cfg['ground_truth'])
|
||
agent_path = os.path.join(base_dir, eval_cfg['target_output'])
|
||
tolerance = eval_cfg.get('tolerance', 1e-4)
|
||
|
||
# 自动识别模式
|
||
mode = "xrd_data"
|
||
if gt_path.lower().endswith('.pid') or eval_cfg.get('type') == 'peak_report':
|
||
mode = "peak_report"
|
||
# Peak 报告的默认容差放宽一些,因为算法可能受环境微小影响
|
||
if 'tolerance' not in eval_cfg:
|
||
tolerance = 1e-2
|
||
|
||
print(f"--- 正在执行评测: {config.get('id', 'unknown')} ---")
|
||
print(f"指令: {config.get('instruction')}")
|
||
print(f"模式: {mode}")
|
||
|
||
score, message = evaluate(gt_path, agent_path, tolerance, mode=mode)
|
||
return score, message
|
||
|
||
if __name__ == "__main__":
|
||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||
# 模式 A: Config 模式
|
||
try:
|
||
score, message = evaluate_by_config(sys.argv[1])
|
||
except Exception as e:
|
||
print(f"配置文件解析失败: {e}")
|
||
sys.exit(1)
|
||
elif len(sys.argv) >= 3:
|
||
# 模式 B: 直接对比模式
|
||
score, message = evaluate(sys.argv[1], sys.argv[2])
|
||
else:
|
||
print("用法:")
|
||
print(" python scripts/evaluator.py instructions/smoothing_001.json")
|
||
print(" python scripts/evaluator.py <gt_file_path> <agent_file_path>")
|
||
sys.exit(1)
|
||
|
||
print("-" * 30)
|
||
print(f"Score: {score}")
|
||
print(f"Reason: {message}")
|
||
print("-" * 30)
|
||
sys.exit(0 if score == 1 else 1)
|