Files
MatBench/layer2/PGEE/code/stepy_gen_option.py

977 lines
41 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import openai
from typing import Dict, Any, List, Tuple, Optional
import time
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from tqdm import tqdm
import random
import re
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class ChoiceOptionsGenerator:
def __init__(self, api_key: str, base_url: str, model_name: str, max_workers: int = 20):
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
self.max_workers = max_workers
self.thread_local = threading.local()
self.lock = threading.Lock()
self.max_retries = 5
self.max_sampling_attempts = 6
def get_client(self):
if not hasattr(self.thread_local, 'client'):
self.thread_local.client = openai.OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
return self.thread_local.client
def generate_options_with_sampling(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
"""使用多次采样策略生成选项"""
attempts_results = []
for attempt in range(self.max_sampling_attempts):
try:
# 生成一个候选选项
candidate = self._attempt_generate_options(question_data)
if not self._validate_options_quality(candidate, question_data):
with self.lock:
logging.warning(f"{attempt+1}次采样 - 选项质量验证失败")
continue
# 测试模型是否能正确回答这个问题
is_model_correct = self._test_model_performance(candidate, question_data)
candidate["performance_test"] = {
"model_answered_correctly": is_model_correct,
"sampling_attempt": attempt + 1,
}
attempts_results.append(candidate)
with self.lock:
logging.info(f"{attempt+1}次采样 - 模型{'答对' if is_model_correct else '答错'}")
# 如果模型答错了,这是一个好的困难题目,早停
if not is_model_correct:
return self._finalize_result(candidate, attempts_results, "early_stop_incorrect")
except Exception as e:
with self.lock:
logging.warning(f"{attempt+1}次采样失败: {e}")
continue
# 所有采样都完成了,选择一个结果
if attempts_results:
# 检查是否所有采样都答对了
all_correct = all(r.get("performance_test", {}).get("model_answered_correctly", True)
for r in attempts_results)
if all_correct:
selected = random.choice(attempts_results)
return self._finalize_result(selected, attempts_results, "all_samples_correct")
else:
# 优先选择答错的
incorrect_results = [r for r in attempts_results
if not r.get("performance_test", {}).get("model_answered_correctly", True)]
if incorrect_results:
selected = random.choice(incorrect_results)
return self._finalize_result(selected, attempts_results, "mixed_results")
else:
selected = random.choice(attempts_results)
return self._finalize_result(selected, attempts_results, "mixed_results")
# 所有采样都失败
logging.error("所有采样都失败,使用备用选项")
return self._create_fallback_options(question_data)
def _finalize_result(self, selected_result: Dict[str, Any], all_results: List[Dict], result_type: str) -> Dict[str, Any]:
"""完善最终结果的标记信息"""
# 统计所有采样的结果
total_attempts = len(all_results)
correct_count = sum(1 for r in all_results
if r.get("performance_test", {}).get("model_answered_correctly", True))
incorrect_count = total_attempts - correct_count
# 添加汇总信息
selected_result["sampling_summary"] = {
"result_type": result_type, # early_stop_incorrect, all_samples_correct, mixed_results
"total_sampling_attempts": total_attempts,
"correct_answers": correct_count,
"incorrect_answers": incorrect_count,
"is_early_stop": result_type == "early_stop_incorrect",
"is_all_correct": result_type == "all_samples_correct",
"selected_attempt": selected_result.get("performance_test", {}).get("sampling_attempt", 1),
"selected_was_correct": selected_result.get("performance_test", {}).get("model_answered_correctly", True)
}
# 简化的难度标记
if result_type == "early_stop_incorrect":
difficulty_label = "hard_early_stop"
elif result_type == "all_samples_correct":
difficulty_label = "easy_all_correct"
else:
difficulty_label = "mixed"
selected_result["sampling_summary"]["difficulty_label"] = difficulty_label
with self.lock:
logging.info(f"题目标记: {difficulty_label} (正确{correct_count}/{total_attempts}次)")
return selected_result
def _test_model_performance(self, generated_question: Dict[str, Any], original_data: Dict[str, Any]) -> bool:
"""测试模型是否能正确回答生成的问题"""
try:
question_type = generated_question.get("question_type", "")
if question_type == "true_false":
return self._test_true_false_question(generated_question)
elif question_type == "multiple_choice":
return self._test_multiple_choice_question(generated_question, original_data)
else:
logging.warning(f"未知题目类型: {question_type}")
return True # 默认认为模型答对了
except Exception as e:
logging.error(f"测试模型性能时出错: {e}")
return True # 出错时默认认为模型答对了
def _test_true_false_question(self, question_data: Dict[str, Any]) -> bool:
"""测试判断题"""
statement = question_data.get("statement", "")
correct_answer = question_data.get("correct_answer", "")
if not statement or not correct_answer:
logging.warning("判断题数据不完整")
return True
test_prompt = f"""
请判断以下陈述的正误。请仔细分析每个细节,考虑所有可能的条件和例外情况。
陈述:{statement}
请只输出 "True""False",不要解释:
"""
try:
client = self.get_client()
response = client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "你是一个材料科学专家。请仔细分析陈述考虑所有技术细节和特殊情况只输出True或False。"},
{"role": "user", "content": test_prompt}
],
temperature=0.1,
max_tokens=10
)
model_answer = response.choices[0].message.content.strip()
if "True" in model_answer:
model_answer = "True"
elif "False" in model_answer:
model_answer = "False"
else:
logging.warning(f"模型回答格式异常: {model_answer}")
return True # 格式异常默认认为答对
is_correct = model_answer == correct_answer
logging.debug(f"判断题测试 - 正确答案: {correct_answer}, 模型答案: {model_answer}, 结果: {'正确' if is_correct else '错误'}")
return is_correct
except Exception as e:
logging.error(f"测试判断题时出错: {e}")
return True
def _test_multiple_choice_question(self, question_data: Dict[str, Any], original_data: Dict[str, Any]) -> bool:
"""测试选择题"""
options = question_data.get("options", {})
correct_answer = question_data.get("correct_answer", "")
original_question = original_data.get("choice_question", "")
if not options or not correct_answer or not original_question:
logging.warning("选择题数据不完整")
return True
# 构造完整的选择题
options_text = ""
for key in sorted(options.keys()):
options_text += f"{key}. {options[key]}\n"
test_prompt = f"""
以下是一道材料科学专业题目,请仔细分析每个选项,考虑所有技术细节和约束条件。
题目:{original_question}
选项:
{options_text}
请选择最准确的答案只输出选项字母A、B、C或D
"""
try:
client = self.get_client()
response = client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "你是一个材料科学专家。请深入分析题目,仔细比较各选项的技术准确性,只输出选项字母。"},
{"role": "user", "content": test_prompt}
],
temperature=0.1,
max_tokens=10
)
model_answer = response.choices[0].message.content.strip().upper()
model_choice = ""
for char in model_answer:
if char in ["A", "B", "C", "D"]:
model_choice = char
break
if not model_choice:
logging.warning(f"模型回答格式异常: {model_answer}")
return True # 格式异常默认认为答对
is_correct = model_choice == correct_answer.upper()
logging.debug(f"选择题测试 - 正确答案: {correct_answer}, 模型答案: {model_choice}, 结果: {'正确' if is_correct else '错误'}")
return is_correct
except Exception as e:
logging.error(f"测试选择题时出错: {e}")
return True
def _create_fallback_options(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
"""当AI生成失败时的备用选项生成"""
question_type = question_data.get("question_type", "")
correct_option = question_data.get("correct_option", "")
if question_type == "true_false":
return {
"question_type": "true_false",
"statement": question_data.get("choice_question", ""),
"options": ["True", "False"],
"correct_answer": self._determine_true_false_answer(correct_option),
"explanation": "基于题目分析的判断结果",
"sampling_summary": {
"result_type": "fallback",
"difficulty_label": "unknown_fallback",
"total_sampling_attempts": 0,
"is_early_stop": False,
"is_all_correct": False
}
}
else:
distractors = ["选项B", "选项C", "选项D"]
all_options = [correct_option] + distractors
random.shuffle(all_options)
correct_index = all_options.index(correct_option)
correct_label = ["A", "B", "C", "D"][correct_index]
return {
"question_type": "multiple_choice",
"options": {
"A": all_options[0],
"B": all_options[1],
"C": all_options[2],
"D": all_options[3]
},
"correct_answer": correct_label,
"explanation": "基于规则生成的备用选项",
"sampling_summary": {
"result_type": "fallback",
"difficulty_label": "unknown_fallback",
"total_sampling_attempts": 0,
"is_early_stop": False,
"is_all_correct": False
}
}
def _determine_true_false_answer(self, correct_option: str) -> str:
"""确定判断题的正确答案"""
true_indicators = ["true", "正确", "", "", "T", "", "yes"]
false_indicators = ["false", "错误", "", "", "F", "×", "no"]
correct_lower = correct_option.lower().strip()
if any(indicator in correct_lower for indicator in true_indicators):
return "True"
elif any(indicator in correct_lower for indicator in false_indicators):
return "False"
else:
return "True"
# 提示词方法保持不变
def _create_enhanced_multiple_choice_prompt(self, question: str, correct_option: str, original_question: str, question_type: str) -> str:
"""增强的选择题提示词 - 专注于生成更具迷惑性的选项"""
return f"""
你是材料科学领域的顶级专家专门设计能够挑战AI系统的高难度考试题目。你的目标是创建连GPT-4、Claude等先进AI都可能答错的题目。
题目: {question}
正确答案: {correct_option}
原始题目: {original_question}
题目类型: {question_type}
**核心策略让AI模型产生错误的直觉判断**
**超高难度干扰项设计原则:**
1. **认知偏差利用型干扰项**
- 利用常见的材料科学概念混淆如弹性模量vs剪切模量
- 设计基于"第一印象"正确但深入分析错误的选项
- 使用数值接近但单位或条件不同的陷阱
2. **专业直觉陷阱型干扰项**
- 基于材料科学中"看似显而易见"但实际错误的推理
- 利用不同材料体系间的类比误区
- 设计需要区分相似工艺或现象的细微差别
3. **多层次验证陷阱**
- 创建需要同时考虑多个材料参数的复杂选项
- 设计表面符合某个理论但忽略关键约束条件的选项
- 使用真实存在但在当前情境下不适用的材料数据
输出格式严格JSON无注释
{{
"question_type": "multiple_choice",
"options": {{
"A": "选项A内容",
"B": "选项B内容",
"C": "选项C内容",
"D": "选项D内容"
}},
"correct_answer": "A",
"explanation": "详细解释为什么正确答案正确,以及每个干扰项的具体迷惑策略"
}}
现在请基于以上原则设计一道能让先进AI模型都可能答错的高难度选择题。
"""
def _create_enhanced_true_false_prompt(self, question: str, correct_option: str, original_question: str) -> str:
"""增强的判断题提示词"""
return f"""
你是材料科学专家需要设计能够挑战AI判断能力的高难度判断题。
题目: {question}
正确答案: {correct_option}
原始题目: {original_question}
**设计高难度判断题的策略:**
1. **微妙条件陷阱**:设计在特定条件下成立但一般情况下错误(或相反)的陈述
2. **精确性陷阱**:使用"总是""从不""所有"等绝对词汇的微妙误用
3. **概念边界模糊**:涉及材料科学中定义边界模糊的概念
4. **数值精度陷阱**:涉及需要精确数值判断的陈述
输出格式严格JSON无注释
{{
"question_type": "true_false",
"statement": "需要判断的复杂陈述句",
"options": ["True", "False"],
"correct_answer": "True或False",
"explanation": "详细解释判断理由和可能的误解点"
}}
"""
def create_options_prompt(self, question_data: Dict[str, Any]) -> str:
"""创建生成选项的提示词"""
choice_question = question_data.get("choice_question", "")
correct_option = question_data.get("correct_option", "")
original_question = question_data.get("question", "")
question_type = question_data.get("question_type", "")
if question_type == "true_false":
return self._create_enhanced_true_false_prompt(choice_question, correct_option, original_question)
else:
return self._create_enhanced_multiple_choice_prompt(choice_question, correct_option, original_question, question_type)
def _attempt_generate_options(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
"""单次尝试生成选项"""
client = self.get_client()
prompt = self.create_options_prompt(question_data)
response = client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "system",
"content": "你是一个材料科学专业的教育评估专家。请严格按照要求的JSON格式输出不要添加任何额外的文本、注释或代码块标记。确保输出的JSON语法完全正确。"
},
{"role": "user", "content": prompt}
],
temperature=0.9,
max_tokens=2000,
top_p=0.95
)
result_text = response.choices[0].message.content.strip()
logging.debug(f"AI响应: {result_text}")
json_result = self._extract_and_fix_json(result_text)
return json_result
def _extract_and_fix_json(self, response_text: str) -> Dict[str, Any]:
"""从响应文本中提取并修复JSON"""
response_text = re.sub(r'```json\s*', '', response_text)
response_text = re.sub(r'```\s*$', '', response_text)
json_start = response_text.find('{')
json_end = response_text.rfind('}') + 1
if json_start == -1 or json_end <= json_start:
raise ValueError("无法在响应中找到JSON格式内容")
json_str = response_text[json_start:json_end]
json_str = self._fix_json_syntax(json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
logging.error(f"JSON解析失败: {e}")
json_str = self._aggressive_json_fix(json_str)
return json.loads(json_str)
def _fix_json_syntax(self, json_str: str) -> str:
"""修复常见的JSON语法错误"""
json_str = re.sub(r'//.*?(?=\n|$)', '', json_str)
json_str = re.sub(r'/\*.*?\*/', '', json_str, flags=re.DOTALL)
json_str = re.sub(r',\s*}', '}', json_str)
json_str = re.sub(r',\s*]', ']', json_str)
json_str = re.sub(r"'([^']*)':", r'"\1":', json_str)
json_str = re.sub(r":\s*'([^']*)'", r': "\1"', json_str)
return json_str
def _aggressive_json_fix(self, json_str: str) -> str:
"""更激进的JSON修复方法"""
try:
patterns = {
'question_type': r'"question_type"\s*:\s*"([^"]*)"',
'correct_answer': r'"correct_answer"\s*:\s*"([^"]*)"',
'explanation': r'"explanation"\s*:\s*"([^"]*)"'
}
extracted = {}
for key, pattern in patterns.items():
match = re.search(pattern, json_str)
if match:
extracted[key] = match.group(1)
options_match = re.search(r'"options"\s*:\s*{([^}]*)}', json_str)
if options_match:
options_content = options_match.group(1)
options = {}
option_pattern = r'"([ABCD])"\s*:\s*"([^"]*)"'
for match in re.finditer(option_pattern, options_content):
options[match.group(1)] = match.group(2)
extracted['options'] = options
if 'question_type' in extracted and len(extracted) >= 3:
return json.dumps(extracted, ensure_ascii=False)
except Exception as e:
logging.error(f"激进修复失败: {e}")
raise ValueError("无法修复JSON格式")
def _validate_options_quality(self, result: Dict[str, Any], original_data: Dict[str, Any]) -> bool:
"""验证生成选项的质量"""
if not result:
return False
question_type = result.get("question_type", "")
if question_type == "true_false":
return self._validate_true_false_quality(result)
elif question_type == "multiple_choice":
return self._validate_multiple_choice_quality(result, original_data)
return False
def _validate_true_false_quality(self, result: Dict[str, Any]) -> bool:
"""验证判断题质量"""
required_fields = ["statement", "options", "correct_answer", "explanation"]
if not all(field in result for field in required_fields):
return False
options = result.get("options", [])
if not (len(options) == 2 and "True" in options and "False" in options):
return False
correct_answer = result.get("correct_answer", "")
if correct_answer not in ["True", "False"]:
return False
return True
def _validate_multiple_choice_quality(self, result: Dict[str, Any], original_data: Dict[str, Any]) -> bool:
"""验证选择题质量"""
if not all(key in result for key in ["options", "correct_answer", "explanation"]):
return False
options = result.get("options", {})
if len(options) != 4 or not all(label in options for label in ["A", "B", "C", "D"]):
return False
correct_answer = result.get("correct_answer", "")
if correct_answer not in ["A", "B", "C", "D"]:
return False
if any(len(str(option).strip()) < 2 for option in options.values()):
return False
option_values = [str(option).strip().lower() for option in options.values()]
if len(set(option_values)) != 4:
return False
return True
def generate_options(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
"""为单个题目生成选项,使用多次采样策略"""
result = self.generate_options_with_sampling(question_data)
if result:
return result
logging.warning("采样生成失败,回退到原始生成方法")
return self._generate_with_basic_retry(question_data)
def _generate_with_basic_retry(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
"""基础重试生成方法"""
for attempt in range(self.max_retries):
try:
result = self._attempt_generate_options(question_data)
if self._validate_options_quality(result, question_data):
# 为基础重试添加采样信息
result["sampling_summary"] = {
"result_type": "basic_retry",
"difficulty_label": "unknown_retry",
"total_sampling_attempts": 1,
"is_early_stop": False,
"is_all_correct": False
}
return result
else:
if attempt < self.max_retries - 1:
logging.warning(f"{attempt+1}次生成的选项质量不佳,重试中...")
time.sleep(1)
continue
except Exception as e:
logging.error(f"{attempt+1}次生成选项失败: {e}")
if attempt < self.max_retries - 1:
time.sleep(2)
continue
logging.error("所有重试都失败,使用备用选项生成")
return self._create_fallback_options(question_data)
def process_single_question(generator, question, question_index):
"""处理单个题目的函数"""
try:
options_data = generator.generate_options(question)
complete_question = question.copy()
complete_question["generated_options"] = options_data
complete_question["generation_status"] = "success"
complete_question["question_index"] = question_index
# 提取采样信息用于日志
sampling_info = options_data.get("sampling_summary", {})
difficulty_label = sampling_info.get("difficulty_label", "unknown")
attempts = sampling_info.get("total_sampling_attempts", 1)
is_early_stop = sampling_info.get("is_early_stop", False)
is_all_correct = sampling_info.get("is_all_correct", False)
status_emoji = {
"hard_early_stop": "🔥",
"easy_all_correct": "",
"mixed": "",
"unknown_fallback": "",
"unknown_retry": "🔄"
}
logging.info(f"{question_index+1}题完成 - {difficulty_label} - 采样{attempts}次 - {'早停' if is_early_stop else '全采样'}")
return complete_question
except Exception as e:
logging.error(f"{question_index+1}题处理失败: {e}")
failed_question = question.copy()
failed_question["generated_options"] = generator._create_fallback_options(question)
failed_question["generation_status"] = "failed"
failed_question["error_message"] = str(e)
failed_question["question_index"] = question_index
return failed_question
def main():
# 配置信息
API_KEY = "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
BASE_URL = "https://vip.apiyi.com/v1"
MODEL_NAME = "deepseek-chat"
MAX_WORKERS = 20
INPUT_FILE = "/home/ubuntu/50T/LYT/MatBench/layer2/PGEE/code/step7_no_perp_convertible.json"
OUTPUT_FILE = "/home/ubuntu/50T/LYT/MatBench/layer2/PGEE/code/stepy_complete_choice_questions_with_sampling.json"
# 加载数据
print("正在加载题目数据...")
with open(INPUT_FILE, 'r', encoding='utf-8') as f:
questions = json.load(f)
import random
random.shuffle(questions) # 随机打乱题目顺序
# questions = questions[:100] # 限制处理前100道题目以便测试
print(f"加载了 {len(questions)} 道题目")
# 统计题目类型分布
type_counts = {}
for q in questions:
qtype = q.get("question_type", "unknown")
type_counts[qtype] = type_counts.get(qtype, 0) + 1
print("题目类型分布:")
for qtype, count in type_counts.items():
print(f" {qtype}: {count}")
# 初始化生成器
generator = ChoiceOptionsGenerator(API_KEY, BASE_URL, MODEL_NAME, MAX_WORKERS)
print(f"\n开始生成选项,每题最多采样{generator.max_sampling_attempts}次...")
print("策略:答错题目会早停,答对题目会继续采样直到上限")
# 使用ThreadPoolExecutor进行并发处理
# 使用ThreadPoolExecutor进行并发处理
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
# 提交所有任务
future_to_question = {
executor.submit(process_single_question, generator, question, i): (question, i)
for i, question in enumerate(questions)
}
# 使用tqdm显示进度
with tqdm(total=len(questions), desc="生成选项") as pbar:
# 收集结果
temp_results = []
for future in as_completed(future_to_question):
try:
result = future.result()
temp_results.append(result)
pbar.update(1)
# 更新进度条描述信息
sampling_info = result.get("generated_options", {}).get("sampling_summary", {})
difficulty_label = sampling_info.get("difficulty_label", "unknown")
status_emoji = {
"hard_early_stop": "🔥",
"easy_all_correct": "",
"mixed": "",
"unknown_fallback": "",
"unknown_retry": "🔄"
}
desc = f"生成选项 {status_emoji.get(difficulty_label, '')}"
pbar.set_description(desc)
except Exception as e:
logging.error(f"处理结果时发生错误: {e}")
original_question, question_index = future_to_question[future]
# 创建失败结果
failed_result = original_question.copy()
failed_result["generated_options"] = generator._create_fallback_options(original_question)
failed_result["generation_status"] = "processing_failed"
failed_result["error_message"] = str(e)
failed_result["question_index"] = question_index
temp_results.append(failed_result)
pbar.update(1)
# 按原始顺序排序结果
complete_questions = sorted(temp_results, key=lambda x: x.get("question_index", 0))
# 移除临时的索引字段
for question in complete_questions:
if "question_index" in question:
del question["question_index"]
# 统计采样结果
print("\n=== 采样结果统计 ===")
sampling_stats = {
"hard_early_stop": 0, # 答错后早停的困难题
"easy_all_correct": 0, # 全部采样都答对的简单题
"mixed": 0, # 混合结果
"unknown_fallback": 0, # 备用方案
"unknown_retry": 0, # 重试方案
"total": len(complete_questions)
}
early_stop_questions = []
all_correct_questions = []
total_api_calls = 0
total_generation_calls = 0
total_validation_calls = 0
for q in complete_questions:
options_data = q.get("generated_options", {})
sampling_info = options_data.get("sampling_summary", {})
difficulty_label = sampling_info.get("difficulty_label", "unknown_fallback")
is_early_stop = sampling_info.get("is_early_stop", False)
is_all_correct = sampling_info.get("is_all_correct", False)
attempts = sampling_info.get("total_sampling_attempts", 0)
# 统计标签分布
if difficulty_label in sampling_stats:
sampling_stats[difficulty_label] += 1
# 收集特殊类别的题目
if is_early_stop:
early_stop_questions.append(q)
if is_all_correct:
all_correct_questions.append(q)
# 统计API调用次数
total_generation_calls += attempts
# 每次采样都需要验证(除了备用方案)
if difficulty_label not in ["unknown_fallback", "unknown_retry"]:
total_validation_calls += attempts
total_api_calls = total_generation_calls + total_validation_calls
# 输出统计结果
print("题目标记分布:")
for label, count in sampling_stats.items():
if label != "total" and count > 0:
percentage = (count / sampling_stats["total"]) * 100
print(f" {label}: {count} 道 ({percentage:.1f}%)")
print(f"\n关键指标:")
print(f" 早停困难题(答错后早停): {len(early_stop_questions)}")
print(f" 全正确简单题(所有采样都答对): {len(all_correct_questions)}")
print(f" 早停率: {len(early_stop_questions)/len(complete_questions)*100:.1f}%")
print(f" 全正确率: {len(all_correct_questions)/len(complete_questions)*100:.1f}%")
# API成本统计
print(f"\n=== API调用统计 ===")
print(f"总生成调用: {total_generation_calls}")
print(f"总验证调用: {total_validation_calls}")
print(f"总API调用: {total_api_calls}")
print(f"平均每题调用: {total_api_calls/len(complete_questions):.1f}")
# 采样效率分析
if early_stop_questions:
early_stop_attempts = [q.get("generated_options", {}).get("sampling_summary", {}).get("total_sampling_attempts", 0)
for q in early_stop_questions]
avg_early_stop_attempts = sum(early_stop_attempts) / len(early_stop_attempts)
print(f"早停题目平均采样次数: {avg_early_stop_attempts:.1f}")
if all_correct_questions:
all_correct_attempts = [q.get("generated_options", {}).get("sampling_summary", {}).get("total_sampling_attempts", 0)
for q in all_correct_questions]
avg_all_correct_attempts = sum(all_correct_attempts) / len(all_correct_attempts)
print(f"全正确题目平均采样次数: {avg_all_correct_attempts:.1f}")
# 按题目类型分析
print(f"\n=== 各题型采样效果 ===")
type_sampling_analysis = {}
for q in complete_questions:
qtype = q.get("question_type", "unknown")
options_data = q.get("generated_options", {})
sampling_info = options_data.get("sampling_summary", {})
difficulty_label = sampling_info.get("difficulty_label", "unknown")
if qtype not in type_sampling_analysis:
type_sampling_analysis[qtype] = {
"hard_early_stop": 0,
"easy_all_correct": 0,
"mixed": 0,
"unknown": 0,
"total": 0
}
type_sampling_analysis[qtype]["total"] += 1
if difficulty_label == "hard_early_stop":
type_sampling_analysis[qtype]["hard_early_stop"] += 1
elif difficulty_label == "easy_all_correct":
type_sampling_analysis[qtype]["easy_all_correct"] += 1
elif difficulty_label == "mixed":
type_sampling_analysis[qtype]["mixed"] += 1
else:
type_sampling_analysis[qtype]["unknown"] += 1
for qtype, stats in type_sampling_analysis.items():
if stats["total"] > 0:
print(f"{qtype}:")
early_stop_rate = (stats["hard_early_stop"] / stats["total"]) * 100
all_correct_rate = (stats["easy_all_correct"] / stats["total"]) * 100
print(f" 早停率: {early_stop_rate:.1f}% ({stats['hard_early_stop']}/{stats['total']})")
print(f" 全正确率: {all_correct_rate:.1f}% ({stats['easy_all_correct']}/{stats['total']})")
# 保存结果
final_output = {
"questions": complete_questions,
"sampling_statistics": {
"label_distribution": {k: v for k, v in sampling_stats.items() if k != "total"},
"early_stop_count": len(early_stop_questions),
"all_correct_count": len(all_correct_questions),
"early_stop_rate": len(early_stop_questions)/len(complete_questions),
"all_correct_rate": len(all_correct_questions)/len(complete_questions),
"total_questions": len(complete_questions)
},
"api_usage": {
"total_generation_calls": total_generation_calls,
"total_validation_calls": total_validation_calls,
"total_api_calls": total_api_calls,
"average_calls_per_question": total_api_calls/len(complete_questions)
},
"generation_metadata": {
"generation_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"model_used": MODEL_NAME,
"max_sampling_attempts": generator.max_sampling_attempts,
"success_rate": sum(1 for q in complete_questions if q.get("generation_status") == "success") / len(complete_questions)
}
}
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
json.dump(final_output, f, ensure_ascii=False, indent=2)
# 输出成功率统计
success_count = sum(1 for q in complete_questions if q.get("generation_status") == "success")
failed_count = len(complete_questions) - success_count
print(f"\n=== 生成成功率统计 ===")
print(f"总共处理: {len(complete_questions)} 道题目")
print(f"成功生成: {success_count}")
print(f"使用备用方案: {failed_count}")
print(f"成功率: {success_count/len(complete_questions)*100:.2f}%")
# 策略效果评估
print(f"\n=== 策略效果评估 ===")
if len(early_stop_questions) > 0:
print("✅ 早停策略有效:成功识别出困难题目")
print(f" 困难题目数量: {len(early_stop_questions)}")
# 展示几个早停题目的例子
print(" 早停题目示例:")
for i, q in enumerate(early_stop_questions[:3]): # 只显示前3个
qtype = q.get("question_type", "unknown")
attempts = q.get("generated_options", {}).get("sampling_summary", {}).get("total_sampling_attempts", 0)
print(f" {i+1}. {qtype}题,第{attempts}次采样后早停")
else:
print("⚠️ 没有题目触发早停,可能需要调整难度")
if len(all_correct_questions) > 0:
print("✅ 全采样策略有效:识别出简单题目")
print(f" 简单题目数量: {len(all_correct_questions)}")
# 展示几个全正确题目的例子
print(" 全正确题目示例:")
for i, q in enumerate(all_correct_questions[:3]): # 只显示前3个
qtype = q.get("question_type", "unknown")
attempts = q.get("generated_options", {}).get("sampling_summary", {}).get("total_sampling_attempts", 0)
print(f" {i+1}. {qtype}题,{attempts}次采样全部答对")
else:
print("⚠️ 没有题目全部答对,生成的题目可能都比较困难")
# 给出优化建议
print(f"\n=== 优化建议 ===")
early_stop_rate = len(early_stop_questions)/len(complete_questions)
all_correct_rate = len(all_correct_questions)/len(complete_questions)
if early_stop_rate < 0.2:
print("• 早停率偏低,建议:")
print(" - 增强提示词的迷惑性设计")
print(" - 提高选项生成的创造性增加temperature")
print(" - 添加更多AI容易犯错的陷阱类型")
if all_correct_rate > 0.6:
print("• 全正确率过高,建议:")
print(" - 检查题目是否过于简单")
print(" - 提升干扰选项的质量")
print(" - 增加专业深度和复杂性")
if early_stop_rate > 0.8:
print("• 早停率过高,建议:")
print(" - 适当降低题目难度")
print(" - 平衡难易程度分布")
print(" - 检查是否过度设计陷阱")
avg_api_calls = total_api_calls/len(complete_questions)
if avg_api_calls > 8:
print("• API调用次数偏高建议:")
print(" - 优化提示词提高首次生成质量")
print(" - 考虑减少最大采样次数")
print(" - 改进验证逻辑减少失败率")
print(f"\n结果已保存到: {OUTPUT_FILE}")
print("包含完整的题目数据、采样统计和API使用情况")
def export_analysis_report(questions: List[Dict], output_path: str):
"""导出分析报告"""
early_stop_questions = []
all_correct_questions = []
mixed_questions = []
for q in questions:
options_data = q.get("generated_options", {})
sampling_info = options_data.get("sampling_summary", {})
difficulty_label = sampling_info.get("difficulty_label", "unknown")
if difficulty_label == "hard_early_stop":
early_stop_questions.append(q)
elif difficulty_label == "easy_all_correct":
all_correct_questions.append(q)
elif difficulty_label == "mixed":
mixed_questions.append(q)
report = {
"summary": {
"total_questions": len(questions),
"early_stop_questions": len(early_stop_questions),
"all_correct_questions": len(all_correct_questions),
"mixed_questions": len(mixed_questions),
"early_stop_rate": len(early_stop_questions) / len(questions),
"all_correct_rate": len(all_correct_questions) / len(questions)
},
"early_stop_examples": early_stop_questions[:10], # 前10个早停例子
"all_correct_examples": all_correct_questions[:10], # 前10个全正确例子
"mixed_examples": mixed_questions[:5] # 前5个混合例子
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"分析报告已保存到: {output_path}")
if __name__ == "__main__":
main()