Files
MatBench/eval_framework/src/metrics.py
2025-05-28 17:29:03 +08:00

121 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
@file
@author: Yutang Li
@mail: yt.li2@siat.ac.cn
@date: 2025-05-28
@version: 1.0
"""
import re
import numpy as np
from typing import List, Dict, Any, Optional
from sklearn.metrics import precision_score, recall_score, f1_score
import logging
logger = logging.getLogger(__name__)
class MetricsCalculator:
"""评估指标计算器"""
@staticmethod
def extract_answer(answer_string: str) -> Optional[str]:
"""
从回答字符串中提取答案
Args:
answer_string: 包含答案的字符串
Returns:
提取的答案如果没有找到返回None
"""
if not answer_string:
return None
match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string)
if match:
return match.group(1).strip()
return None
@staticmethod
def parse_answer(answer: Optional[str]) -> List[str]:
"""
解析答案为列表
Args:
answer: 答案字符串
Returns:
答案列表
"""
if answer is None:
return []
return [a.strip() for a in answer.split(',')]
@classmethod
def compute_metrics(cls, data: List[Dict[str, Any]]) -> Dict[str, float]:
"""
计算评估指标
Args:
data: 包含真实答案和预测答案的数据
Returns:
各种评估指标的字典
"""
true_answers = []
pred_answers = []
# 提取和解析答案
for item in data:
true_ans = cls.extract_answer(item["answer"])
pred_ans = cls.extract_answer(item["llm_answer"])
true_answers.append(cls.parse_answer(true_ans))
pred_answers.append(cls.parse_answer(pred_ans))
# 计算准确率
correct_counts = []
for true_ans, pred_ans in zip(true_answers, pred_answers):
if true_ans and pred_ans and set(true_ans) == set(pred_ans):
correct_counts.append(1)
else:
correct_counts.append(0)
accuracy = np.mean(correct_counts)
# 构建多标签向量
all_labels = set()
for item in data:
choices = item["choices"]["label"]
for label in choices:
all_labels.add(label)
all_labels = sorted(list(all_labels))
y_true_multi = []
y_pred_multi = []
for true_ans, pred_ans in zip(true_answers, pred_answers):
true_vector = [1 if label in (true_ans or []) else 0 for label in all_labels]
pred_vector = [1 if label in (pred_ans or []) else 0 for label in all_labels]
y_true_multi.append(true_vector)
y_pred_multi.append(pred_vector)
y_true_multi = np.array(y_true_multi)
y_pred_multi = np.array(y_pred_multi)
# 计算各种指标
metrics = {
"accuracy": accuracy,
"precision_micro": precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
"recall_micro": recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
"f1_micro": f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
"precision_macro": precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
"recall_macro": recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
"f1_macro": f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
}
logger.info("Metrics computed successfully")
return metrics