Initial commit
This commit is contained in:
261
scripts/tools/run_eval.py
Normal file
261
scripts/tools/run_eval.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
评测入口脚本
|
||||
支持单任务或批量评测
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.evaluator import evaluate
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def evaluate_task(task_id, project_root=".", verbose=True):
|
||||
"""
|
||||
评测单个任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
project_root: 项目根目录
|
||||
verbose: 是否详细输出
|
||||
|
||||
Returns:
|
||||
tuple: (score, message, details)
|
||||
"""
|
||||
task_dir = os.path.join(project_root, "tasks", task_id)
|
||||
task_json_path = os.path.join(task_dir, "task.json")
|
||||
|
||||
# 检查任务配置
|
||||
if not os.path.exists(task_json_path):
|
||||
logger.error(f"❌ 任务配置不存在: {task_json_path}")
|
||||
return 0, "任务配置不存在", {}
|
||||
|
||||
# 加载任务配置
|
||||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||||
task_config = json.load(f)
|
||||
|
||||
if verbose:
|
||||
print("\n" + "=" * 60)
|
||||
print(f"📝 评测任务: {task_id}")
|
||||
print("=" * 60)
|
||||
print(f"类别: {task_config.get('category', 'N/A')}")
|
||||
print(f"难度: {task_config.get('difficulty', 'N/A')}")
|
||||
print(f"指令: {task_config.get('instruction', 'N/A')}")
|
||||
print("=" * 60)
|
||||
|
||||
# 获取评测配置
|
||||
eval_config = task_config.get('evaluation', {})
|
||||
method = eval_config.get('method', 'xrd_data_compare')
|
||||
|
||||
# 构建文件路径
|
||||
gt_path = os.path.join(task_dir, eval_config.get('ground_truth', ''))
|
||||
agent_path = os.path.join(task_dir, eval_config.get('target_output', ''))
|
||||
tolerance = eval_config.get('tolerance', 1e-4)
|
||||
|
||||
# 检查文件
|
||||
if not os.path.exists(gt_path):
|
||||
logger.error(f"❌ Ground truth文件不存在: {gt_path}")
|
||||
return 0, "Ground truth文件不存在", {}
|
||||
|
||||
if not os.path.exists(agent_path):
|
||||
logger.error(f"❌ Agent输出文件不存在: {agent_path}")
|
||||
return 0, "Agent输出文件不存在", {}
|
||||
|
||||
# 执行评测
|
||||
try:
|
||||
if method == 'xrd_data_compare':
|
||||
score, message = evaluate(gt_path, agent_path, tolerance, mode="xrd_data")
|
||||
elif method == 'peak_report_compare':
|
||||
score, message = evaluate(gt_path, agent_path, tolerance, mode="peak_report")
|
||||
else:
|
||||
logger.warning(f"⚠️ 未知的评测方法: {method}")
|
||||
score, message = 0, f"未知的评测方法: {method}"
|
||||
|
||||
details = {
|
||||
"task_id": task_id,
|
||||
"method": method,
|
||||
"ground_truth": gt_path,
|
||||
"agent_output": agent_path,
|
||||
"tolerance": tolerance,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(f"\n📊 评测结果:")
|
||||
print(f" Score: {score}")
|
||||
print(f" {message}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
return score, message, details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 评测失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 0, f"评测失败: {str(e)}", {}
|
||||
|
||||
|
||||
def evaluate_batch(task_ids, project_root=".", output_file=None):
|
||||
"""
|
||||
批量评测多个任务
|
||||
|
||||
Args:
|
||||
task_ids: 任务ID列表
|
||||
project_root: 项目根目录
|
||||
output_file: 结果输出文件(JSON格式)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 批量评测")
|
||||
print("=" * 60)
|
||||
print(f"任务数: {len(task_ids)}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
results = []
|
||||
total_score = 0
|
||||
|
||||
for i, task_id in enumerate(task_ids, 1):
|
||||
print(f"\n[{i}/{len(task_ids)}] 评测: {task_id}")
|
||||
score, message, details = evaluate_task(task_id, project_root, verbose=False)
|
||||
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
"score": score,
|
||||
"message": message,
|
||||
**details
|
||||
}
|
||||
results.append(result)
|
||||
total_score += score
|
||||
|
||||
status = "✅ 通过" if score == 1 else "❌ 失败"
|
||||
print(f" {status}: {message}")
|
||||
|
||||
# 统计
|
||||
pass_count = sum(1 for r in results if r['score'] == 1)
|
||||
pass_rate = pass_count / len(task_ids) * 100 if task_ids else 0
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("📈 评测统计")
|
||||
print("=" * 60)
|
||||
print(f"总任务数: {len(task_ids)}")
|
||||
print(f"通过数: {pass_count}")
|
||||
print(f"失败数: {len(task_ids) - pass_count}")
|
||||
print(f"通过率: {pass_rate:.1f}%")
|
||||
print(f"平均分: {total_score / len(task_ids):.2f}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# 保存结果
|
||||
if output_file:
|
||||
output_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_tasks": len(task_ids),
|
||||
"pass_count": pass_count,
|
||||
"pass_rate": pass_rate,
|
||||
"results": results
|
||||
}
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"📄 详细结果已保存到: {output_file}\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def discover_tasks(project_root="."):
|
||||
"""
|
||||
自动发现所有任务
|
||||
|
||||
Returns:
|
||||
list: 任务ID列表
|
||||
"""
|
||||
tasks_dir = os.path.join(project_root, "tasks")
|
||||
|
||||
if not os.path.exists(tasks_dir):
|
||||
return []
|
||||
|
||||
task_ids = []
|
||||
for item in os.listdir(tasks_dir):
|
||||
task_dir = os.path.join(tasks_dir, item)
|
||||
task_json = os.path.join(task_dir, "task.json")
|
||||
|
||||
if os.path.isdir(task_dir) and os.path.exists(task_json):
|
||||
task_ids.append(item)
|
||||
|
||||
return sorted(task_ids)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="JADE Benchmark 评测工具",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
# 评测单个任务
|
||||
python scripts/run_eval.py smoothing_001
|
||||
|
||||
# 评测多个任务
|
||||
python scripts/run_eval.py smoothing_001 peak_search_001
|
||||
|
||||
# 评测所有任务
|
||||
python scripts/run_eval.py --all
|
||||
|
||||
# 保存结果到文件
|
||||
python scripts/run_eval.py --all --output results.json
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("task_ids", nargs="*", help="任务ID列表")
|
||||
parser.add_argument("--all", action="store_true", help="评测所有任务")
|
||||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||||
parser.add_argument("--output", help="结果输出文件(JSON格式)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确定要评测的任务
|
||||
if args.all:
|
||||
task_ids = discover_tasks(args.project_root)
|
||||
if not task_ids:
|
||||
logger.error("❌ 未找到任何任务")
|
||||
sys.exit(1)
|
||||
logger.info(f"发现 {len(task_ids)} 个任务")
|
||||
elif args.task_ids:
|
||||
task_ids = args.task_ids
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# 执行评测
|
||||
try:
|
||||
if len(task_ids) == 1:
|
||||
# 单任务评测
|
||||
score, message, _ = evaluate_task(task_ids[0], args.project_root)
|
||||
sys.exit(0 if score == 1 else 1)
|
||||
else:
|
||||
# 批量评测
|
||||
results = evaluate_batch(task_ids, args.project_root, args.output)
|
||||
|
||||
# 返回码:全部通过返回0,否则返回1
|
||||
all_pass = all(r['score'] == 1 for r in results)
|
||||
sys.exit(0 if all_pass else 1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⏹ 评测已取消")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user