重构eval代码
This commit is contained in:
36
eval_framework/config/config.yaml
Normal file
36
eval_framework/config/config.yaml
Normal file
@@ -0,0 +1,36 @@
|
||||
# API配置
|
||||
api:
|
||||
key: "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
|
||||
base_url: "https://vip.apiyi.com/v1"
|
||||
temperature: 0
|
||||
max_retries: 10
|
||||
# 支持多个模型
|
||||
models:
|
||||
- "qwen-max-2025-01-25"
|
||||
- "gpt-4o"
|
||||
# 或者使用单个模型(向后兼容)
|
||||
# model: "qwen-max-2025-01-25"
|
||||
|
||||
# 系统提示词
|
||||
system_prompt: "You are an expert in the field of materials science, adept at answering questions related to fundamental aspects of materials science, including material structure, properties, processing, and applications."
|
||||
|
||||
# 评估配置
|
||||
evaluation:
|
||||
max_workers: 8
|
||||
input_file: "/home/ubuntu/50T/LYT/MatBench/layer1/ALL-merge/merged.json"
|
||||
# 输出配置
|
||||
output:
|
||||
base_dir: "results"
|
||||
auto_timestamp: true
|
||||
filename_template: "{model}.json"
|
||||
summary_filename: "summary.json"
|
||||
# 输出格式选项
|
||||
export_formats:
|
||||
- "json" # 详细JSON结果
|
||||
- "csv" # CSV表格
|
||||
- "excel" # Excel表格(需要openpyxl)
|
||||
|
||||
# 日志配置
|
||||
logging:
|
||||
level: "INFO"
|
||||
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
164
eval_framework/main.py
Normal file
164
eval_framework/main.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from src import (
|
||||
DataLoader, LLMClient, Evaluator,
|
||||
load_config, save_results, save_metrics, save_summary,
|
||||
setup_logging, print_metrics, print_summary,
|
||||
get_models_from_config, generate_output_dir, generate_model_output_path
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def evaluate_single_model(
|
||||
model_name: str,
|
||||
data: list,
|
||||
config: Dict[str, Any],
|
||||
output_dir: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
评估单个模型
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
data: 评估数据
|
||||
config: 配置字典
|
||||
output_dir: 输出目录
|
||||
|
||||
Returns:
|
||||
包含指标和结果的字典
|
||||
"""
|
||||
logger.info(f"Starting evaluation for model: {model_name}")
|
||||
|
||||
# 初始化LLM客户端
|
||||
llm_client = LLMClient(
|
||||
api_key=config['api']['key'],
|
||||
base_url=config['api']['base_url'],
|
||||
model=model_name,
|
||||
temperature=config['api']['temperature'],
|
||||
max_retries=config['api']['max_retries']
|
||||
)
|
||||
|
||||
# 初始化评估器
|
||||
evaluator = Evaluator(
|
||||
llm_client=llm_client,
|
||||
system_prompt=config['system_prompt']
|
||||
)
|
||||
|
||||
# 执行评估
|
||||
max_workers = config['evaluation']['max_workers']
|
||||
metrics, results = evaluator.evaluate(data, max_workers=max_workers)
|
||||
|
||||
# 生成输出文件路径
|
||||
filename_template = config['evaluation']['output']['filename_template']
|
||||
output_file = generate_model_output_path(output_dir, model_name, filename_template)
|
||||
|
||||
# 保存结果和指标
|
||||
save_results(results, output_file)
|
||||
save_metrics(metrics, output_file)
|
||||
|
||||
logger.info(f"Model {model_name} evaluation completed. Results saved to {output_file}")
|
||||
|
||||
return {
|
||||
"metrics": metrics,
|
||||
"results": results,
|
||||
"output_file": output_file
|
||||
}
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="材料科学LLM评估框架")
|
||||
parser.add_argument("--config", default="eval_framework/config/config.yaml", help="配置文件路径")
|
||||
parser.add_argument("--input", help="输入数据文件路径(覆盖配置文件)")
|
||||
parser.add_argument("--output-dir", help="输出目录路径(覆盖配置文件)")
|
||||
parser.add_argument("--workers", type=int, help="工作线程数(覆盖配置文件)")
|
||||
parser.add_argument("--models", nargs="+", help="指定要评估的模型列表(覆盖配置文件)")
|
||||
parser.add_argument("--no-timestamp", action="store_true", help="不使用时间戳文件夹")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 加载配置
|
||||
config = load_config(args.config)
|
||||
|
||||
# 如果指定了不使用时间戳,修改配置
|
||||
if args.no_timestamp:
|
||||
config['evaluation']['output']['auto_timestamp'] = False
|
||||
|
||||
# 设置日志
|
||||
setup_logging(
|
||||
level=config.get('logging', {}).get('level', 'INFO'),
|
||||
format_str=config.get('logging', {}).get('format')
|
||||
)
|
||||
|
||||
logger.info("Starting multi-model evaluation framework")
|
||||
|
||||
# 处理输入路径和工作线程数
|
||||
input_file = args.input or config['evaluation']['input_file']
|
||||
if args.workers:
|
||||
config['evaluation']['max_workers'] = args.workers
|
||||
|
||||
# 获取模型列表
|
||||
if args.models:
|
||||
models = args.models
|
||||
logger.info(f"Using models from command line: {models}")
|
||||
else:
|
||||
models = get_models_from_config(config)
|
||||
logger.info(f"Using models from config: {models}")
|
||||
|
||||
# 生成输出目录
|
||||
if args.output_dir:
|
||||
output_dir = args.output_dir
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
output_dir = generate_output_dir(config)
|
||||
|
||||
logger.info(f"Output directory: {output_dir}")
|
||||
|
||||
try:
|
||||
# 加载数据
|
||||
logger.info(f"Loading data from {input_file}")
|
||||
data = DataLoader.load_and_validate_data(input_file)
|
||||
|
||||
if not data:
|
||||
logger.error("No valid data found")
|
||||
return
|
||||
|
||||
logger.info(f"Loaded {len(data)} valid data items")
|
||||
|
||||
# 存储所有模型的结果
|
||||
all_results = {}
|
||||
|
||||
# 逐个评估模型
|
||||
for i, model_name in enumerate(models, 1):
|
||||
logger.info(f"Evaluating model {i}/{len(models)}: {model_name}")
|
||||
|
||||
try:
|
||||
model_result = evaluate_single_model(model_name, data[:10], config, output_dir)
|
||||
all_results[model_name] = model_result
|
||||
|
||||
# 打印当前模型的结果
|
||||
print_metrics(model_result["metrics"], model_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to evaluate model {model_name}: {e}")
|
||||
continue
|
||||
|
||||
# 保存汇总结果
|
||||
if all_results:
|
||||
summary_filename = config['evaluation']['output']['summary_filename']
|
||||
save_summary(all_results, output_dir, summary_filename)
|
||||
|
||||
# 打印汇总对比
|
||||
print_summary(all_results)
|
||||
|
||||
logger.info(f"Summary saved to {Path(output_dir) / summary_filename}")
|
||||
|
||||
logger.info("Multi-model evaluation completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Evaluation failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
26
eval_framework/src/__init__.py
Normal file
26
eval_framework/src/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from .data_loader import DataLoader
|
||||
from .llm_client import LLMClient
|
||||
from .evaluator import Evaluator
|
||||
from .metrics import MetricsCalculator
|
||||
from .utils import (
|
||||
load_config, save_results, save_metrics, save_summary,
|
||||
setup_logging, print_metrics, print_summary,
|
||||
get_models_from_config, generate_output_dir, generate_model_output_path
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'DataLoader',
|
||||
'LLMClient',
|
||||
'Evaluator',
|
||||
'MetricsCalculator',
|
||||
'load_config',
|
||||
'save_results',
|
||||
'save_metrics',
|
||||
'save_summary',
|
||||
'setup_logging',
|
||||
'print_metrics',
|
||||
'print_summary',
|
||||
'get_models_from_config',
|
||||
'generate_output_dir',
|
||||
'generate_model_output_path'
|
||||
]
|
||||
BIN
eval_framework/src/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/data_loader.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/data_loader.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/data_loader.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/data_loader.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/evaluator.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/evaluator.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/evaluator.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/evaluator.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/llm_client.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/llm_client.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/llm_client.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/llm_client.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/metrics.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/metrics.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/metrics.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/metrics.cpython-312.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/utils.cpython-311.pyc
Normal file
BIN
eval_framework/src/__pycache__/utils.cpython-311.pyc
Normal file
Binary file not shown.
BIN
eval_framework/src/__pycache__/utils.cpython-312.pyc
Normal file
BIN
eval_framework/src/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
81
eval_framework/src/data_loader.py
Normal file
81
eval_framework/src/data_loader.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DataLoader:
|
||||
"""数据加载器,负责加载和验证数据"""
|
||||
|
||||
@staticmethod
|
||||
def load_json_data(filepath: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从JSON文件加载数据
|
||||
|
||||
Args:
|
||||
filepath: JSON文件路径
|
||||
|
||||
Returns:
|
||||
加载的数据列表
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
json.JSONDecodeError: JSON格式错误
|
||||
"""
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
logger.info(f"Successfully loaded {len(data)} items from {filepath}")
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
logger.error(f"File not found: {filepath}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error in {filepath}: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def validate_data_item(item: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
验证数据项是否包含必要字段
|
||||
|
||||
Args:
|
||||
item: 数据项
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
required_fields = ['question', 'choices', 'answer', 'prompt']
|
||||
for field in required_fields:
|
||||
if field not in item:
|
||||
logger.warning(f"Missing required field: {field}")
|
||||
return False
|
||||
|
||||
if 'text' not in item['choices'] or 'label' not in item['choices']:
|
||||
logger.warning("Missing 'text' or 'label' in choices")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def load_and_validate_data(cls, filepath: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
加载并验证数据
|
||||
|
||||
Args:
|
||||
filepath: JSON文件路径
|
||||
|
||||
Returns:
|
||||
验证后的数据列表
|
||||
"""
|
||||
data = cls.load_json_data(filepath)
|
||||
valid_data = []
|
||||
|
||||
for i, item in enumerate(data):
|
||||
if cls.validate_data_item(item):
|
||||
valid_data.append(item)
|
||||
else:
|
||||
logger.warning(f"Invalid data item at index {i}, skipping")
|
||||
|
||||
logger.info(f"Validated {len(valid_data)} out of {len(data)} items")
|
||||
return valid_data
|
||||
98
eval_framework/src/evaluator.py
Normal file
98
eval_framework/src/evaluator.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
import concurrent.futures
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from tqdm import tqdm
|
||||
|
||||
from .llm_client import LLMClient
|
||||
from .metrics import MetricsCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Evaluator:
|
||||
"""评估器,协调整个评估流程"""
|
||||
|
||||
def __init__(self, llm_client: LLMClient, system_prompt: str):
|
||||
"""
|
||||
初始化评估器
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
system_prompt: 系统提示词
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.system_prompt = system_prompt
|
||||
self.metrics_calculator = MetricsCalculator()
|
||||
|
||||
def process_item(self, item: Dict[str, Any], index: int) -> Dict[str, Any]:
|
||||
"""
|
||||
处理单个数据项
|
||||
|
||||
Args:
|
||||
item: 数据项
|
||||
index: 数据项索引
|
||||
|
||||
Returns:
|
||||
处理结果
|
||||
"""
|
||||
question = item['question']
|
||||
text = item['choices']['text']
|
||||
label = item['choices']['label']
|
||||
prompt = item['prompt']
|
||||
expected_answer = item['answer'].strip()
|
||||
|
||||
# 格式化选择项
|
||||
formatted_choices = " ".join([f"({lbl}) {txt}" for lbl, txt in zip(label, text)])
|
||||
user_input = f"{question} {formatted_choices}. {prompt}"
|
||||
|
||||
# 获取LLM响应
|
||||
llm_answer = self.llm_client.get_response(user_input, self.system_prompt)
|
||||
|
||||
return {
|
||||
'index': index,
|
||||
'question': question,
|
||||
'choices': item['choices'],
|
||||
'answer': expected_answer,
|
||||
'llm_answer': llm_answer
|
||||
}
|
||||
|
||||
def evaluate(self, data: List[Dict[str, Any]], max_workers: int = 5) -> Tuple[Dict[str, float], List[Dict[str, Any]]]:
|
||||
"""
|
||||
评估数据集
|
||||
|
||||
Args:
|
||||
data: 数据集
|
||||
max_workers: 最大工作线程数
|
||||
|
||||
Returns:
|
||||
评估指标和详细结果
|
||||
"""
|
||||
results = []
|
||||
|
||||
logger.info(f"Starting evaluation with {max_workers} workers")
|
||||
|
||||
with tqdm(total=len(data), desc="Processing items") as pbar:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_index = {
|
||||
executor.submit(self.process_item, item, i): i
|
||||
for i, item in enumerate(data)
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for future in concurrent.futures.as_completed(future_to_index):
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
pbar.update(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing item: {e}")
|
||||
pbar.update(1)
|
||||
|
||||
# 按索引排序结果
|
||||
results.sort(key=lambda x: x['index'])
|
||||
|
||||
# 计算指标
|
||||
metrics = self.metrics_calculator.compute_metrics(results)
|
||||
|
||||
logger.info("Evaluation completed successfully")
|
||||
return metrics, results
|
||||
60
eval_framework/src/llm_client.py
Normal file
60
eval_framework/src/llm_client.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from openai import OpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LLMClient:
|
||||
"""LLM客户端,负责与API交互"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, model: str,
|
||||
temperature: float = 0, max_retries: int = 10):
|
||||
"""
|
||||
初始化LLM客户端
|
||||
|
||||
Args:
|
||||
api_key: API密钥
|
||||
base_url: API基础URL
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_retries: 最大重试次数
|
||||
"""
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_retries = max_retries
|
||||
|
||||
def get_response(self, user_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
获取LLM响应
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
system_prompt: 系统提示词
|
||||
|
||||
Returns:
|
||||
LLM响应,失败时返回"error!"
|
||||
"""
|
||||
retries = 0
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_input}
|
||||
],
|
||||
temperature=self.temperature
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
return answer
|
||||
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
logger.warning(f"API call failed (Attempt {retries}/{self.max_retries}): {e}")
|
||||
if retries < self.max_retries:
|
||||
time.sleep(2 ** retries) # 指数退避
|
||||
|
||||
logger.error(f"Failed to get response after {self.max_retries} attempts")
|
||||
return "error!"
|
||||
111
eval_framework/src/metrics.py
Normal file
111
eval_framework/src/metrics.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import re
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetricsCalculator:
|
||||
"""评估指标计算器"""
|
||||
|
||||
@staticmethod
|
||||
def extract_answer(answer_string: str) -> Optional[str]:
|
||||
"""
|
||||
从回答字符串中提取答案
|
||||
|
||||
Args:
|
||||
answer_string: 包含答案的字符串
|
||||
|
||||
Returns:
|
||||
提取的答案,如果没有找到返回None
|
||||
"""
|
||||
if not answer_string:
|
||||
return None
|
||||
|
||||
match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_answer(answer: Optional[str]) -> List[str]:
|
||||
"""
|
||||
解析答案为列表
|
||||
|
||||
Args:
|
||||
answer: 答案字符串
|
||||
|
||||
Returns:
|
||||
答案列表
|
||||
"""
|
||||
if answer is None:
|
||||
return []
|
||||
return [a.strip() for a in answer.split(',')]
|
||||
|
||||
@classmethod
|
||||
def compute_metrics(cls, data: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
"""
|
||||
计算评估指标
|
||||
|
||||
Args:
|
||||
data: 包含真实答案和预测答案的数据
|
||||
|
||||
Returns:
|
||||
各种评估指标的字典
|
||||
"""
|
||||
true_answers = []
|
||||
pred_answers = []
|
||||
|
||||
# 提取和解析答案
|
||||
for item in data:
|
||||
true_ans = cls.extract_answer(item["answer"])
|
||||
pred_ans = cls.extract_answer(item["llm_answer"])
|
||||
|
||||
true_answers.append(cls.parse_answer(true_ans))
|
||||
pred_answers.append(cls.parse_answer(pred_ans))
|
||||
|
||||
# 计算准确率
|
||||
correct_counts = []
|
||||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||||
if true_ans and pred_ans and set(true_ans) == set(pred_ans):
|
||||
correct_counts.append(1)
|
||||
else:
|
||||
correct_counts.append(0)
|
||||
|
||||
accuracy = np.mean(correct_counts)
|
||||
|
||||
# 构建多标签向量
|
||||
all_labels = set()
|
||||
for item in data:
|
||||
choices = item["choices"]["label"]
|
||||
for label in choices:
|
||||
all_labels.add(label)
|
||||
|
||||
all_labels = sorted(list(all_labels))
|
||||
|
||||
y_true_multi = []
|
||||
y_pred_multi = []
|
||||
|
||||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||||
true_vector = [1 if label in (true_ans or []) else 0 for label in all_labels]
|
||||
pred_vector = [1 if label in (pred_ans or []) else 0 for label in all_labels]
|
||||
y_true_multi.append(true_vector)
|
||||
y_pred_multi.append(pred_vector)
|
||||
|
||||
y_true_multi = np.array(y_true_multi)
|
||||
y_pred_multi = np.array(y_pred_multi)
|
||||
|
||||
# 计算各种指标
|
||||
metrics = {
|
||||
"accuracy": accuracy,
|
||||
"precision_micro": precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||||
"recall_micro": recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||||
"f1_micro": f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0),
|
||||
"precision_macro": precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
|
||||
"recall_macro": recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0),
|
||||
"f1_macro": f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
||||
}
|
||||
|
||||
logger.info("Metrics computed successfully")
|
||||
return metrics
|
||||
360
eval_framework/src/utils.py
Normal file
360
eval_framework/src/utils.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import json
|
||||
import yaml
|
||||
import logging
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
from tabulate import tabulate
|
||||
|
||||
def load_config(config_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
加载配置文件
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
|
||||
Returns:
|
||||
配置字典
|
||||
"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
def get_models_from_config(config: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
从配置中获取模型列表
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
模型名称列表
|
||||
"""
|
||||
api_config = config['api']
|
||||
|
||||
# 优先使用models列表
|
||||
if 'models' in api_config and api_config['models']:
|
||||
return api_config['models']
|
||||
# 向后兼容:如果没有models,使用单个model
|
||||
elif 'model' in api_config:
|
||||
return [api_config['model']]
|
||||
else:
|
||||
raise ValueError("No models specified in configuration")
|
||||
|
||||
def generate_output_dir(config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
生成输出目录路径
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
输出目录路径
|
||||
"""
|
||||
output_config = config['evaluation']['output']
|
||||
base_dir = output_config['base_dir']
|
||||
auto_timestamp = output_config.get('auto_timestamp', True)
|
||||
|
||||
# 创建基础目录
|
||||
base_path = Path(base_dir)
|
||||
|
||||
if auto_timestamp:
|
||||
# 创建时间戳文件夹 (年月日时分)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
|
||||
output_dir = base_path / timestamp
|
||||
else:
|
||||
output_dir = base_path
|
||||
|
||||
# 确保目录存在
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return str(output_dir)
|
||||
|
||||
def generate_model_output_path(output_dir: str, model_name: str, filename_template: str) -> str:
|
||||
"""
|
||||
为特定模型生成输出文件路径
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
model_name: 模型名称
|
||||
filename_template: 文件名模板
|
||||
|
||||
Returns:
|
||||
完整的输出文件路径
|
||||
"""
|
||||
# 处理模型名中的特殊字符
|
||||
safe_model_name = model_name.replace('/', '_').replace(':', '_')
|
||||
filename = filename_template.format(model=safe_model_name)
|
||||
return str(Path(output_dir) / filename)
|
||||
|
||||
def save_results(results: list, filepath: str) -> None:
|
||||
"""
|
||||
保存结果到JSON文件
|
||||
|
||||
Args:
|
||||
results: 结果列表
|
||||
filepath: 保存路径
|
||||
"""
|
||||
# 确保目录存在
|
||||
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def save_metrics(metrics: Dict[str, float], filepath: str) -> None:
|
||||
"""
|
||||
保存评估指标到JSON文件
|
||||
|
||||
Args:
|
||||
metrics: 指标字典
|
||||
filepath: 保存路径
|
||||
"""
|
||||
# 生成指标文件路径(在同一目录下)
|
||||
metrics_path = Path(filepath).parent / f"{Path(filepath).stem}_metrics.json"
|
||||
|
||||
# 添加时间戳和其他元信息
|
||||
metrics_with_meta = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metrics": metrics
|
||||
}
|
||||
|
||||
with open(metrics_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metrics_with_meta, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def create_results_dataframe(all_results: Dict[str, Dict]) -> pd.DataFrame:
|
||||
"""
|
||||
将所有模型的结果转换为DataFrame
|
||||
|
||||
Args:
|
||||
all_results: 所有模型的结果字典
|
||||
|
||||
Returns:
|
||||
包含所有模型指标的DataFrame
|
||||
"""
|
||||
if not all_results:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 收集所有模型的指标数据
|
||||
data = []
|
||||
for model_name, model_result in all_results.items():
|
||||
row = {"Model": model_name}
|
||||
row.update(model_result["metrics"])
|
||||
row["Data Count"] = len(model_result["results"])
|
||||
data.append(row)
|
||||
|
||||
# 创建DataFrame
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# 将Model列设为索引
|
||||
df = df.set_index("Model")
|
||||
|
||||
# 对列进行排序(将Data Count放在最后)
|
||||
metric_columns = [col for col in df.columns if col != "Data Count"]
|
||||
df = df[metric_columns + ["Data Count"]]
|
||||
|
||||
return df
|
||||
|
||||
def save_summary(all_results: Dict[str, Dict], output_dir: str, summary_filename: str) -> None:
|
||||
"""
|
||||
保存所有模型的汇总结果
|
||||
|
||||
Args:
|
||||
all_results: 所有模型的结果字典
|
||||
output_dir: 输出目录
|
||||
summary_filename: 汇总文件名
|
||||
"""
|
||||
output_path = Path(output_dir)
|
||||
|
||||
# 创建DataFrame
|
||||
df = create_results_dataframe(all_results)
|
||||
|
||||
if df.empty:
|
||||
logging.warning("No results to save in summary")
|
||||
return
|
||||
|
||||
# 保存JSON格式的详细汇总
|
||||
summary_path = output_path / summary_filename
|
||||
summary_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"models_count": len(all_results),
|
||||
"models": {}
|
||||
}
|
||||
|
||||
for model_name, model_result in all_results.items():
|
||||
summary_data["models"][model_name] = {
|
||||
"metrics": model_result["metrics"],
|
||||
"data_count": len(model_result["results"])
|
||||
}
|
||||
|
||||
# 添加模型对比表
|
||||
if len(all_results) > 1:
|
||||
comparison = {}
|
||||
metric_names = [col for col in df.columns if col != "Data Count"]
|
||||
|
||||
for metric in metric_names:
|
||||
comparison[metric] = df[metric].to_dict()
|
||||
|
||||
summary_data["comparison"] = comparison
|
||||
|
||||
with open(summary_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 保存CSV格式的汇总表格
|
||||
csv_filename = summary_filename.replace('.json', '.csv')
|
||||
csv_path = output_path / csv_filename
|
||||
|
||||
# 重置索引以便模型名称也作为列保存
|
||||
df_for_csv = df.reset_index()
|
||||
df_for_csv.to_csv(csv_path, index=False, encoding='utf-8')
|
||||
|
||||
# 保存Excel格式(如果需要)
|
||||
excel_filename = summary_filename.replace('.json', '.xlsx')
|
||||
excel_path = output_path / excel_filename
|
||||
|
||||
try:
|
||||
# 创建Excel文件,包含多个工作表
|
||||
with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
|
||||
# 主要结果表
|
||||
df_for_csv.to_excel(writer, sheet_name='Summary', index=False)
|
||||
|
||||
# 如果有多个模型,创建排名表
|
||||
if len(all_results) > 1:
|
||||
ranking_df = create_ranking_dataframe(df)
|
||||
ranking_df.to_excel(writer, sheet_name='Rankings', index=False)
|
||||
|
||||
except ImportError:
|
||||
logging.warning("openpyxl not installed, skipping Excel export")
|
||||
|
||||
logging.info(f"Summary saved to {summary_path}")
|
||||
logging.info(f"CSV summary saved to {csv_path}")
|
||||
|
||||
def create_ranking_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
创建模型排名DataFrame
|
||||
|
||||
Args:
|
||||
df: 原始结果DataFrame
|
||||
|
||||
Returns:
|
||||
包含排名的DataFrame
|
||||
"""
|
||||
# 排除非指标列
|
||||
metric_columns = [col for col in df.columns if col != "Data Count"]
|
||||
|
||||
# 为每个指标创建排名(假设数值越大越好,可以根据需要调整)
|
||||
ranking_data = []
|
||||
|
||||
for metric in metric_columns:
|
||||
# 创建排名(降序,数值越大排名越前)
|
||||
ranks = df[metric].rank(method='min', ascending=False)
|
||||
|
||||
for model_name in df.index:
|
||||
ranking_data.append({
|
||||
'Model': model_name,
|
||||
'Metric': metric,
|
||||
'Value': df.loc[model_name, metric],
|
||||
'Rank': int(ranks[model_name])
|
||||
})
|
||||
|
||||
ranking_df = pd.DataFrame(ranking_data)
|
||||
return ranking_df
|
||||
|
||||
def print_summary(all_results: Dict[str, Dict]) -> None:
|
||||
"""
|
||||
打印所有模型的汇总结果
|
||||
|
||||
Args:
|
||||
all_results: 所有模型的结果字典
|
||||
"""
|
||||
print("\n" + "="*100)
|
||||
print("SUMMARY - ALL MODELS COMPARISON")
|
||||
print("="*100)
|
||||
|
||||
if not all_results:
|
||||
print("No results to display")
|
||||
return
|
||||
|
||||
# 创建DataFrame
|
||||
df = create_results_dataframe(all_results)
|
||||
|
||||
if df.empty:
|
||||
print("No valid results to display")
|
||||
return
|
||||
|
||||
# 使用tabulate打印美观的表格
|
||||
print(tabulate(
|
||||
df,
|
||||
headers=df.columns,
|
||||
tablefmt='grid',
|
||||
floatfmt='.4f',
|
||||
showindex=True
|
||||
))
|
||||
|
||||
# 如果有多个模型,显示最佳模型
|
||||
if len(all_results) > 1:
|
||||
print("\n" + "-"*100)
|
||||
print("BEST PERFORMERS BY METRIC:")
|
||||
print("-"*100)
|
||||
|
||||
metric_columns = [col for col in df.columns if col != "Data Count"]
|
||||
|
||||
for metric in metric_columns:
|
||||
best_model = df[metric].idxmax()
|
||||
best_value = df.loc[best_model, metric]
|
||||
print(f"{metric.upper():<20}: {best_model:<30} ({best_value:.4f})")
|
||||
|
||||
print("="*100)
|
||||
|
||||
def setup_logging(level: str = "INFO", format_str: str = None, log_dir: str = "logs") -> None:
|
||||
"""
|
||||
设置日志配置
|
||||
|
||||
Args:
|
||||
level: 日志级别
|
||||
format_str: 日志格式
|
||||
log_dir: 日志目录
|
||||
"""
|
||||
if format_str is None:
|
||||
format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
# 创建日志目录
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成日志文件名(包含时间戳)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
|
||||
log_file = Path(log_dir) / f"evaluation_{timestamp}.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, level.upper()),
|
||||
format=format_str,
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(log_file, encoding='utf-8')
|
||||
]
|
||||
)
|
||||
|
||||
def print_metrics(metrics: Dict[str, float], model_name: str = None) -> None:
|
||||
"""
|
||||
打印评估指标
|
||||
|
||||
Args:
|
||||
metrics: 指标字典
|
||||
model_name: 模型名称
|
||||
"""
|
||||
title = f"EVALUATION RESULTS - {model_name}" if model_name else "EVALUATION RESULTS"
|
||||
print("\n" + "="*60)
|
||||
print(title)
|
||||
print("="*60)
|
||||
|
||||
# 创建单行DataFrame用于美观显示
|
||||
df = pd.DataFrame([metrics])
|
||||
print(tabulate(
|
||||
df,
|
||||
headers=df.columns,
|
||||
tablefmt='grid',
|
||||
floatfmt='.4f',
|
||||
showindex=False
|
||||
))
|
||||
|
||||
print("="*60)
|
||||
Reference in New Issue
Block a user