Initial commit
This commit is contained in:
178
scripts/core/evaluator.py
Normal file
178
scripts/core/evaluator.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
|
||||
def load_xrd_data(file_path):
|
||||
"""
|
||||
读取 XRD 导出的 txt 文件(如 background_result.txt),跳过头部的 Metadata。
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
data = []
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split()
|
||||
try:
|
||||
values = [float(x) for x in parts]
|
||||
if len(values) >= 2:
|
||||
data.append(values[:2])
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not data:
|
||||
raise ValueError(f"在文件中未找到有效的数值数据: {file_path}")
|
||||
|
||||
return np.array(data)
|
||||
|
||||
def load_peak_report(file_path):
|
||||
"""
|
||||
专门解析 JADE Peak Search Report (.pid)
|
||||
提取表格部分的数值数据
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
peaks = []
|
||||
metadata = {}
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 查找表格开始的位置(在含有 "2-Theta" 的行之后)
|
||||
table_started = False
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 提取元数据示例: JADE: Peak Search Report (72 Peaks, Max P/N = 19.3)
|
||||
if "Peak Search Report" in line:
|
||||
match = re.search(r"\((\d+) Peaks", line)
|
||||
if match:
|
||||
metadata['reported_peak_count'] = int(match.group(1))
|
||||
|
||||
# 识别表头
|
||||
if "2-Theta" in line and "d(" in line:
|
||||
table_started = True
|
||||
continue
|
||||
|
||||
if table_started:
|
||||
parts = line.split()
|
||||
# Peak 报告通常每行有 8-9 个字段
|
||||
if len(parts) >= 6:
|
||||
try:
|
||||
# 尝试将前几个字段转为 float
|
||||
# 0: 2-Theta, 1: d, 2: BG, 3: Height, 4: I%, 5: Area, 6: I%, 7: FWHM
|
||||
peak_data = [float(p) for p in parts[:8]]
|
||||
peaks.append(peak_data)
|
||||
except ValueError:
|
||||
# 如果转换失败,可能是说明文字或空行,跳过
|
||||
continue
|
||||
|
||||
if not peaks:
|
||||
raise ValueError(f"未能从 Peak 报告中解析出有效数据: {file_path}")
|
||||
|
||||
return np.array(peaks), metadata
|
||||
|
||||
def evaluate(gt_path, agent_path, tolerance=1e-4, mode="xrd_data"):
|
||||
"""
|
||||
对比 Ground Truth 和 Agent Output。
|
||||
支持两种模式:
|
||||
- xrd_data: 对比 (2-Theta, Intensity) 原始数据点
|
||||
- peak_report: 对比 Peak 搜索结果列表
|
||||
"""
|
||||
try:
|
||||
if mode == "peak_report":
|
||||
gt_data, gt_meta = load_peak_report(gt_path)
|
||||
agent_data, agent_meta = load_peak_report(agent_path)
|
||||
|
||||
# 对于 Peak 报告,我们主要关注 2-Theta 位置和 Height
|
||||
# 这里对比全表,但放宽容差,因为 Peak Search 的算法可能在不同环境下有极细微差异
|
||||
if gt_data.shape != agent_data.shape:
|
||||
# 如果数量不匹配,直接判定失败
|
||||
return 0, f"失败: Peak 数量不匹配。GT {len(gt_data)}, Agent {len(agent_data)}"
|
||||
else:
|
||||
gt_data = load_xrd_data(gt_path)
|
||||
agent_data = load_xrd_data(agent_path)
|
||||
|
||||
if gt_data.shape != agent_data.shape:
|
||||
return 0, f"失败: 数据维度不匹配。GT 形状 {gt_data.shape}, Agent 形状 {agent_data.shape}"
|
||||
|
||||
diff = np.abs(gt_data - agent_data)
|
||||
max_error = np.max(diff)
|
||||
|
||||
if max_error < tolerance:
|
||||
return 1, f"成功: 最大绝对误差 {max_error:.2e} < 阈值 {tolerance}"
|
||||
else:
|
||||
return 0, f"失败: 最大绝对误差 {max_error:.2e} 超过阈值 {tolerance}"
|
||||
|
||||
except Exception as e:
|
||||
return 0, f"错误: {str(e)}"
|
||||
|
||||
def evaluate_by_config(config_path):
|
||||
"""
|
||||
根据任务配置文件进行评测。
|
||||
"""
|
||||
# 兼容性处理:如果传入的是任务 ID 路径
|
||||
if not os.path.isabs(config_path) and not config_path.startswith('.'):
|
||||
# 尝试补全路径,例如 instructions/smoothing_001.json
|
||||
if not config_path.endswith('.json'):
|
||||
config_path = config_path + ".json"
|
||||
if not os.path.exists(config_path):
|
||||
# 尝试在任务目录下找
|
||||
pass
|
||||
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(config_path)))
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
eval_cfg = config.get('evaluation', {})
|
||||
# 处理相对路径
|
||||
gt_path = os.path.join(base_dir, eval_cfg['ground_truth'])
|
||||
agent_path = os.path.join(base_dir, eval_cfg['target_output'])
|
||||
tolerance = eval_cfg.get('tolerance', 1e-4)
|
||||
|
||||
# 自动识别模式
|
||||
mode = "xrd_data"
|
||||
if gt_path.lower().endswith('.pid') or eval_cfg.get('type') == 'peak_report':
|
||||
mode = "peak_report"
|
||||
# Peak 报告的默认容差放宽一些,因为算法可能受环境微小影响
|
||||
if 'tolerance' not in eval_cfg:
|
||||
tolerance = 1e-2
|
||||
|
||||
print(f"--- 正在执行评测: {config.get('id', 'unknown')} ---")
|
||||
print(f"指令: {config.get('instruction')}")
|
||||
print(f"模式: {mode}")
|
||||
|
||||
score, message = evaluate(gt_path, agent_path, tolerance, mode=mode)
|
||||
return score, message
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# 模式 A: Config 模式
|
||||
try:
|
||||
score, message = evaluate_by_config(sys.argv[1])
|
||||
except Exception as e:
|
||||
print(f"配置文件解析失败: {e}")
|
||||
sys.exit(1)
|
||||
elif len(sys.argv) >= 3:
|
||||
# 模式 B: 直接对比模式
|
||||
score, message = evaluate(sys.argv[1], sys.argv[2])
|
||||
else:
|
||||
print("用法:")
|
||||
print(" python scripts/evaluator.py instructions/smoothing_001.json")
|
||||
print(" python scripts/evaluator.py <gt_file_path> <agent_file_path>")
|
||||
sys.exit(1)
|
||||
|
||||
print("-" * 30)
|
||||
print(f"Score: {score}")
|
||||
print(f"Reason: {message}")
|
||||
print("-" * 30)
|
||||
sys.exit(0 if score == 1 else 1)
|
||||
Reference in New Issue
Block a user