diff --git a/eval_framework/config/config.yaml b/eval_framework/config/config.yaml new file mode 100644 index 0000000..8ef7f91 --- /dev/null +++ b/eval_framework/config/config.yaml @@ -0,0 +1,36 @@ +# API配置 +api: + key: "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d" + base_url: "https://vip.apiyi.com/v1" + temperature: 0 + max_retries: 10 + # 支持多个模型 + models: + - "qwen-max-2025-01-25" + - "gpt-4o" + # 或者使用单个模型(向后兼容) + # model: "qwen-max-2025-01-25" + +# 系统提示词 +system_prompt: "You are an expert in the field of materials science, adept at answering questions related to fundamental aspects of materials science, including material structure, properties, processing, and applications." + +# 评估配置 +evaluation: + max_workers: 8 + input_file: "/home/ubuntu/50T/LYT/MatBench/layer1/ALL-merge/merged.json" + # 输出配置 + output: + base_dir: "results" + auto_timestamp: true + filename_template: "{model}.json" + summary_filename: "summary.json" + # 输出格式选项 + export_formats: + - "json" # 详细JSON结果 + - "csv" # CSV表格 + - "excel" # Excel表格(需要openpyxl) + +# 日志配置 +logging: + level: "INFO" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" diff --git a/eval_framework/main.py b/eval_framework/main.py new file mode 100644 index 0000000..6c7bf2b --- /dev/null +++ b/eval_framework/main.py @@ -0,0 +1,164 @@ +import argparse +import logging +from pathlib import Path +from typing import Dict, Any + +from src import ( + DataLoader, LLMClient, Evaluator, + load_config, save_results, save_metrics, save_summary, + setup_logging, print_metrics, print_summary, + get_models_from_config, generate_output_dir, generate_model_output_path +) + +logger = logging.getLogger(__name__) + +def evaluate_single_model( + model_name: str, + data: list, + config: Dict[str, Any], + output_dir: str +) -> Dict[str, Any]: + """ + 评估单个模型 + + Args: + model_name: 模型名称 + data: 评估数据 + config: 配置字典 + output_dir: 输出目录 + + Returns: + 包含指标和结果的字典 + """ + logger.info(f"Starting evaluation for model: {model_name}") + + # 初始化LLM客户端 + llm_client = LLMClient( + api_key=config['api']['key'], + base_url=config['api']['base_url'], + model=model_name, + temperature=config['api']['temperature'], + max_retries=config['api']['max_retries'] + ) + + # 初始化评估器 + evaluator = Evaluator( + llm_client=llm_client, + system_prompt=config['system_prompt'] + ) + + # 执行评估 + max_workers = config['evaluation']['max_workers'] + metrics, results = evaluator.evaluate(data, max_workers=max_workers) + + # 生成输出文件路径 + filename_template = config['evaluation']['output']['filename_template'] + output_file = generate_model_output_path(output_dir, model_name, filename_template) + + # 保存结果和指标 + save_results(results, output_file) + save_metrics(metrics, output_file) + + logger.info(f"Model {model_name} evaluation completed. Results saved to {output_file}") + + return { + "metrics": metrics, + "results": results, + "output_file": output_file + } + +def main(): + parser = argparse.ArgumentParser(description="材料科学LLM评估框架") + parser.add_argument("--config", default="eval_framework/config/config.yaml", help="配置文件路径") + parser.add_argument("--input", help="输入数据文件路径(覆盖配置文件)") + parser.add_argument("--output-dir", help="输出目录路径(覆盖配置文件)") + parser.add_argument("--workers", type=int, help="工作线程数(覆盖配置文件)") + parser.add_argument("--models", nargs="+", help="指定要评估的模型列表(覆盖配置文件)") + parser.add_argument("--no-timestamp", action="store_true", help="不使用时间戳文件夹") + + args = parser.parse_args() + + # 加载配置 + config = load_config(args.config) + + # 如果指定了不使用时间戳,修改配置 + if args.no_timestamp: + config['evaluation']['output']['auto_timestamp'] = False + + # 设置日志 + setup_logging( + level=config.get('logging', {}).get('level', 'INFO'), + format_str=config.get('logging', {}).get('format') + ) + + logger.info("Starting multi-model evaluation framework") + + # 处理输入路径和工作线程数 + input_file = args.input or config['evaluation']['input_file'] + if args.workers: + config['evaluation']['max_workers'] = args.workers + + # 获取模型列表 + if args.models: + models = args.models + logger.info(f"Using models from command line: {models}") + else: + models = get_models_from_config(config) + logger.info(f"Using models from config: {models}") + + # 生成输出目录 + if args.output_dir: + output_dir = args.output_dir + Path(output_dir).mkdir(parents=True, exist_ok=True) + else: + output_dir = generate_output_dir(config) + + logger.info(f"Output directory: {output_dir}") + + try: + # 加载数据 + logger.info(f"Loading data from {input_file}") + data = DataLoader.load_and_validate_data(input_file) + + if not data: + logger.error("No valid data found") + return + + logger.info(f"Loaded {len(data)} valid data items") + + # 存储所有模型的结果 + all_results = {} + + # 逐个评估模型 + for i, model_name in enumerate(models, 1): + logger.info(f"Evaluating model {i}/{len(models)}: {model_name}") + + try: + model_result = evaluate_single_model(model_name, data[:10], config, output_dir) + all_results[model_name] = model_result + + # 打印当前模型的结果 + print_metrics(model_result["metrics"], model_name) + + except Exception as e: + logger.error(f"Failed to evaluate model {model_name}: {e}") + continue + + # 保存汇总结果 + if all_results: + summary_filename = config['evaluation']['output']['summary_filename'] + save_summary(all_results, output_dir, summary_filename) + + # 打印汇总对比 + print_summary(all_results) + + logger.info(f"Summary saved to {Path(output_dir) / summary_filename}") + + logger.info("Multi-model evaluation completed successfully") + + except Exception as e: + logger.error(f"Evaluation failed: {e}") + raise + +if __name__ == "__main__": + main() diff --git a/eval_framework/src/__init__.py b/eval_framework/src/__init__.py new file mode 100644 index 0000000..d534827 --- /dev/null +++ b/eval_framework/src/__init__.py @@ -0,0 +1,26 @@ +from .data_loader import DataLoader +from .llm_client import LLMClient +from .evaluator import Evaluator +from .metrics import MetricsCalculator +from .utils import ( + load_config, save_results, save_metrics, save_summary, + setup_logging, print_metrics, print_summary, + get_models_from_config, generate_output_dir, generate_model_output_path +) + +__all__ = [ + 'DataLoader', + 'LLMClient', + 'Evaluator', + 'MetricsCalculator', + 'load_config', + 'save_results', + 'save_metrics', + 'save_summary', + 'setup_logging', + 'print_metrics', + 'print_summary', + 'get_models_from_config', + 'generate_output_dir', + 'generate_model_output_path' +] diff --git a/eval_framework/src/__pycache__/__init__.cpython-311.pyc b/eval_framework/src/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..430cb87 Binary files /dev/null and b/eval_framework/src/__pycache__/__init__.cpython-311.pyc differ diff --git a/eval_framework/src/__pycache__/__init__.cpython-312.pyc b/eval_framework/src/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..58a17be Binary files /dev/null and b/eval_framework/src/__pycache__/__init__.cpython-312.pyc differ diff --git a/eval_framework/src/__pycache__/data_loader.cpython-311.pyc b/eval_framework/src/__pycache__/data_loader.cpython-311.pyc new file mode 100644 index 0000000..07ff9cb Binary files /dev/null and b/eval_framework/src/__pycache__/data_loader.cpython-311.pyc differ diff --git a/eval_framework/src/__pycache__/data_loader.cpython-312.pyc b/eval_framework/src/__pycache__/data_loader.cpython-312.pyc new file mode 100644 index 0000000..5414c43 Binary files /dev/null and b/eval_framework/src/__pycache__/data_loader.cpython-312.pyc differ diff --git a/eval_framework/src/__pycache__/evaluator.cpython-311.pyc b/eval_framework/src/__pycache__/evaluator.cpython-311.pyc new file mode 100644 index 0000000..2ac9afd Binary files /dev/null and b/eval_framework/src/__pycache__/evaluator.cpython-311.pyc differ diff --git a/eval_framework/src/__pycache__/evaluator.cpython-312.pyc b/eval_framework/src/__pycache__/evaluator.cpython-312.pyc new file mode 100644 index 0000000..b5e2a0e Binary files /dev/null and b/eval_framework/src/__pycache__/evaluator.cpython-312.pyc differ diff --git a/eval_framework/src/__pycache__/llm_client.cpython-311.pyc b/eval_framework/src/__pycache__/llm_client.cpython-311.pyc new file mode 100644 index 0000000..ddffac6 Binary files /dev/null and b/eval_framework/src/__pycache__/llm_client.cpython-311.pyc differ diff --git a/eval_framework/src/__pycache__/llm_client.cpython-312.pyc b/eval_framework/src/__pycache__/llm_client.cpython-312.pyc new file mode 100644 index 0000000..5f212b7 Binary files /dev/null and b/eval_framework/src/__pycache__/llm_client.cpython-312.pyc differ diff --git a/eval_framework/src/__pycache__/metrics.cpython-311.pyc b/eval_framework/src/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000..4ba0133 Binary files /dev/null and b/eval_framework/src/__pycache__/metrics.cpython-311.pyc differ diff --git a/eval_framework/src/__pycache__/metrics.cpython-312.pyc b/eval_framework/src/__pycache__/metrics.cpython-312.pyc new file mode 100644 index 0000000..fd3282b Binary files /dev/null and b/eval_framework/src/__pycache__/metrics.cpython-312.pyc differ diff --git a/eval_framework/src/__pycache__/utils.cpython-311.pyc b/eval_framework/src/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000..40fd14f Binary files /dev/null and b/eval_framework/src/__pycache__/utils.cpython-311.pyc differ diff --git a/eval_framework/src/__pycache__/utils.cpython-312.pyc b/eval_framework/src/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000..487e507 Binary files /dev/null and b/eval_framework/src/__pycache__/utils.cpython-312.pyc differ diff --git a/eval_framework/src/data_loader.py b/eval_framework/src/data_loader.py new file mode 100644 index 0000000..9a97abe --- /dev/null +++ b/eval_framework/src/data_loader.py @@ -0,0 +1,81 @@ +import json +import logging +from typing import List, Dict, Any + +logger = logging.getLogger(__name__) + +class DataLoader: + """数据加载器,负责加载和验证数据""" + + @staticmethod + def load_json_data(filepath: str) -> List[Dict[str, Any]]: + """ + 从JSON文件加载数据 + + Args: + filepath: JSON文件路径 + + Returns: + 加载的数据列表 + + Raises: + FileNotFoundError: 文件不存在 + json.JSONDecodeError: JSON格式错误 + """ + try: + with open(filepath, 'r', encoding='utf-8') as file: + data = json.load(file) + logger.info(f"Successfully loaded {len(data)} items from {filepath}") + return data + except FileNotFoundError: + logger.error(f"File not found: {filepath}") + raise + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in {filepath}: {e}") + raise + + @staticmethod + def validate_data_item(item: Dict[str, Any]) -> bool: + """ + 验证数据项是否包含必要字段 + + Args: + item: 数据项 + + Returns: + 是否有效 + """ + required_fields = ['question', 'choices', 'answer', 'prompt'] + for field in required_fields: + if field not in item: + logger.warning(f"Missing required field: {field}") + return False + + if 'text' not in item['choices'] or 'label' not in item['choices']: + logger.warning("Missing 'text' or 'label' in choices") + return False + + return True + + @classmethod + def load_and_validate_data(cls, filepath: str) -> List[Dict[str, Any]]: + """ + 加载并验证数据 + + Args: + filepath: JSON文件路径 + + Returns: + 验证后的数据列表 + """ + data = cls.load_json_data(filepath) + valid_data = [] + + for i, item in enumerate(data): + if cls.validate_data_item(item): + valid_data.append(item) + else: + logger.warning(f"Invalid data item at index {i}, skipping") + + logger.info(f"Validated {len(valid_data)} out of {len(data)} items") + return valid_data diff --git a/eval_framework/src/evaluator.py b/eval_framework/src/evaluator.py new file mode 100644 index 0000000..2d4254b --- /dev/null +++ b/eval_framework/src/evaluator.py @@ -0,0 +1,98 @@ +import logging +import concurrent.futures +from typing import List, Dict, Any, Tuple +from tqdm import tqdm + +from .llm_client import LLMClient +from .metrics import MetricsCalculator + +logger = logging.getLogger(__name__) + +class Evaluator: + """评估器,协调整个评估流程""" + + def __init__(self, llm_client: LLMClient, system_prompt: str): + """ + 初始化评估器 + + Args: + llm_client: LLM客户端 + system_prompt: 系统提示词 + """ + self.llm_client = llm_client + self.system_prompt = system_prompt + self.metrics_calculator = MetricsCalculator() + + def process_item(self, item: Dict[str, Any], index: int) -> Dict[str, Any]: + """ + 处理单个数据项 + + Args: + item: 数据项 + index: 数据项索引 + + Returns: + 处理结果 + """ + question = item['question'] + text = item['choices']['text'] + label = item['choices']['label'] + prompt = item['prompt'] + expected_answer = item['answer'].strip() + + # 格式化选择项 + formatted_choices = " ".join([f"({lbl}) {txt}" for lbl, txt in zip(label, text)]) + user_input = f"{question} {formatted_choices}. {prompt}" + + # 获取LLM响应 + llm_answer = self.llm_client.get_response(user_input, self.system_prompt) + + return { + 'index': index, + 'question': question, + 'choices': item['choices'], + 'answer': expected_answer, + 'llm_answer': llm_answer + } + + def evaluate(self, data: List[Dict[str, Any]], max_workers: int = 5) -> Tuple[Dict[str, float], List[Dict[str, Any]]]: + """ + 评估数据集 + + Args: + data: 数据集 + max_workers: 最大工作线程数 + + Returns: + 评估指标和详细结果 + """ + results = [] + + logger.info(f"Starting evaluation with {max_workers} workers") + + with tqdm(total=len(data), desc="Processing items") as pbar: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交所有任务 + future_to_index = { + executor.submit(self.process_item, item, i): i + for i, item in enumerate(data) + } + + # 收集结果 + for future in concurrent.futures.as_completed(future_to_index): + try: + result = future.result() + results.append(result) + pbar.update(1) + except Exception as e: + logger.error(f"Error processing item: {e}") + pbar.update(1) + + # 按索引排序结果 + results.sort(key=lambda x: x['index']) + + # 计算指标 + metrics = self.metrics_calculator.compute_metrics(results) + + logger.info("Evaluation completed successfully") + return metrics, results diff --git a/eval_framework/src/llm_client.py b/eval_framework/src/llm_client.py new file mode 100644 index 0000000..6ce60ef --- /dev/null +++ b/eval_framework/src/llm_client.py @@ -0,0 +1,60 @@ +import logging +import time +from typing import Optional +from openai import OpenAI + +logger = logging.getLogger(__name__) + +class LLMClient: + """LLM客户端,负责与API交互""" + + def __init__(self, api_key: str, base_url: str, model: str, + temperature: float = 0, max_retries: int = 10): + """ + 初始化LLM客户端 + + Args: + api_key: API密钥 + base_url: API基础URL + model: 模型名称 + temperature: 温度参数 + max_retries: 最大重试次数 + """ + self.client = OpenAI(api_key=api_key, base_url=base_url) + self.model = model + self.temperature = temperature + self.max_retries = max_retries + + def get_response(self, user_input: str, system_prompt: str) -> str: + """ + 获取LLM响应 + + Args: + user_input: 用户输入 + system_prompt: 系统提示词 + + Returns: + LLM响应,失败时返回"error!" + """ + retries = 0 + while retries < self.max_retries: + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_input} + ], + temperature=self.temperature + ) + answer = response.choices[0].message.content + return answer + + except Exception as e: + retries += 1 + logger.warning(f"API call failed (Attempt {retries}/{self.max_retries}): {e}") + if retries < self.max_retries: + time.sleep(2 ** retries) # 指数退避 + + logger.error(f"Failed to get response after {self.max_retries} attempts") + return "error!" diff --git a/eval_framework/src/metrics.py b/eval_framework/src/metrics.py new file mode 100644 index 0000000..dcfb93d --- /dev/null +++ b/eval_framework/src/metrics.py @@ -0,0 +1,111 @@ +import re +import numpy as np +from typing import List, Dict, Any, Optional +from sklearn.metrics import precision_score, recall_score, f1_score +import logging + +logger = logging.getLogger(__name__) + +class MetricsCalculator: + """评估指标计算器""" + + @staticmethod + def extract_answer(answer_string: str) -> Optional[str]: + """ + 从回答字符串中提取答案 + + Args: + answer_string: 包含答案的字符串 + + Returns: + 提取的答案,如果没有找到返回None + """ + if not answer_string: + return None + + match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string) + if match: + return match.group(1).strip() + return None + + @staticmethod + def parse_answer(answer: Optional[str]) -> List[str]: + """ + 解析答案为列表 + + Args: + answer: 答案字符串 + + Returns: + 答案列表 + """ + if answer is None: + return [] + return [a.strip() for a in answer.split(',')] + + @classmethod + def compute_metrics(cls, data: List[Dict[str, Any]]) -> Dict[str, float]: + """ + 计算评估指标 + + Args: + data: 包含真实答案和预测答案的数据 + + Returns: + 各种评估指标的字典 + """ + true_answers = [] + pred_answers = [] + + # 提取和解析答案 + for item in data: + true_ans = cls.extract_answer(item["answer"]) + pred_ans = cls.extract_answer(item["llm_answer"]) + + true_answers.append(cls.parse_answer(true_ans)) + pred_answers.append(cls.parse_answer(pred_ans)) + + # 计算准确率 + correct_counts = [] + for true_ans, pred_ans in zip(true_answers, pred_answers): + if true_ans and pred_ans and set(true_ans) == set(pred_ans): + correct_counts.append(1) + else: + correct_counts.append(0) + + accuracy = np.mean(correct_counts) + + # 构建多标签向量 + all_labels = set() + for item in data: + choices = item["choices"]["label"] + for label in choices: + all_labels.add(label) + + all_labels = sorted(list(all_labels)) + + y_true_multi = [] + y_pred_multi = [] + + for true_ans, pred_ans in zip(true_answers, pred_answers): + true_vector = [1 if label in (true_ans or []) else 0 for label in all_labels] + pred_vector = [1 if label in (pred_ans or []) else 0 for label in all_labels] + y_true_multi.append(true_vector) + y_pred_multi.append(pred_vector) + + y_true_multi = np.array(y_true_multi) + y_pred_multi = np.array(y_pred_multi) + + # 计算各种指标 + metrics = { + "accuracy": accuracy, + "precision_micro": precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0), + "recall_micro": recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0), + "f1_micro": f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0), + "precision_macro": precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0), + "recall_macro": recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0), + "f1_macro": f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0) + } + + logger.info("Metrics computed successfully") + return metrics diff --git a/eval_framework/src/utils.py b/eval_framework/src/utils.py new file mode 100644 index 0000000..652f4a2 --- /dev/null +++ b/eval_framework/src/utils.py @@ -0,0 +1,360 @@ +import json +import yaml +import logging +import pandas as pd +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, List +from tabulate import tabulate + +def load_config(config_path: str) -> Dict[str, Any]: + """ + 加载配置文件 + + Args: + config_path: 配置文件路径 + + Returns: + 配置字典 + """ + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + return config + +def get_models_from_config(config: Dict[str, Any]) -> List[str]: + """ + 从配置中获取模型列表 + + Args: + config: 配置字典 + + Returns: + 模型名称列表 + """ + api_config = config['api'] + + # 优先使用models列表 + if 'models' in api_config and api_config['models']: + return api_config['models'] + # 向后兼容:如果没有models,使用单个model + elif 'model' in api_config: + return [api_config['model']] + else: + raise ValueError("No models specified in configuration") + +def generate_output_dir(config: Dict[str, Any]) -> str: + """ + 生成输出目录路径 + + Args: + config: 配置字典 + + Returns: + 输出目录路径 + """ + output_config = config['evaluation']['output'] + base_dir = output_config['base_dir'] + auto_timestamp = output_config.get('auto_timestamp', True) + + # 创建基础目录 + base_path = Path(base_dir) + + if auto_timestamp: + # 创建时间戳文件夹 (年月日时分) + timestamp = datetime.now().strftime("%Y%m%d_%H%M") + output_dir = base_path / timestamp + else: + output_dir = base_path + + # 确保目录存在 + output_dir.mkdir(parents=True, exist_ok=True) + + return str(output_dir) + +def generate_model_output_path(output_dir: str, model_name: str, filename_template: str) -> str: + """ + 为特定模型生成输出文件路径 + + Args: + output_dir: 输出目录 + model_name: 模型名称 + filename_template: 文件名模板 + + Returns: + 完整的输出文件路径 + """ + # 处理模型名中的特殊字符 + safe_model_name = model_name.replace('/', '_').replace(':', '_') + filename = filename_template.format(model=safe_model_name) + return str(Path(output_dir) / filename) + +def save_results(results: list, filepath: str) -> None: + """ + 保存结果到JSON文件 + + Args: + results: 结果列表 + filepath: 保存路径 + """ + # 确保目录存在 + Path(filepath).parent.mkdir(parents=True, exist_ok=True) + + with open(filepath, 'w', encoding='utf-8') as f: + json.dump(results, f, indent=2, ensure_ascii=False) + +def save_metrics(metrics: Dict[str, float], filepath: str) -> None: + """ + 保存评估指标到JSON文件 + + Args: + metrics: 指标字典 + filepath: 保存路径 + """ + # 生成指标文件路径(在同一目录下) + metrics_path = Path(filepath).parent / f"{Path(filepath).stem}_metrics.json" + + # 添加时间戳和其他元信息 + metrics_with_meta = { + "timestamp": datetime.now().isoformat(), + "metrics": metrics + } + + with open(metrics_path, 'w', encoding='utf-8') as f: + json.dump(metrics_with_meta, f, indent=2, ensure_ascii=False) + +def create_results_dataframe(all_results: Dict[str, Dict]) -> pd.DataFrame: + """ + 将所有模型的结果转换为DataFrame + + Args: + all_results: 所有模型的结果字典 + + Returns: + 包含所有模型指标的DataFrame + """ + if not all_results: + return pd.DataFrame() + + # 收集所有模型的指标数据 + data = [] + for model_name, model_result in all_results.items(): + row = {"Model": model_name} + row.update(model_result["metrics"]) + row["Data Count"] = len(model_result["results"]) + data.append(row) + + # 创建DataFrame + df = pd.DataFrame(data) + + # 将Model列设为索引 + df = df.set_index("Model") + + # 对列进行排序(将Data Count放在最后) + metric_columns = [col for col in df.columns if col != "Data Count"] + df = df[metric_columns + ["Data Count"]] + + return df + +def save_summary(all_results: Dict[str, Dict], output_dir: str, summary_filename: str) -> None: + """ + 保存所有模型的汇总结果 + + Args: + all_results: 所有模型的结果字典 + output_dir: 输出目录 + summary_filename: 汇总文件名 + """ + output_path = Path(output_dir) + + # 创建DataFrame + df = create_results_dataframe(all_results) + + if df.empty: + logging.warning("No results to save in summary") + return + + # 保存JSON格式的详细汇总 + summary_path = output_path / summary_filename + summary_data = { + "timestamp": datetime.now().isoformat(), + "models_count": len(all_results), + "models": {} + } + + for model_name, model_result in all_results.items(): + summary_data["models"][model_name] = { + "metrics": model_result["metrics"], + "data_count": len(model_result["results"]) + } + + # 添加模型对比表 + if len(all_results) > 1: + comparison = {} + metric_names = [col for col in df.columns if col != "Data Count"] + + for metric in metric_names: + comparison[metric] = df[metric].to_dict() + + summary_data["comparison"] = comparison + + with open(summary_path, 'w', encoding='utf-8') as f: + json.dump(summary_data, f, indent=2, ensure_ascii=False) + + # 保存CSV格式的汇总表格 + csv_filename = summary_filename.replace('.json', '.csv') + csv_path = output_path / csv_filename + + # 重置索引以便模型名称也作为列保存 + df_for_csv = df.reset_index() + df_for_csv.to_csv(csv_path, index=False, encoding='utf-8') + + # 保存Excel格式(如果需要) + excel_filename = summary_filename.replace('.json', '.xlsx') + excel_path = output_path / excel_filename + + try: + # 创建Excel文件,包含多个工作表 + with pd.ExcelWriter(excel_path, engine='openpyxl') as writer: + # 主要结果表 + df_for_csv.to_excel(writer, sheet_name='Summary', index=False) + + # 如果有多个模型,创建排名表 + if len(all_results) > 1: + ranking_df = create_ranking_dataframe(df) + ranking_df.to_excel(writer, sheet_name='Rankings', index=False) + + except ImportError: + logging.warning("openpyxl not installed, skipping Excel export") + + logging.info(f"Summary saved to {summary_path}") + logging.info(f"CSV summary saved to {csv_path}") + +def create_ranking_dataframe(df: pd.DataFrame) -> pd.DataFrame: + """ + 创建模型排名DataFrame + + Args: + df: 原始结果DataFrame + + Returns: + 包含排名的DataFrame + """ + # 排除非指标列 + metric_columns = [col for col in df.columns if col != "Data Count"] + + # 为每个指标创建排名(假设数值越大越好,可以根据需要调整) + ranking_data = [] + + for metric in metric_columns: + # 创建排名(降序,数值越大排名越前) + ranks = df[metric].rank(method='min', ascending=False) + + for model_name in df.index: + ranking_data.append({ + 'Model': model_name, + 'Metric': metric, + 'Value': df.loc[model_name, metric], + 'Rank': int(ranks[model_name]) + }) + + ranking_df = pd.DataFrame(ranking_data) + return ranking_df + +def print_summary(all_results: Dict[str, Dict]) -> None: + """ + 打印所有模型的汇总结果 + + Args: + all_results: 所有模型的结果字典 + """ + print("\n" + "="*100) + print("SUMMARY - ALL MODELS COMPARISON") + print("="*100) + + if not all_results: + print("No results to display") + return + + # 创建DataFrame + df = create_results_dataframe(all_results) + + if df.empty: + print("No valid results to display") + return + + # 使用tabulate打印美观的表格 + print(tabulate( + df, + headers=df.columns, + tablefmt='grid', + floatfmt='.4f', + showindex=True + )) + + # 如果有多个模型,显示最佳模型 + if len(all_results) > 1: + print("\n" + "-"*100) + print("BEST PERFORMERS BY METRIC:") + print("-"*100) + + metric_columns = [col for col in df.columns if col != "Data Count"] + + for metric in metric_columns: + best_model = df[metric].idxmax() + best_value = df.loc[best_model, metric] + print(f"{metric.upper():<20}: {best_model:<30} ({best_value:.4f})") + + print("="*100) + +def setup_logging(level: str = "INFO", format_str: str = None, log_dir: str = "logs") -> None: + """ + 设置日志配置 + + Args: + level: 日志级别 + format_str: 日志格式 + log_dir: 日志目录 + """ + if format_str is None: + format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + # 创建日志目录 + Path(log_dir).mkdir(parents=True, exist_ok=True) + + # 生成日志文件名(包含时间戳) + timestamp = datetime.now().strftime("%Y%m%d_%H%M") + log_file = Path(log_dir) / f"evaluation_{timestamp}.log" + + logging.basicConfig( + level=getattr(logging, level.upper()), + format=format_str, + handlers=[ + logging.StreamHandler(), + logging.FileHandler(log_file, encoding='utf-8') + ] + ) + +def print_metrics(metrics: Dict[str, float], model_name: str = None) -> None: + """ + 打印评估指标 + + Args: + metrics: 指标字典 + model_name: 模型名称 + """ + title = f"EVALUATION RESULTS - {model_name}" if model_name else "EVALUATION RESULTS" + print("\n" + "="*60) + print(title) + print("="*60) + + # 创建单行DataFrame用于美观显示 + df = pd.DataFrame([metrics]) + print(tabulate( + df, + headers=df.columns, + tablefmt='grid', + floatfmt='.4f', + showindex=False + )) + + print("="*60) diff --git a/layer1/ALL-merge/eval.py b/layer1/ALL-merge/eval.py index f914e59..e69de29 100644 --- a/layer1/ALL-merge/eval.py +++ b/layer1/ALL-merge/eval.py @@ -1,166 +0,0 @@ -import json -import threading -from tqdm import tqdm -import concurrent.futures -from openai import OpenAI -import numpy as np -from sklearn.metrics import precision_score, recall_score, f1_score -import re - -client = OpenAI( - api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", - base_url="https://vip.apiyi.com/v1" -) - -thread_lock = threading.Lock() - -def load_json_data(filepath): - with open(filepath, 'r') as file: - data = json.load(file) - return data - -def get_response(input,max_retries=10): - retries = 0 - while retries