Files
sci-gui-agent-benchmark/scripts/tools/run_eval.py
2026-01-12 18:30:12 +08:00

262 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
评测入口脚本
支持单任务或批量评测
"""
import os
import sys
import json
import argparse
import logging
from datetime import datetime
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.evaluator import evaluate
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
def evaluate_task(task_id, project_root=".", verbose=True):
"""
评测单个任务
Args:
task_id: 任务ID
project_root: 项目根目录
verbose: 是否详细输出
Returns:
tuple: (score, message, details)
"""
task_dir = os.path.join(project_root, "tasks", task_id)
task_json_path = os.path.join(task_dir, "task.json")
# 检查任务配置
if not os.path.exists(task_json_path):
logger.error(f"❌ 任务配置不存在: {task_json_path}")
return 0, "任务配置不存在", {}
# 加载任务配置
with open(task_json_path, 'r', encoding='utf-8') as f:
task_config = json.load(f)
if verbose:
print("\n" + "=" * 60)
print(f"📝 评测任务: {task_id}")
print("=" * 60)
print(f"类别: {task_config.get('category', 'N/A')}")
print(f"难度: {task_config.get('difficulty', 'N/A')}")
print(f"指令: {task_config.get('instruction', 'N/A')}")
print("=" * 60)
# 获取评测配置
eval_config = task_config.get('evaluation', {})
method = eval_config.get('method', 'xrd_data_compare')
# 构建文件路径
gt_path = os.path.join(task_dir, eval_config.get('ground_truth', ''))
agent_path = os.path.join(task_dir, eval_config.get('target_output', ''))
tolerance = eval_config.get('tolerance', 1e-4)
# 检查文件
if not os.path.exists(gt_path):
logger.error(f"❌ Ground truth文件不存在: {gt_path}")
return 0, "Ground truth文件不存在", {}
if not os.path.exists(agent_path):
logger.error(f"❌ Agent输出文件不存在: {agent_path}")
return 0, "Agent输出文件不存在", {}
# 执行评测
try:
if method == 'xrd_data_compare':
score, message = evaluate(gt_path, agent_path, tolerance, mode="xrd_data")
elif method == 'peak_report_compare':
score, message = evaluate(gt_path, agent_path, tolerance, mode="peak_report")
else:
logger.warning(f"⚠️ 未知的评测方法: {method}")
score, message = 0, f"未知的评测方法: {method}"
details = {
"task_id": task_id,
"method": method,
"ground_truth": gt_path,
"agent_output": agent_path,
"tolerance": tolerance,
"timestamp": datetime.now().isoformat()
}
if verbose:
print(f"\n📊 评测结果:")
print(f" Score: {score}")
print(f" {message}")
print("=" * 60 + "\n")
return score, message, details
except Exception as e:
logger.error(f"❌ 评测失败: {e}")
import traceback
traceback.print_exc()
return 0, f"评测失败: {str(e)}", {}
def evaluate_batch(task_ids, project_root=".", output_file=None):
"""
批量评测多个任务
Args:
task_ids: 任务ID列表
project_root: 项目根目录
output_file: 结果输出文件JSON格式
"""
print("\n" + "=" * 60)
print("📊 批量评测")
print("=" * 60)
print(f"任务数: {len(task_ids)}")
print("=" * 60 + "\n")
results = []
total_score = 0
for i, task_id in enumerate(task_ids, 1):
print(f"\n[{i}/{len(task_ids)}] 评测: {task_id}")
score, message, details = evaluate_task(task_id, project_root, verbose=False)
result = {
"task_id": task_id,
"score": score,
"message": message,
**details
}
results.append(result)
total_score += score
status = "✅ 通过" if score == 1 else "❌ 失败"
print(f" {status}: {message}")
# 统计
pass_count = sum(1 for r in results if r['score'] == 1)
pass_rate = pass_count / len(task_ids) * 100 if task_ids else 0
print("\n" + "=" * 60)
print("📈 评测统计")
print("=" * 60)
print(f"总任务数: {len(task_ids)}")
print(f"通过数: {pass_count}")
print(f"失败数: {len(task_ids) - pass_count}")
print(f"通过率: {pass_rate:.1f}%")
print(f"平均分: {total_score / len(task_ids):.2f}")
print("=" * 60 + "\n")
# 保存结果
if output_file:
output_data = {
"timestamp": datetime.now().isoformat(),
"total_tasks": len(task_ids),
"pass_count": pass_count,
"pass_rate": pass_rate,
"results": results
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"📄 详细结果已保存到: {output_file}\n")
return results
def discover_tasks(project_root="."):
"""
自动发现所有任务
Returns:
list: 任务ID列表
"""
tasks_dir = os.path.join(project_root, "tasks")
if not os.path.exists(tasks_dir):
return []
task_ids = []
for item in os.listdir(tasks_dir):
task_dir = os.path.join(tasks_dir, item)
task_json = os.path.join(task_dir, "task.json")
if os.path.isdir(task_dir) and os.path.exists(task_json):
task_ids.append(item)
return sorted(task_ids)
def main():
parser = argparse.ArgumentParser(
description="JADE Benchmark 评测工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
# 评测单个任务
python scripts/run_eval.py smoothing_001
# 评测多个任务
python scripts/run_eval.py smoothing_001 peak_search_001
# 评测所有任务
python scripts/run_eval.py --all
# 保存结果到文件
python scripts/run_eval.py --all --output results.json
"""
)
parser.add_argument("task_ids", nargs="*", help="任务ID列表")
parser.add_argument("--all", action="store_true", help="评测所有任务")
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--output", help="结果输出文件JSON格式")
args = parser.parse_args()
# 确定要评测的任务
if args.all:
task_ids = discover_tasks(args.project_root)
if not task_ids:
logger.error("❌ 未找到任何任务")
sys.exit(1)
logger.info(f"发现 {len(task_ids)} 个任务")
elif args.task_ids:
task_ids = args.task_ids
else:
parser.print_help()
sys.exit(1)
# 执行评测
try:
if len(task_ids) == 1:
# 单任务评测
score, message, _ = evaluate_task(task_ids[0], args.project_root)
sys.exit(0 if score == 1 else 1)
else:
# 批量评测
results = evaluate_batch(task_ids, args.project_root, args.output)
# 返回码全部通过返回0否则返回1
all_pass = all(r['score'] == 1 for r in results)
sys.exit(0 if all_pass else 1)
except KeyboardInterrupt:
print("\n\n⏹ 评测已取消")
sys.exit(1)
except Exception as e:
logger.error(f"❌ 错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()