Files
MatBench/eval_framework/main.py
2025-05-29 20:18:57 +08:00

175 lines
5.5 KiB
Python

"""
@file
@brief 主程序入口,负责加载配置、处理命令行参数、执行模型评估等
@author: Yutang Li
@mail: yt.li2@siat.ac.cn
@date: 2025-05-28
@version: 1.0
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Any
from src import (
DataLoader, LLMClient, Evaluator,
load_config, save_results, save_metrics, save_summary,
setup_logging, print_metrics, print_summary,
get_models_from_config, generate_output_dir, generate_model_output_path
)
logger = logging.getLogger(__name__)
def evaluate_single_model(
model_name: str,
data: list,
config: Dict[str, Any],
output_dir: str
) -> Dict[str, Any]:
"""
评估单个模型
Args:
model_name: 模型名称
data: 评估数据
config: 配置字典
output_dir: 输出目录
Returns:
包含指标和结果的字典
"""
logger.info(f"Starting evaluation for model: {model_name}")
# 初始化LLM客户端
llm_client = LLMClient(
api_key=config['api']['key'],
base_url=config['api']['base_url'],
model=model_name,
temperature=config['api']['temperature'],
max_retries=config['api']['max_retries']
)
# 初始化评估器
evaluator = Evaluator(
llm_client=llm_client,
system_prompt=config['system_prompt']
)
# 执行评估
max_workers = config['evaluation']['max_workers']
metrics, results = evaluator.evaluate(data, max_workers=max_workers)
# 生成输出文件路径
filename_template = config['evaluation']['output']['filename_template']
output_file = generate_model_output_path(output_dir, model_name, filename_template)
# 保存结果和指标
save_results(results, output_file)
save_metrics(metrics, output_file)
logger.info(f"Model {model_name} evaluation completed. Results saved to {output_file}")
return {
"metrics": metrics,
"results": results,
"output_file": output_file
}
def main():
parser = argparse.ArgumentParser(description="材料科学LLM评估框架")
parser.add_argument("--config", default="eval_framework/config/config.yaml", help="配置文件路径")
parser.add_argument("--input", help="输入数据文件路径(覆盖配置文件)")
parser.add_argument("--output-dir", help="输出目录路径(覆盖配置文件)")
parser.add_argument("--workers", type=int, help="工作线程数(覆盖配置文件)")
parser.add_argument("--models", nargs="+", help="指定要评估的模型列表(覆盖配置文件)")
parser.add_argument("--no-timestamp", action="store_true", help="不使用时间戳文件夹")
args = parser.parse_args()
# 加载配置
config = load_config(args.config)
# 如果指定了不使用时间戳,修改配置
if args.no_timestamp:
config['evaluation']['output']['auto_timestamp'] = False
# 设置日志
setup_logging(
level=config.get('logging', {}).get('level', 'INFO'),
format_str=config.get('logging', {}).get('format')
)
logger.info("Starting multi-model evaluation framework")
# 处理输入路径和工作线程数
input_file = args.input or config['evaluation']['input_file']
if args.workers:
config['evaluation']['max_workers'] = args.workers
# 获取模型列表
if args.models:
models = args.models
logger.info(f"Using models from command line: {models}")
else:
models = get_models_from_config(config)
logger.info(f"Using models from config: {models}")
# 生成输出目录
if args.output_dir:
output_dir = args.output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
else:
output_dir = generate_output_dir(config)
logger.info(f"Output directory: {output_dir}")
try:
# 加载数据
logger.info(f"Loading data from {input_file}")
data = DataLoader.load_and_validate_data(input_file)
if not data:
logger.error("No valid data found")
return
logger.info(f"Loaded {len(data)} valid data items")
# 存储所有模型的结果
all_results = {}
# 逐个评估模型
for i, model_name in enumerate(models, 1):
logger.info(f"Evaluating model {i}/{len(models)}: {model_name}")
try:
model_result = evaluate_single_model(model_name, data, config, output_dir)
all_results[model_name] = model_result
# 打印当前模型的结果
print_metrics(model_result["metrics"], model_name)
except Exception as e:
logger.error(f"Failed to evaluate model {model_name}: {e}")
continue
# 保存汇总结果
if all_results:
summary_filename = config['evaluation']['output']['summary_filename']
save_summary(all_results, output_dir, summary_filename)
# 打印汇总对比
print_summary(all_results)
logger.info(f"Summary saved to {Path(output_dir) / summary_filename}")
logger.info("Multi-model evaluation completed successfully")
except Exception as e:
logger.error(f"Evaluation failed: {e}")
raise
if __name__ == "__main__":
main()