重构eval代码
This commit is contained in:
26
eval_framework/src/__init__.py
Normal file
26
eval_framework/src/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from .data_loader import DataLoader
|
||||
from .llm_client import LLMClient
|
||||
from .evaluator import Evaluator
|
||||
from .metrics import MetricsCalculator
|
||||
from .utils import (
|
||||
load_config, save_results, save_metrics, save_summary,
|
||||
setup_logging, print_metrics, print_summary,
|
||||
get_models_from_config, generate_output_dir, generate_model_output_path
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'DataLoader',
|
||||
'LLMClient',
|
||||
'Evaluator',
|
||||
'MetricsCalculator',
|
||||
'load_config',
|
||||
'save_results',
|
||||
'save_metrics',
|
||||
'save_summary',
|
||||
'setup_logging',
|
||||
'print_metrics',
|
||||
'print_summary',
|
||||
'get_models_from_config',
|
||||
'generate_output_dir',
|
||||
'generate_model_output_path'
|
||||
]
|
||||
BIN
eval_framework/src/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/data_loader.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/data_loader.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/data_loader.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/data_loader.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/evaluator.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/evaluator.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/evaluator.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/evaluator.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/llm_client.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/llm_client.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/llm_client.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/llm_client.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/metrics.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/metrics.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/metrics.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/metrics.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/utils.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/utils.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/utils.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
81
eval_framework/src/data_loader.py
Normal file
81
eval_framework/src/data_loader.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DataLoader:
|
||||
"""数据加载器,负责加载和验证数据"""
|
||||
|
||||
@staticmethod
|
||||
def load_json_data(filepath: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从JSON文件加载数据
|
||||
|
||||
Args:
|
||||
filepath: JSON文件路径
|
||||
|
||||
Returns:
|
||||
加载的数据列表
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
json.JSONDecodeError: JSON格式错误
|
||||
"""
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
logger.info(f"Successfully loaded {len(data)} items from {filepath}")
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
logger.error(f"File not found: {filepath}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error in {filepath}: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def validate_data_item(item: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
验证数据项是否包含必要字段
|
||||
|
||||
Args:
|
||||
item: 数据项
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
required_fields = ['question', 'choices', 'answer', 'prompt']
|
||||
for field in required_fields:
|
||||
if field not in item:
|
||||
logger.warning(f"Missing required field: {field}")
|
||||
return False
|
||||
|
||||
if 'text' not in item['choices'] or 'label' not in item['choices']:
|
||||
logger.warning("Missing 'text' or 'label' in choices")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def load_and_validate_data(cls, filepath: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
加载并验证数据
|
||||
|
||||
Args:
|
||||
filepath: JSON文件路径
|
||||
|
||||
Returns:
|
||||
验证后的数据列表
|
||||
"""
|
||||
data = cls.load_json_data(filepath)
|
||||
valid_data = []
|
||||
|
||||
for i, item in enumerate(data):
|
||||
if cls.validate_data_item(item):
|
||||
valid_data.append(item)
|
||||
else:
|
||||
logger.warning(f"Invalid data item at index {i}, skipping")
|
||||
|
||||
logger.info(f"Validated {len(valid_data)} out of {len(data)} items")
|
||||
return valid_data
|
||||
98
eval_framework/src/evaluator.py
Normal file
98
eval_framework/src/evaluator.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
import concurrent.futures
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from tqdm import tqdm
|
||||
|
||||
from .llm_client import LLMClient
|
||||
from .metrics import MetricsCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Evaluator:
|
||||
"""评估器,协调整个评估流程"""
|
||||
|
||||
def __init__(self, llm_client: LLMClient, system_prompt: str):
|
||||
"""
|
||||
初始化评估器
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
system_prompt: 系统提示词
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.system_prompt = system_prompt
|
||||
self.metrics_calculator = MetricsCalculator()
|
||||
|
||||
def process_item(self, item: Dict[str, Any], index: int) -> Dict[str, Any]:
|
||||
"""
|
||||
处理单个数据项
|
||||
|
||||
Args:
|
||||
item: 数据项
|
||||
index: 数据项索引
|
||||
|
||||
Returns:
|
||||
处理结果
|
||||
"""
|
||||
question = item['question']
|
||||
text = item['choices']['text']
|
||||
label = item['choices']['label']
|
||||
prompt = item['prompt']
|
||||
expected_answer = item['answer'].strip()
|
||||
|
||||
# 格式化选择项
|
||||
formatted_choices = " ".join([f"({lbl}) {txt}" for lbl, txt in zip(label, text)])
|
||||
user_input = f"{question} {formatted_choices}. {prompt}"
|
||||
|
||||
# 获取LLM响应
|
||||
llm_answer = self.llm_client.get_response(user_input, self.system_prompt)
|
||||
|
||||
return {
|
||||
'index': index,
|
||||
'question': question,
|
||||
'choices': item['choices'],
|
||||
'answer': expected_answer,
|
||||
'llm_answer': llm_answer
|
||||
}
|
||||
|
||||
def evaluate(self, data: List[Dict[str, Any]], max_workers: int = 5) -> Tuple[Dict[str, float], List[Dict[str, Any]]]:
|
||||
"""
|
||||
评估数据集
|
||||
|
||||
Args:
|
||||
data: 数据集
|
||||
max_workers: 最大工作线程数
|
||||
|
||||
Returns:
|
||||
评估指标和详细结果
|
||||
"""
|
||||
results = []
|
||||
|
||||
logger.info(f"Starting evaluation with {max_workers} workers")
|
||||
|
||||
with tqdm(total=len(data), desc="Processing items") as pbar:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_index = {
|
||||
executor.submit(self.process_item, item, i): i
|
||||
for i, item in enumerate(data)
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for future in concurrent.futures.as_completed(future_to_index):
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
pbar.update(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing item: {e}")
|
||||
pbar.update(1)
|
||||
|
||||
# 按索引排序结果
|
||||
results.sort(key=lambda x: x['index'])
|
||||
|
||||
# 计算指标
|
||||
metrics = self.metrics_calculator.compute_metrics(results)
|
||||
|
||||
logger.info("Evaluation completed successfully")
|
||||
return metrics, results
|
||||
60
eval_framework/src/llm_client.py
Normal file
60
eval_framework/src/llm_client.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from openai import OpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LLMClient:
|
||||
"""LLM客户端,负责与API交互"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, model: str,
|
||||
temperature: float = 0, max_retries: int = 10):
|
||||
"""
|
||||
初始化LLM客户端
|
||||
|
||||
Args:
|
||||
api_key: API密钥
|
||||
base_url: API基础URL
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_retries: 最大重试次数
|
||||
"""
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_retries = max_retries
|
||||
|
||||
def get_response(self, user_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
获取LLM响应
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
system_prompt: 系统提示词
|
||||
|
||||
Returns:
|
||||
LLM响应,失败时返回"error!"
|
||||
"""
|
||||
retries = 0
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_input}
|
||||
],
|
||||
temperature=self.temperature
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
return answer
|
||||
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
logger.warning(f"API call failed (Attempt {retries}/{self.max_retries}): {e}")
|
||||
if retries < self.max_retries:
|
||||
time.sleep(2 ** retries) # 指数退避
|
||||
|
||||
logger.error(f"Failed to get response after {self.max_retries} attempts")
|
||||
return "error!"
|
||||
111
eval_framework/src/metrics.py
Normal file
111
eval_framework/src/metrics.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import re
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetricsCalculator:
|
||||
"""评估指标计算器"""
|
||||
|
||||
@staticmethod
|
||||
def extract_answer(answer_string: str) -> Optional[str]:
|
||||
"""
|
||||
从回答字符串中提取答案
|
||||
|
||||
Args:
|
||||
answer_string: 包含答案的字符串
|
||||
|
||||
Returns:
|
||||
提取的答案,如果没有找到返回None
|
||||
"""
|
||||
if not answer_string:
|
||||
return None
|
||||
|
||||
match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_answer(answer: Optional[str]) -> List[str]:
|
||||
"""
|
||||
解析答案为列表
|
||||
|
||||
Args:
|
||||
answer: 答案字符串
|
||||
|
||||
Returns:
|
||||
答案列表
|
||||
"""
|
||||
if answer is None:
|
||||
return []
|
||||
return [a.strip() for a in answer.split(',')]
|
||||
|
||||
@classmethod
|
||||
def compute_metrics(cls, data: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
"""
|
||||
计算评估指标
|
||||
|
||||
Args:
|
||||
data: 包含真实答案和预测答案的数据
|
||||
|
||||
Returns:
|
||||
各种评估指标的字典
|
||||
"""
|
||||
true_answers = []
|
||||
pred_answers = []
|
||||
|
||||
# 提取和解析答案
|
||||
for item in data:
|
||||
true_ans = cls.extract_answer(item["answer"])
|
||||
pred_ans = cls.extract_answer(item["llm_answer"])
|
||||
|
||||
true_answers.append(cls.parse_answer(true_ans))
|
||||
pred_answers.append(cls.parse_answer(pred_ans))
|
||||
|
||||
# 计算准确率
|
||||
correct_counts = []
|
||||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||||
if true_ans and pred_ans and set(true_ans) == set(pred_ans):
|
||||
correct_counts.append(1)
|
||||
else:
|
||||
correct_counts.append(0)
|
||||
|
||||
accuracy = np.mean(correct_counts)
|
||||
|
||||
# 构建多标签向量
|
||||
all_labels = set()
|
||||
for item in data:
|
||||
choices = item["choices"]["label"]
|
||||
for label in choices:
|
||||
all_labels.add(label)
|
||||
|
||||
all_labels = sorted(list(all_labels))
|
||||
|
||||
y_true_multi = []
|
||||
y_pred_multi = []
|
||||
|
||||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||||
true_vector = [1 if label in (true_ans or []) else 0 for label in all_labels]
|
||||
pred_vector = [1 if label in (pred_ans or []) else 0 for label in all_labels]
|
||||
y_true_multi.append(true_vector)
|
||||
y_pred_multi.append(pred_vector)
|
||||
|
||||
y_true_multi = np.array(y_true_multi)
|
||||
y_pred_multi = np.array(y_pred_multi)
|
||||
|
||||
# 计算各种指标
|
||||
metrics = {
|
||||
"accuracy": accuracy,
|
||||
"precision_micro": precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||||
"recall_micro": recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||||
"f1_micro": f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||||
"precision_macro": precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
|
||||
"recall_macro": recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
|
||||
"f1_macro": f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
||||
}
|
||||
|
||||
logger.info("Metrics computed successfully")
|
||||
return metrics
|
||||
360
eval_framework/src/utils.py
Normal file
360
eval_framework/src/utils.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import json
|
||||
import yaml
|
||||
import logging
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
from tabulate import tabulate
|
||||
|
||||
def load_config(config_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
加载配置文件
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
|
||||
Returns:
|
||||
配置字典
|
||||
"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
def get_models_from_config(config: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
从配置中获取模型列表
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
模型名称列表
|
||||
"""
|
||||
api_config = config['api']
|
||||
|
||||
# 优先使用models列表
|
||||
if 'models' in api_config and api_config['models']:
|
||||
return api_config['models']
|
||||
# 向后兼容:如果没有models,使用单个model
|
||||
elif 'model' in api_config:
|
||||
return [api_config['model']]
|
||||
else:
|
||||
raise ValueError("No models specified in configuration")
|
||||
|
||||
def generate_output_dir(config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
生成输出目录路径
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
输出目录路径
|
||||
"""
|
||||
output_config = config['evaluation']['output']
|
||||
base_dir = output_config['base_dir']
|
||||
auto_timestamp = output_config.get('auto_timestamp', True)
|
||||
|
||||
# 创建基础目录
|
||||
base_path = Path(base_dir)
|
||||
|
||||
if auto_timestamp:
|
||||
# 创建时间戳文件夹 (年月日时分)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
|
||||
output_dir = base_path / timestamp
|
||||
else:
|
||||
output_dir = base_path
|
||||
|
||||
# 确保目录存在
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return str(output_dir)
|
||||
|
||||
def generate_model_output_path(output_dir: str, model_name: str, filename_template: str) -> str:
|
||||
"""
|
||||
为特定模型生成输出文件路径
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
model_name: 模型名称
|
||||
filename_template: 文件名模板
|
||||
|
||||
Returns:
|
||||
完整的输出文件路径
|
||||
"""
|
||||
# 处理模型名中的特殊字符
|
||||
safe_model_name = model_name.replace('/', '_').replace(':', '_')
|
||||
filename = filename_template.format(model=safe_model_name)
|
||||
return str(Path(output_dir) / filename)
|
||||
|
||||
def save_results(results: list, filepath: str) -> None:
|
||||
"""
|
||||
保存结果到JSON文件
|
||||
|
||||
Args:
|
||||
results: 结果列表
|
||||
filepath: 保存路径
|
||||
"""
|
||||
# 确保目录存在
|
||||
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def save_metrics(metrics: Dict[str, float], filepath: str) -> None:
|
||||
"""
|
||||
保存评估指标到JSON文件
|
||||
|
||||
Args:
|
||||
metrics: 指标字典
|
||||
filepath: 保存路径
|
||||
"""
|
||||
# 生成指标文件路径(在同一目录下)
|
||||
metrics_path = Path(filepath).parent / f"{Path(filepath).stem}_metrics.json"
|
||||
|
||||
# 添加时间戳和其他元信息
|
||||
metrics_with_meta = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metrics": metrics
|
||||
}
|
||||
|
||||
with open(metrics_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metrics_with_meta, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def create_results_dataframe(all_results: Dict[str, Dict]) -> pd.DataFrame:
|
||||
"""
|
||||
将所有模型的结果转换为DataFrame
|
||||
|
||||
Args:
|
||||
all_results: 所有模型的结果字典
|
||||
|
||||
Returns:
|
||||
包含所有模型指标的DataFrame
|
||||
"""
|
||||
if not all_results:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 收集所有模型的指标数据
|
||||
data = []
|
||||
for model_name, model_result in all_results.items():
|
||||
row = {"Model": model_name}
|
||||
row.update(model_result["metrics"])
|
||||
row["Data Count"] = len(model_result["results"])
|
||||
data.append(row)
|
||||
|
||||
# 创建DataFrame
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# 将Model列设为索引
|
||||
df = df.set_index("Model")
|
||||
|
||||
# 对列进行排序(将Data Count放在最后)
|
||||
metric_columns = [col for col in df.columns if col != "Data Count"]
|
||||
df = df[metric_columns + ["Data Count"]]
|
||||
|
||||
return df
|
||||
|
||||
def save_summary(all_results: Dict[str, Dict], output_dir: str, summary_filename: str) -> None:
|
||||
"""
|
||||
保存所有模型的汇总结果
|
||||
|
||||
Args:
|
||||
all_results: 所有模型的结果字典
|
||||
output_dir: 输出目录
|
||||
summary_filename: 汇总文件名
|
||||
"""
|
||||
output_path = Path(output_dir)
|
||||
|
||||
# 创建DataFrame
|
||||
df = create_results_dataframe(all_results)
|
||||
|
||||
if df.empty:
|
||||
logging.warning("No results to save in summary")
|
||||
return
|
||||
|
||||
# 保存JSON格式的详细汇总
|
||||
summary_path = output_path / summary_filename
|
||||
summary_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"models_count": len(all_results),
|
||||
"models": {}
|
||||
}
|
||||
|
||||
for model_name, model_result in all_results.items():
|
||||
summary_data["models"][model_name] = {
|
||||
"metrics": model_result["metrics"],
|
||||
"data_count": len(model_result["results"])
|
||||
}
|
||||
|
||||
# 添加模型对比表
|
||||
if len(all_results) > 1:
|
||||
comparison = {}
|
||||
metric_names = [col for col in df.columns if col != "Data Count"]
|
||||
|
||||
for metric in metric_names:
|
||||
comparison[metric] = df[metric].to_dict()
|
||||
|
||||
summary_data["comparison"] = comparison
|
||||
|
||||
with open(summary_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 保存CSV格式的汇总表格
|
||||
csv_filename = summary_filename.replace('.json', '.csv')
|
||||
csv_path = output_path / csv_filename
|
||||
|
||||
# 重置索引以便模型名称也作为列保存
|
||||
df_for_csv = df.reset_index()
|
||||
df_for_csv.to_csv(csv_path, index=False, encoding='utf-8')
|
||||
|
||||
# 保存Excel格式(如果需要)
|
||||
excel_filename = summary_filename.replace('.json', '.xlsx')
|
||||
excel_path = output_path / excel_filename
|
||||
|
||||
try:
|
||||
# 创建Excel文件,包含多个工作表
|
||||
with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
|
||||
# 主要结果表
|
||||
df_for_csv.to_excel(writer, sheet_name='Summary', index=False)
|
||||
|
||||
# 如果有多个模型,创建排名表
|
||||
if len(all_results) > 1:
|
||||
ranking_df = create_ranking_dataframe(df)
|
||||
ranking_df.to_excel(writer, sheet_name='Rankings', index=False)
|
||||
|
||||
except ImportError:
|
||||
logging.warning("openpyxl not installed, skipping Excel export")
|
||||
|
||||
logging.info(f"Summary saved to {summary_path}")
|
||||
logging.info(f"CSV summary saved to {csv_path}")
|
||||
|
||||
def create_ranking_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
创建模型排名DataFrame
|
||||
|
||||
Args:
|
||||
df: 原始结果DataFrame
|
||||
|
||||
Returns:
|
||||
包含排名的DataFrame
|
||||
"""
|
||||
# 排除非指标列
|
||||
metric_columns = [col for col in df.columns if col != "Data Count"]
|
||||
|
||||
# 为每个指标创建排名(假设数值越大越好,可以根据需要调整)
|
||||
ranking_data = []
|
||||
|
||||
for metric in metric_columns:
|
||||
# 创建排名(降序,数值越大排名越前)
|
||||
ranks = df[metric].rank(method='min', ascending=False)
|
||||
|
||||
for model_name in df.index:
|
||||
ranking_data.append({
|
||||
'Model': model_name,
|
||||
'Metric': metric,
|
||||
'Value': df.loc[model_name, metric],
|
||||
'Rank': int(ranks[model_name])
|
||||
})
|
||||
|
||||
ranking_df = pd.DataFrame(ranking_data)
|
||||
return ranking_df
|
||||
|
||||
def print_summary(all_results: Dict[str, Dict]) -> None:
|
||||
"""
|
||||
打印所有模型的汇总结果
|
||||
|
||||
Args:
|
||||
all_results: 所有模型的结果字典
|
||||
"""
|
||||
print("\n" + "="*100)
|
||||
print("SUMMARY - ALL MODELS COMPARISON")
|
||||
print("="*100)
|
||||
|
||||
if not all_results:
|
||||
print("No results to display")
|
||||
return
|
||||
|
||||
# 创建DataFrame
|
||||
df = create_results_dataframe(all_results)
|
||||
|
||||
if df.empty:
|
||||
print("No valid results to display")
|
||||
return
|
||||
|
||||
# 使用tabulate打印美观的表格
|
||||
print(tabulate(
|
||||
df,
|
||||
headers=df.columns,
|
||||
tablefmt='grid',
|
||||
floatfmt='.4f',
|
||||
showindex=True
|
||||
))
|
||||
|
||||
# 如果有多个模型,显示最佳模型
|
||||
if len(all_results) > 1:
|
||||
print("\n" + "-"*100)
|
||||
print("BEST PERFORMERS BY METRIC:")
|
||||
print("-"*100)
|
||||
|
||||
metric_columns = [col for col in df.columns if col != "Data Count"]
|
||||
|
||||
for metric in metric_columns:
|
||||
best_model = df[metric].idxmax()
|
||||
best_value = df.loc[best_model, metric]
|
||||
print(f"{metric.upper():<20}: {best_model:<30} ({best_value:.4f})")
|
||||
|
||||
print("="*100)
|
||||
|
||||
def setup_logging(level: str = "INFO", format_str: str = None, log_dir: str = "logs") -> None:
|
||||
"""
|
||||
设置日志配置
|
||||
|
||||
Args:
|
||||
level: 日志级别
|
||||
format_str: 日志格式
|
||||
log_dir: 日志目录
|
||||
"""
|
||||
if format_str is None:
|
||||
format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
# 创建日志目录
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成日志文件名(包含时间戳)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
|
||||
log_file = Path(log_dir) / f"evaluation_{timestamp}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, level.upper()),
|
||||
format=format_str,
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(log_file, encoding='utf-8')
|
||||
]
|
||||
)
|
||||
|
||||
def print_metrics(metrics: Dict[str, float], model_name: str = None) -> None:
|
||||
"""
|
||||
打印评估指标
|
||||
|
||||
Args:
|
||||
metrics: 指标字典
|
||||
model_name: 模型名称
|
||||
"""
|
||||
title = f"EVALUATION RESULTS - {model_name}" if model_name else "EVALUATION RESULTS"
|
||||
print("\n" + "="*60)
|
||||
print(title)
|
||||
print("="*60)
|
||||
|
||||
# 创建单行DataFrame用于美观显示
|
||||
df = pd.DataFrame([metrics])
|
||||
print(tabulate(
|
||||
df,
|
||||
headers=df.columns,
|
||||
tablefmt='grid',
|
||||
floatfmt='.4f',
|
||||
showindex=False
|
||||
))
|
||||
|
||||
print("="*60)
|
||||
Reference in New Issue
Block a user