""" @file @author: Yutang Li @mail: yt.li2@siat.ac.cn @date: 2025-05-28 @version: 1.0 """ import re import numpy as np from typing import List, Dict, Any, Optional from sklearn.metrics import precision_score, recall_score, f1_score import logging logger = logging.getLogger(__name__) class MetricsCalculator: """评估指标计算器""" @staticmethod def extract_answer(answer_string: str) -> Optional[str]: """ 从回答字符串中提取答案 Args: answer_string: 包含答案的字符串 Returns: 提取的答案,如果没有找到返回None """ if not answer_string: return None match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string) if match: return match.group(1).strip() return None @staticmethod def parse_answer(answer: Optional[str]) -> List[str]: """ 解析答案为列表 Args: answer: 答案字符串 Returns: 答案列表 """ if answer is None: return [] return [a.strip() for a in answer.split(',')] @classmethod def compute_metrics(cls, data: List[Dict[str, Any]]) -> Dict[str, float]: """ 计算评估指标 Args: data: 包含真实答案和预测答案的数据 Returns: 各种评估指标的字典 """ true_answers = [] pred_answers = [] # 提取和解析答案 for item in data: true_ans = cls.extract_answer(item["answer"]) pred_ans = cls.extract_answer(item["llm_answer"]) true_answers.append(cls.parse_answer(true_ans)) pred_answers.append(cls.parse_answer(pred_ans)) # 计算准确率 correct_counts = [] for true_ans, pred_ans in zip(true_answers, pred_answers): if true_ans and pred_ans and set(true_ans) == set(pred_ans): correct_counts.append(1) else: correct_counts.append(0) accuracy = np.mean(correct_counts) # 构建多标签向量 all_labels = set() for item in data: choices = item["choices"]["label"] for label in choices: all_labels.add(label) all_labels = sorted(list(all_labels)) y_true_multi = [] y_pred_multi = [] for true_ans, pred_ans in zip(true_answers, pred_answers): true_vector = [1 if label in (true_ans or []) else 0 for label in all_labels] pred_vector = [1 if label in (pred_ans or []) else 0 for label in all_labels] y_true_multi.append(true_vector) y_pred_multi.append(pred_vector) y_true_multi = np.array(y_true_multi) y_pred_multi = np.array(y_pred_multi) # 计算各种指标 metrics = { "accuracy": accuracy, "precision_micro": precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0), "recall_micro": recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0), "f1_micro": f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0), "precision_macro": precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0), "recall_macro": recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0), "f1_macro": f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0) } logger.info("Metrics computed successfully") return metrics