重构eval代码

This commit is contained in:
lzy
2025-05-28 15:43:50 +08:00
parent 9f5318c23d
commit 9abd8fc1c5
39 changed files with 2468 additions and 166 deletions

View File

@@ -0,0 +1,111 @@
import re
import numpy as np
from typing import List, Dict, Any, Optional
from sklearn.metrics import precision_score, recall_score, f1_score
import logging
logger = logging.getLogger(__name__)
class MetricsCalculator:
"""评估指标计算器"""
@staticmethod
def extract_answer(answer_string: str) -> Optional[str]:
"""
从回答字符串中提取答案
Args:
answer_string: 包含答案的字符串
Returns:
提取的答案如果没有找到返回None
"""
if not answer_string:
return None
match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string)
if match:
return match.group(1).strip()
return None
@staticmethod
def parse_answer(answer: Optional[str]) -> List[str]:
"""
解析答案为列表
Args:
answer: 答案字符串
Returns:
答案列表
"""
if answer is None:
return []
return [a.strip() for a in answer.split(',')]
@classmethod
def compute_metrics(cls, data: List[Dict[str, Any]]) -> Dict[str, float]:
"""
计算评估指标
Args:
data: 包含真实答案和预测答案的数据
Returns:
各种评估指标的字典
"""
true_answers = []
pred_answers = []
# 提取和解析答案
for item in data:
true_ans = cls.extract_answer(item["answer"])
pred_ans = cls.extract_answer(item["llm_answer"])
true_answers.append(cls.parse_answer(true_ans))
pred_answers.append(cls.parse_answer(pred_ans))
# 计算准确率
correct_counts = []
for true_ans, pred_ans in zip(true_answers, pred_answers):
if true_ans and pred_ans and set(true_ans) == set(pred_ans):
correct_counts.append(1)
else:
correct_counts.append(0)
accuracy = np.mean(correct_counts)
# 构建多标签向量
all_labels = set()
for item in data:
choices = item["choices"]["label"]
for label in choices:
all_labels.add(label)
all_labels = sorted(list(all_labels))
y_true_multi = []
y_pred_multi = []
for true_ans, pred_ans in zip(true_answers, pred_answers):
true_vector = [1 if label in (true_ans or []) else 0 for label in all_labels]
pred_vector = [1 if label in (pred_ans or []) else 0 for label in all_labels]
y_true_multi.append(true_vector)
y_pred_multi.append(pred_vector)
y_true_multi = np.array(y_true_multi)
y_pred_multi = np.array(y_pred_multi)
# 计算各种指标
metrics = {
"accuracy": accuracy,
"precision_micro": precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
"recall_micro": recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
"f1_micro": f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
"precision_macro": precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
"recall_macro": recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
"f1_macro": f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
}
logger.info("Metrics computed successfully")
return metrics