175 lines
5.5 KiB
Python
175 lines
5.5 KiB
Python
"""
|
|
@file
|
|
@brief 主程序入口,负责加载配置、处理命令行参数、执行模型评估等
|
|
@author: Yutang Li
|
|
@mail: yt.li2@siat.ac.cn
|
|
@date: 2025-05-28
|
|
@version: 1.0
|
|
"""
|
|
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, Any
|
|
|
|
from src import (
|
|
DataLoader, LLMClient, Evaluator,
|
|
load_config, save_results, save_metrics, save_summary,
|
|
setup_logging, print_metrics, print_summary,
|
|
get_models_from_config, generate_output_dir, generate_model_output_path
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def evaluate_single_model(
|
|
model_name: str,
|
|
data: list,
|
|
config: Dict[str, Any],
|
|
output_dir: str
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
评估单个模型
|
|
|
|
Args:
|
|
model_name: 模型名称
|
|
data: 评估数据
|
|
config: 配置字典
|
|
output_dir: 输出目录
|
|
|
|
Returns:
|
|
包含指标和结果的字典
|
|
"""
|
|
logger.info(f"Starting evaluation for model: {model_name}")
|
|
|
|
# 初始化LLM客户端
|
|
llm_client = LLMClient(
|
|
api_key=config['api']['key'],
|
|
base_url=config['api']['base_url'],
|
|
model=model_name,
|
|
temperature=config['api']['temperature'],
|
|
max_retries=config['api']['max_retries']
|
|
)
|
|
|
|
# 初始化评估器
|
|
evaluator = Evaluator(
|
|
llm_client=llm_client,
|
|
system_prompt=config['system_prompt']
|
|
)
|
|
|
|
# 执行评估
|
|
max_workers = config['evaluation']['max_workers']
|
|
metrics, results = evaluator.evaluate(data, max_workers=max_workers)
|
|
|
|
# 生成输出文件路径
|
|
filename_template = config['evaluation']['output']['filename_template']
|
|
output_file = generate_model_output_path(output_dir, model_name, filename_template)
|
|
|
|
# 保存结果和指标
|
|
save_results(results, output_file)
|
|
save_metrics(metrics, output_file)
|
|
|
|
logger.info(f"Model {model_name} evaluation completed. Results saved to {output_file}")
|
|
|
|
return {
|
|
"metrics": metrics,
|
|
"results": results,
|
|
"output_file": output_file
|
|
}
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="材料科学LLM评估框架")
|
|
parser.add_argument("--config", default="eval_framework/config/config.yaml", help="配置文件路径")
|
|
parser.add_argument("--input", help="输入数据文件路径(覆盖配置文件)")
|
|
parser.add_argument("--output-dir", help="输出目录路径(覆盖配置文件)")
|
|
parser.add_argument("--workers", type=int, help="工作线程数(覆盖配置文件)")
|
|
parser.add_argument("--models", nargs="+", help="指定要评估的模型列表(覆盖配置文件)")
|
|
parser.add_argument("--no-timestamp", action="store_true", help="不使用时间戳文件夹")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 加载配置
|
|
config = load_config(args.config)
|
|
|
|
# 如果指定了不使用时间戳,修改配置
|
|
if args.no_timestamp:
|
|
config['evaluation']['output']['auto_timestamp'] = False
|
|
|
|
# 设置日志
|
|
setup_logging(
|
|
level=config.get('logging', {}).get('level', 'INFO'),
|
|
format_str=config.get('logging', {}).get('format')
|
|
)
|
|
|
|
logger.info("Starting multi-model evaluation framework")
|
|
|
|
# 处理输入路径和工作线程数
|
|
input_file = args.input or config['evaluation']['input_file']
|
|
if args.workers:
|
|
config['evaluation']['max_workers'] = args.workers
|
|
|
|
# 获取模型列表
|
|
if args.models:
|
|
models = args.models
|
|
logger.info(f"Using models from command line: {models}")
|
|
else:
|
|
models = get_models_from_config(config)
|
|
logger.info(f"Using models from config: {models}")
|
|
|
|
# 生成输出目录
|
|
if args.output_dir:
|
|
output_dir = args.output_dir
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
else:
|
|
output_dir = generate_output_dir(config)
|
|
|
|
logger.info(f"Output directory: {output_dir}")
|
|
|
|
try:
|
|
# 加载数据
|
|
logger.info(f"Loading data from {input_file}")
|
|
data = DataLoader.load_and_validate_data(input_file)
|
|
|
|
if not data:
|
|
logger.error("No valid data found")
|
|
return
|
|
|
|
logger.info(f"Loaded {len(data)} valid data items")
|
|
|
|
# 存储所有模型的结果
|
|
all_results = {}
|
|
|
|
# 逐个评估模型
|
|
for i, model_name in enumerate(models, 1):
|
|
logger.info(f"Evaluating model {i}/{len(models)}: {model_name}")
|
|
|
|
try:
|
|
model_result = evaluate_single_model(model_name, data, config, output_dir)
|
|
all_results[model_name] = model_result
|
|
|
|
# 打印当前模型的结果
|
|
print_metrics(model_result["metrics"], model_name)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to evaluate model {model_name}: {e}")
|
|
continue
|
|
|
|
# 保存汇总结果
|
|
if all_results:
|
|
summary_filename = config['evaluation']['output']['summary_filename']
|
|
save_summary(all_results, output_dir, summary_filename)
|
|
|
|
# 打印汇总对比
|
|
print_summary(all_results)
|
|
|
|
logger.info(f"Summary saved to {Path(output_dir) / summary_filename}")
|
|
|
|
logger.info("Multi-model evaluation completed successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Evaluation failed: {e}")
|
|
raise
|
|
|
|
if __name__ == "__main__":
|
|
main()
|