262 lines
7.7 KiB
Python
262 lines
7.7 KiB
Python
"""
|
||
评测入口脚本
|
||
支持单任务或批量评测
|
||
"""
|
||
import os
|
||
import sys
|
||
import json
|
||
import argparse
|
||
import logging
|
||
from datetime import datetime
|
||
|
||
# 添加父目录到路径
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from core.evaluator import evaluate
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def evaluate_task(task_id, project_root=".", verbose=True):
|
||
"""
|
||
评测单个任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
project_root: 项目根目录
|
||
verbose: 是否详细输出
|
||
|
||
Returns:
|
||
tuple: (score, message, details)
|
||
"""
|
||
task_dir = os.path.join(project_root, "tasks", task_id)
|
||
task_json_path = os.path.join(task_dir, "task.json")
|
||
|
||
# 检查任务配置
|
||
if not os.path.exists(task_json_path):
|
||
logger.error(f"❌ 任务配置不存在: {task_json_path}")
|
||
return 0, "任务配置不存在", {}
|
||
|
||
# 加载任务配置
|
||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||
task_config = json.load(f)
|
||
|
||
if verbose:
|
||
print("\n" + "=" * 60)
|
||
print(f"📝 评测任务: {task_id}")
|
||
print("=" * 60)
|
||
print(f"类别: {task_config.get('category', 'N/A')}")
|
||
print(f"难度: {task_config.get('difficulty', 'N/A')}")
|
||
print(f"指令: {task_config.get('instruction', 'N/A')}")
|
||
print("=" * 60)
|
||
|
||
# 获取评测配置
|
||
eval_config = task_config.get('evaluation', {})
|
||
method = eval_config.get('method', 'xrd_data_compare')
|
||
|
||
# 构建文件路径
|
||
gt_path = os.path.join(task_dir, eval_config.get('ground_truth', ''))
|
||
agent_path = os.path.join(task_dir, eval_config.get('target_output', ''))
|
||
tolerance = eval_config.get('tolerance', 1e-4)
|
||
|
||
# 检查文件
|
||
if not os.path.exists(gt_path):
|
||
logger.error(f"❌ Ground truth文件不存在: {gt_path}")
|
||
return 0, "Ground truth文件不存在", {}
|
||
|
||
if not os.path.exists(agent_path):
|
||
logger.error(f"❌ Agent输出文件不存在: {agent_path}")
|
||
return 0, "Agent输出文件不存在", {}
|
||
|
||
# 执行评测
|
||
try:
|
||
if method == 'xrd_data_compare':
|
||
score, message = evaluate(gt_path, agent_path, tolerance, mode="xrd_data")
|
||
elif method == 'peak_report_compare':
|
||
score, message = evaluate(gt_path, agent_path, tolerance, mode="peak_report")
|
||
else:
|
||
logger.warning(f"⚠️ 未知的评测方法: {method}")
|
||
score, message = 0, f"未知的评测方法: {method}"
|
||
|
||
details = {
|
||
"task_id": task_id,
|
||
"method": method,
|
||
"ground_truth": gt_path,
|
||
"agent_output": agent_path,
|
||
"tolerance": tolerance,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
if verbose:
|
||
print(f"\n📊 评测结果:")
|
||
print(f" Score: {score}")
|
||
print(f" {message}")
|
||
print("=" * 60 + "\n")
|
||
|
||
return score, message, details
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 评测失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 0, f"评测失败: {str(e)}", {}
|
||
|
||
|
||
def evaluate_batch(task_ids, project_root=".", output_file=None):
|
||
"""
|
||
批量评测多个任务
|
||
|
||
Args:
|
||
task_ids: 任务ID列表
|
||
project_root: 项目根目录
|
||
output_file: 结果输出文件(JSON格式)
|
||
"""
|
||
print("\n" + "=" * 60)
|
||
print("📊 批量评测")
|
||
print("=" * 60)
|
||
print(f"任务数: {len(task_ids)}")
|
||
print("=" * 60 + "\n")
|
||
|
||
results = []
|
||
total_score = 0
|
||
|
||
for i, task_id in enumerate(task_ids, 1):
|
||
print(f"\n[{i}/{len(task_ids)}] 评测: {task_id}")
|
||
score, message, details = evaluate_task(task_id, project_root, verbose=False)
|
||
|
||
result = {
|
||
"task_id": task_id,
|
||
"score": score,
|
||
"message": message,
|
||
**details
|
||
}
|
||
results.append(result)
|
||
total_score += score
|
||
|
||
status = "✅ 通过" if score == 1 else "❌ 失败"
|
||
print(f" {status}: {message}")
|
||
|
||
# 统计
|
||
pass_count = sum(1 for r in results if r['score'] == 1)
|
||
pass_rate = pass_count / len(task_ids) * 100 if task_ids else 0
|
||
|
||
print("\n" + "=" * 60)
|
||
print("📈 评测统计")
|
||
print("=" * 60)
|
||
print(f"总任务数: {len(task_ids)}")
|
||
print(f"通过数: {pass_count}")
|
||
print(f"失败数: {len(task_ids) - pass_count}")
|
||
print(f"通过率: {pass_rate:.1f}%")
|
||
print(f"平均分: {total_score / len(task_ids):.2f}")
|
||
print("=" * 60 + "\n")
|
||
|
||
# 保存结果
|
||
if output_file:
|
||
output_data = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"total_tasks": len(task_ids),
|
||
"pass_count": pass_count,
|
||
"pass_rate": pass_rate,
|
||
"results": results
|
||
}
|
||
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||
|
||
print(f"📄 详细结果已保存到: {output_file}\n")
|
||
|
||
return results
|
||
|
||
|
||
def discover_tasks(project_root="."):
|
||
"""
|
||
自动发现所有任务
|
||
|
||
Returns:
|
||
list: 任务ID列表
|
||
"""
|
||
tasks_dir = os.path.join(project_root, "tasks")
|
||
|
||
if not os.path.exists(tasks_dir):
|
||
return []
|
||
|
||
task_ids = []
|
||
for item in os.listdir(tasks_dir):
|
||
task_dir = os.path.join(tasks_dir, item)
|
||
task_json = os.path.join(task_dir, "task.json")
|
||
|
||
if os.path.isdir(task_dir) and os.path.exists(task_json):
|
||
task_ids.append(item)
|
||
|
||
return sorted(task_ids)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="JADE Benchmark 评测工具",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
使用示例:
|
||
# 评测单个任务
|
||
python scripts/run_eval.py smoothing_001
|
||
|
||
# 评测多个任务
|
||
python scripts/run_eval.py smoothing_001 peak_search_001
|
||
|
||
# 评测所有任务
|
||
python scripts/run_eval.py --all
|
||
|
||
# 保存结果到文件
|
||
python scripts/run_eval.py --all --output results.json
|
||
"""
|
||
)
|
||
|
||
parser.add_argument("task_ids", nargs="*", help="任务ID列表")
|
||
parser.add_argument("--all", action="store_true", help="评测所有任务")
|
||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||
parser.add_argument("--output", help="结果输出文件(JSON格式)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 确定要评测的任务
|
||
if args.all:
|
||
task_ids = discover_tasks(args.project_root)
|
||
if not task_ids:
|
||
logger.error("❌ 未找到任何任务")
|
||
sys.exit(1)
|
||
logger.info(f"发现 {len(task_ids)} 个任务")
|
||
elif args.task_ids:
|
||
task_ids = args.task_ids
|
||
else:
|
||
parser.print_help()
|
||
sys.exit(1)
|
||
|
||
# 执行评测
|
||
try:
|
||
if len(task_ids) == 1:
|
||
# 单任务评测
|
||
score, message, _ = evaluate_task(task_ids[0], args.project_root)
|
||
sys.exit(0 if score == 1 else 1)
|
||
else:
|
||
# 批量评测
|
||
results = evaluate_batch(task_ids, args.project_root, args.output)
|
||
|
||
# 返回码:全部通过返回0,否则返回1
|
||
all_pass = all(r['score'] == 1 for r in results)
|
||
sys.exit(0 if all_pass else 1)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\n⏹ 评测已取消")
|
||
sys.exit(1)
|
||
except Exception as e:
|
||
logger.error(f"❌ 错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|