Files
sci-gui-agent-benchmark/scripts/core/evaluator.py
2026-01-12 18:30:12 +08:00

179 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import numpy as np
import os
import sys
import re
def load_xrd_data(file_path):
"""
读取 XRD 导出的 txt 文件(如 background_result.txt跳过头部的 Metadata。
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
data = []
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
try:
values = [float(x) for x in parts]
if len(values) >= 2:
data.append(values[:2])
except ValueError:
continue
if not data:
raise ValueError(f"在文件中未找到有效的数值数据: {file_path}")
return np.array(data)
def load_peak_report(file_path):
"""
专门解析 JADE Peak Search Report (.pid)
提取表格部分的数值数据
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
peaks = []
metadata = {}
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
# 查找表格开始的位置(在含有 "2-Theta" 的行之后)
table_started = False
for line in lines:
line = line.strip()
if not line:
continue
# 提取元数据示例: JADE: Peak Search Report (72 Peaks, Max P/N = 19.3)
if "Peak Search Report" in line:
match = re.search(r"\((\d+) Peaks", line)
if match:
metadata['reported_peak_count'] = int(match.group(1))
# 识别表头
if "2-Theta" in line and "d(" in line:
table_started = True
continue
if table_started:
parts = line.split()
# Peak 报告通常每行有 8-9 个字段
if len(parts) >= 6:
try:
# 尝试将前几个字段转为 float
# 0: 2-Theta, 1: d, 2: BG, 3: Height, 4: I%, 5: Area, 6: I%, 7: FWHM
peak_data = [float(p) for p in parts[:8]]
peaks.append(peak_data)
except ValueError:
# 如果转换失败,可能是说明文字或空行,跳过
continue
if not peaks:
raise ValueError(f"未能从 Peak 报告中解析出有效数据: {file_path}")
return np.array(peaks), metadata
def evaluate(gt_path, agent_path, tolerance=1e-4, mode="xrd_data"):
"""
对比 Ground Truth 和 Agent Output。
支持两种模式:
- xrd_data: 对比 (2-Theta, Intensity) 原始数据点
- peak_report: 对比 Peak 搜索结果列表
"""
try:
if mode == "peak_report":
gt_data, gt_meta = load_peak_report(gt_path)
agent_data, agent_meta = load_peak_report(agent_path)
# 对于 Peak 报告,我们主要关注 2-Theta 位置和 Height
# 这里对比全表,但放宽容差,因为 Peak Search 的算法可能在不同环境下有极细微差异
if gt_data.shape != agent_data.shape:
# 如果数量不匹配,直接判定失败
return 0, f"失败: Peak 数量不匹配。GT {len(gt_data)}, Agent {len(agent_data)}"
else:
gt_data = load_xrd_data(gt_path)
agent_data = load_xrd_data(agent_path)
if gt_data.shape != agent_data.shape:
return 0, f"失败: 数据维度不匹配。GT 形状 {gt_data.shape}, Agent 形状 {agent_data.shape}"
diff = np.abs(gt_data - agent_data)
max_error = np.max(diff)
if max_error < tolerance:
return 1, f"成功: 最大绝对误差 {max_error:.2e} < 阈值 {tolerance}"
else:
return 0, f"失败: 最大绝对误差 {max_error:.2e} 超过阈值 {tolerance}"
except Exception as e:
return 0, f"错误: {str(e)}"
def evaluate_by_config(config_path):
"""
根据任务配置文件进行评测。
"""
# 兼容性处理:如果传入的是任务 ID 路径
if not os.path.isabs(config_path) and not config_path.startswith('.'):
# 尝试补全路径,例如 instructions/smoothing_001.json
if not config_path.endswith('.json'):
config_path = config_path + ".json"
if not os.path.exists(config_path):
# 尝试在任务目录下找
pass
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(config_path)))
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
eval_cfg = config.get('evaluation', {})
# 处理相对路径
gt_path = os.path.join(base_dir, eval_cfg['ground_truth'])
agent_path = os.path.join(base_dir, eval_cfg['target_output'])
tolerance = eval_cfg.get('tolerance', 1e-4)
# 自动识别模式
mode = "xrd_data"
if gt_path.lower().endswith('.pid') or eval_cfg.get('type') == 'peak_report':
mode = "peak_report"
# Peak 报告的默认容差放宽一些,因为算法可能受环境微小影响
if 'tolerance' not in eval_cfg:
tolerance = 1e-2
print(f"--- 正在执行评测: {config.get('id', 'unknown')} ---")
print(f"指令: {config.get('instruction')}")
print(f"模式: {mode}")
score, message = evaluate(gt_path, agent_path, tolerance, mode=mode)
return score, message
if __name__ == "__main__":
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
# 模式 A: Config 模式
try:
score, message = evaluate_by_config(sys.argv[1])
except Exception as e:
print(f"配置文件解析失败: {e}")
sys.exit(1)
elif len(sys.argv) >= 3:
# 模式 B: 直接对比模式
score, message = evaluate(sys.argv[1], sys.argv[2])
else:
print("用法:")
print(" python scripts/evaluator.py instructions/smoothing_001.json")
print(" python scripts/evaluator.py <gt_file_path> <agent_file_path>")
sys.exit(1)
print("-" * 30)
print(f"Score: {score}")
print(f"Reason: {message}")
print("-" * 30)
sys.exit(0 if score == 1 else 1)