重构eval代码
This commit is contained in:
111
eval_framework/src/metrics.py
Normal file
111
eval_framework/src/metrics.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import re
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetricsCalculator:
|
||||
"""评估指标计算器"""
|
||||
|
||||
@staticmethod
|
||||
def extract_answer(answer_string: str) -> Optional[str]:
|
||||
"""
|
||||
从回答字符串中提取答案
|
||||
|
||||
Args:
|
||||
answer_string: 包含答案的字符串
|
||||
|
||||
Returns:
|
||||
提取的答案,如果没有找到返回None
|
||||
"""
|
||||
if not answer_string:
|
||||
return None
|
||||
|
||||
match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_answer(answer: Optional[str]) -> List[str]:
|
||||
"""
|
||||
解析答案为列表
|
||||
|
||||
Args:
|
||||
answer: 答案字符串
|
||||
|
||||
Returns:
|
||||
答案列表
|
||||
"""
|
||||
if answer is None:
|
||||
return []
|
||||
return [a.strip() for a in answer.split(',')]
|
||||
|
||||
@classmethod
|
||||
def compute_metrics(cls, data: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
"""
|
||||
计算评估指标
|
||||
|
||||
Args:
|
||||
data: 包含真实答案和预测答案的数据
|
||||
|
||||
Returns:
|
||||
各种评估指标的字典
|
||||
"""
|
||||
true_answers = []
|
||||
pred_answers = []
|
||||
|
||||
# 提取和解析答案
|
||||
for item in data:
|
||||
true_ans = cls.extract_answer(item["answer"])
|
||||
pred_ans = cls.extract_answer(item["llm_answer"])
|
||||
|
||||
true_answers.append(cls.parse_answer(true_ans))
|
||||
pred_answers.append(cls.parse_answer(pred_ans))
|
||||
|
||||
# 计算准确率
|
||||
correct_counts = []
|
||||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||||
if true_ans and pred_ans and set(true_ans) == set(pred_ans):
|
||||
correct_counts.append(1)
|
||||
else:
|
||||
correct_counts.append(0)
|
||||
|
||||
accuracy = np.mean(correct_counts)
|
||||
|
||||
# 构建多标签向量
|
||||
all_labels = set()
|
||||
for item in data:
|
||||
choices = item["choices"]["label"]
|
||||
for label in choices:
|
||||
all_labels.add(label)
|
||||
|
||||
all_labels = sorted(list(all_labels))
|
||||
|
||||
y_true_multi = []
|
||||
y_pred_multi = []
|
||||
|
||||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||||
true_vector = [1 if label in (true_ans or []) else 0 for label in all_labels]
|
||||
pred_vector = [1 if label in (pred_ans or []) else 0 for label in all_labels]
|
||||
y_true_multi.append(true_vector)
|
||||
y_pred_multi.append(pred_vector)
|
||||
|
||||
y_true_multi = np.array(y_true_multi)
|
||||
y_pred_multi = np.array(y_pred_multi)
|
||||
|
||||
# 计算各种指标
|
||||
metrics = {
|
||||
"accuracy": accuracy,
|
||||
"precision_micro": precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||||
"recall_micro": recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||||
"f1_micro": f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||||
"precision_macro": precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
|
||||
"recall_macro": recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
|
||||
"f1_macro": f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
||||
}
|
||||
|
||||
logger.info("Metrics computed successfully")
|
||||
return metrics
|
||||
Reference in New Issue
Block a user