121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
"""
|
||
@file
|
||
@author: Yutang Li
|
||
@mail: yt.li2@siat.ac.cn
|
||
@date: 2025-05-28
|
||
@version: 1.0
|
||
"""
|
||
|
||
|
||
import re
|
||
import numpy as np
|
||
from typing import List, Dict, Any, Optional
|
||
from sklearn.metrics import precision_score, recall_score, f1_score
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class MetricsCalculator:
|
||
"""评估指标计算器"""
|
||
|
||
@staticmethod
|
||
def extract_answer(answer_string: str) -> Optional[str]:
|
||
"""
|
||
从回答字符串中提取答案
|
||
|
||
Args:
|
||
answer_string: 包含答案的字符串
|
||
|
||
Returns:
|
||
提取的答案,如果没有找到返回None
|
||
"""
|
||
if not answer_string:
|
||
return None
|
||
|
||
match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return None
|
||
|
||
@staticmethod
|
||
def parse_answer(answer: Optional[str]) -> List[str]:
|
||
"""
|
||
解析答案为列表
|
||
|
||
Args:
|
||
answer: 答案字符串
|
||
|
||
Returns:
|
||
答案列表
|
||
"""
|
||
if answer is None:
|
||
return []
|
||
return [a.strip() for a in answer.split(',')]
|
||
|
||
@classmethod
|
||
def compute_metrics(cls, data: List[Dict[str, Any]]) -> Dict[str, float]:
|
||
"""
|
||
计算评估指标
|
||
|
||
Args:
|
||
data: 包含真实答案和预测答案的数据
|
||
|
||
Returns:
|
||
各种评估指标的字典
|
||
"""
|
||
true_answers = []
|
||
pred_answers = []
|
||
|
||
# 提取和解析答案
|
||
for item in data:
|
||
true_ans = cls.extract_answer(item["answer"])
|
||
pred_ans = cls.extract_answer(item["llm_answer"])
|
||
|
||
true_answers.append(cls.parse_answer(true_ans))
|
||
pred_answers.append(cls.parse_answer(pred_ans))
|
||
|
||
# 计算准确率
|
||
correct_counts = []
|
||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||
if true_ans and pred_ans and set(true_ans) == set(pred_ans):
|
||
correct_counts.append(1)
|
||
else:
|
||
correct_counts.append(0)
|
||
|
||
accuracy = np.mean(correct_counts)
|
||
|
||
# 构建多标签向量
|
||
all_labels = set()
|
||
for item in data:
|
||
choices = item["choices"]["label"]
|
||
for label in choices:
|
||
all_labels.add(label)
|
||
|
||
all_labels = sorted(list(all_labels))
|
||
|
||
y_true_multi = []
|
||
y_pred_multi = []
|
||
|
||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||
true_vector = [1 if label in (true_ans or []) else 0 for label in all_labels]
|
||
pred_vector = [1 if label in (pred_ans or []) else 0 for label in all_labels]
|
||
y_true_multi.append(true_vector)
|
||
y_pred_multi.append(pred_vector)
|
||
|
||
y_true_multi = np.array(y_true_multi)
|
||
y_pred_multi = np.array(y_pred_multi)
|
||
|
||
# 计算各种指标
|
||
metrics = {
|
||
"accuracy": accuracy,
|
||
"precision_micro": precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||
"recall_micro": recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||
"f1_micro": f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||
"precision_macro": precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
|
||
"recall_macro": recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
|
||
"f1_macro": f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
||
}
|
||
|
||
logger.info("Metrics computed successfully")
|
||
return metrics
|