977 lines
41 KiB
Python
977 lines
41 KiB
Python
import json
|
||
import openai
|
||
from typing import Dict, Any, List, Tuple, Optional
|
||
import time
|
||
import logging
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import threading
|
||
from tqdm import tqdm
|
||
import random
|
||
import re
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
||
class ChoiceOptionsGenerator:
|
||
def __init__(self, api_key: str, base_url: str, model_name: str, max_workers: int = 20):
|
||
self.api_key = api_key
|
||
self.base_url = base_url
|
||
self.model_name = model_name
|
||
self.max_workers = max_workers
|
||
self.thread_local = threading.local()
|
||
self.lock = threading.Lock()
|
||
self.max_retries = 5
|
||
self.max_sampling_attempts = 6
|
||
|
||
def get_client(self):
|
||
if not hasattr(self.thread_local, 'client'):
|
||
self.thread_local.client = openai.OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url
|
||
)
|
||
return self.thread_local.client
|
||
|
||
def generate_options_with_sampling(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""使用多次采样策略生成选项"""
|
||
attempts_results = []
|
||
|
||
for attempt in range(self.max_sampling_attempts):
|
||
try:
|
||
# 生成一个候选选项
|
||
candidate = self._attempt_generate_options(question_data)
|
||
|
||
if not self._validate_options_quality(candidate, question_data):
|
||
with self.lock:
|
||
logging.warning(f"第{attempt+1}次采样 - 选项质量验证失败")
|
||
continue
|
||
|
||
# 测试模型是否能正确回答这个问题
|
||
is_model_correct = self._test_model_performance(candidate, question_data)
|
||
|
||
candidate["performance_test"] = {
|
||
"model_answered_correctly": is_model_correct,
|
||
"sampling_attempt": attempt + 1,
|
||
}
|
||
|
||
attempts_results.append(candidate)
|
||
|
||
with self.lock:
|
||
logging.info(f"第{attempt+1}次采样 - 模型{'答对' if is_model_correct else '答错'}了")
|
||
|
||
# 如果模型答错了,这是一个好的困难题目,早停
|
||
if not is_model_correct:
|
||
return self._finalize_result(candidate, attempts_results, "early_stop_incorrect")
|
||
|
||
except Exception as e:
|
||
with self.lock:
|
||
logging.warning(f"第{attempt+1}次采样失败: {e}")
|
||
continue
|
||
|
||
# 所有采样都完成了,选择一个结果
|
||
if attempts_results:
|
||
# 检查是否所有采样都答对了
|
||
all_correct = all(r.get("performance_test", {}).get("model_answered_correctly", True)
|
||
for r in attempts_results)
|
||
|
||
if all_correct:
|
||
selected = random.choice(attempts_results)
|
||
return self._finalize_result(selected, attempts_results, "all_samples_correct")
|
||
else:
|
||
# 优先选择答错的
|
||
incorrect_results = [r for r in attempts_results
|
||
if not r.get("performance_test", {}).get("model_answered_correctly", True)]
|
||
if incorrect_results:
|
||
selected = random.choice(incorrect_results)
|
||
return self._finalize_result(selected, attempts_results, "mixed_results")
|
||
else:
|
||
selected = random.choice(attempts_results)
|
||
return self._finalize_result(selected, attempts_results, "mixed_results")
|
||
|
||
# 所有采样都失败
|
||
logging.error("所有采样都失败,使用备用选项")
|
||
return self._create_fallback_options(question_data)
|
||
|
||
def _finalize_result(self, selected_result: Dict[str, Any], all_results: List[Dict], result_type: str) -> Dict[str, Any]:
|
||
"""完善最终结果的标记信息"""
|
||
# 统计所有采样的结果
|
||
total_attempts = len(all_results)
|
||
correct_count = sum(1 for r in all_results
|
||
if r.get("performance_test", {}).get("model_answered_correctly", True))
|
||
incorrect_count = total_attempts - correct_count
|
||
|
||
# 添加汇总信息
|
||
selected_result["sampling_summary"] = {
|
||
"result_type": result_type, # early_stop_incorrect, all_samples_correct, mixed_results
|
||
"total_sampling_attempts": total_attempts,
|
||
"correct_answers": correct_count,
|
||
"incorrect_answers": incorrect_count,
|
||
"is_early_stop": result_type == "early_stop_incorrect",
|
||
"is_all_correct": result_type == "all_samples_correct",
|
||
"selected_attempt": selected_result.get("performance_test", {}).get("sampling_attempt", 1),
|
||
"selected_was_correct": selected_result.get("performance_test", {}).get("model_answered_correctly", True)
|
||
}
|
||
|
||
# 简化的难度标记
|
||
if result_type == "early_stop_incorrect":
|
||
difficulty_label = "hard_early_stop"
|
||
elif result_type == "all_samples_correct":
|
||
difficulty_label = "easy_all_correct"
|
||
else:
|
||
difficulty_label = "mixed"
|
||
|
||
selected_result["sampling_summary"]["difficulty_label"] = difficulty_label
|
||
|
||
with self.lock:
|
||
logging.info(f"题目标记: {difficulty_label} (正确{correct_count}/{total_attempts}次)")
|
||
|
||
return selected_result
|
||
|
||
def _test_model_performance(self, generated_question: Dict[str, Any], original_data: Dict[str, Any]) -> bool:
|
||
"""测试模型是否能正确回答生成的问题"""
|
||
try:
|
||
question_type = generated_question.get("question_type", "")
|
||
|
||
if question_type == "true_false":
|
||
return self._test_true_false_question(generated_question)
|
||
elif question_type == "multiple_choice":
|
||
return self._test_multiple_choice_question(generated_question, original_data)
|
||
else:
|
||
logging.warning(f"未知题目类型: {question_type}")
|
||
return True # 默认认为模型答对了
|
||
|
||
except Exception as e:
|
||
logging.error(f"测试模型性能时出错: {e}")
|
||
return True # 出错时默认认为模型答对了
|
||
|
||
def _test_true_false_question(self, question_data: Dict[str, Any]) -> bool:
|
||
"""测试判断题"""
|
||
statement = question_data.get("statement", "")
|
||
correct_answer = question_data.get("correct_answer", "")
|
||
|
||
if not statement or not correct_answer:
|
||
logging.warning("判断题数据不完整")
|
||
return True
|
||
|
||
test_prompt = f"""
|
||
请判断以下陈述的正误。请仔细分析每个细节,考虑所有可能的条件和例外情况。
|
||
|
||
陈述:{statement}
|
||
|
||
请只输出 "True" 或 "False",不要解释:
|
||
"""
|
||
|
||
try:
|
||
client = self.get_client()
|
||
response = client.chat.completions.create(
|
||
model=self.model_name,
|
||
messages=[
|
||
{"role": "system", "content": "你是一个材料科学专家。请仔细分析陈述,考虑所有技术细节和特殊情况,只输出True或False。"},
|
||
{"role": "user", "content": test_prompt}
|
||
],
|
||
temperature=0.1,
|
||
max_tokens=10
|
||
)
|
||
|
||
model_answer = response.choices[0].message.content.strip()
|
||
if "True" in model_answer:
|
||
model_answer = "True"
|
||
elif "False" in model_answer:
|
||
model_answer = "False"
|
||
else:
|
||
logging.warning(f"模型回答格式异常: {model_answer}")
|
||
return True # 格式异常默认认为答对
|
||
|
||
is_correct = model_answer == correct_answer
|
||
logging.debug(f"判断题测试 - 正确答案: {correct_answer}, 模型答案: {model_answer}, 结果: {'正确' if is_correct else '错误'}")
|
||
|
||
return is_correct
|
||
|
||
except Exception as e:
|
||
logging.error(f"测试判断题时出错: {e}")
|
||
return True
|
||
|
||
def _test_multiple_choice_question(self, question_data: Dict[str, Any], original_data: Dict[str, Any]) -> bool:
|
||
"""测试选择题"""
|
||
options = question_data.get("options", {})
|
||
correct_answer = question_data.get("correct_answer", "")
|
||
original_question = original_data.get("choice_question", "")
|
||
|
||
if not options or not correct_answer or not original_question:
|
||
logging.warning("选择题数据不完整")
|
||
return True
|
||
|
||
# 构造完整的选择题
|
||
options_text = ""
|
||
for key in sorted(options.keys()):
|
||
options_text += f"{key}. {options[key]}\n"
|
||
|
||
test_prompt = f"""
|
||
以下是一道材料科学专业题目,请仔细分析每个选项,考虑所有技术细节和约束条件。
|
||
|
||
题目:{original_question}
|
||
|
||
选项:
|
||
{options_text}
|
||
|
||
请选择最准确的答案,只输出选项字母(A、B、C或D):
|
||
"""
|
||
|
||
try:
|
||
client = self.get_client()
|
||
response = client.chat.completions.create(
|
||
model=self.model_name,
|
||
messages=[
|
||
{"role": "system", "content": "你是一个材料科学专家。请深入分析题目,仔细比较各选项的技术准确性,只输出选项字母。"},
|
||
{"role": "user", "content": test_prompt}
|
||
],
|
||
temperature=0.1,
|
||
max_tokens=10
|
||
)
|
||
|
||
model_answer = response.choices[0].message.content.strip().upper()
|
||
model_choice = ""
|
||
for char in model_answer:
|
||
if char in ["A", "B", "C", "D"]:
|
||
model_choice = char
|
||
break
|
||
|
||
if not model_choice:
|
||
logging.warning(f"模型回答格式异常: {model_answer}")
|
||
return True # 格式异常默认认为答对
|
||
|
||
is_correct = model_choice == correct_answer.upper()
|
||
logging.debug(f"选择题测试 - 正确答案: {correct_answer}, 模型答案: {model_choice}, 结果: {'正确' if is_correct else '错误'}")
|
||
|
||
return is_correct
|
||
|
||
except Exception as e:
|
||
logging.error(f"测试选择题时出错: {e}")
|
||
return True
|
||
|
||
def _create_fallback_options(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""当AI生成失败时的备用选项生成"""
|
||
question_type = question_data.get("question_type", "")
|
||
correct_option = question_data.get("correct_option", "")
|
||
|
||
if question_type == "true_false":
|
||
return {
|
||
"question_type": "true_false",
|
||
"statement": question_data.get("choice_question", ""),
|
||
"options": ["True", "False"],
|
||
"correct_answer": self._determine_true_false_answer(correct_option),
|
||
"explanation": "基于题目分析的判断结果",
|
||
"sampling_summary": {
|
||
"result_type": "fallback",
|
||
"difficulty_label": "unknown_fallback",
|
||
"total_sampling_attempts": 0,
|
||
"is_early_stop": False,
|
||
"is_all_correct": False
|
||
}
|
||
}
|
||
else:
|
||
distractors = ["选项B", "选项C", "选项D"]
|
||
all_options = [correct_option] + distractors
|
||
random.shuffle(all_options)
|
||
|
||
correct_index = all_options.index(correct_option)
|
||
correct_label = ["A", "B", "C", "D"][correct_index]
|
||
|
||
return {
|
||
"question_type": "multiple_choice",
|
||
"options": {
|
||
"A": all_options[0],
|
||
"B": all_options[1],
|
||
"C": all_options[2],
|
||
"D": all_options[3]
|
||
},
|
||
"correct_answer": correct_label,
|
||
"explanation": "基于规则生成的备用选项",
|
||
"sampling_summary": {
|
||
"result_type": "fallback",
|
||
"difficulty_label": "unknown_fallback",
|
||
"total_sampling_attempts": 0,
|
||
"is_early_stop": False,
|
||
"is_all_correct": False
|
||
}
|
||
}
|
||
|
||
def _determine_true_false_answer(self, correct_option: str) -> str:
|
||
"""确定判断题的正确答案"""
|
||
true_indicators = ["true", "正确", "是", "对", "T", "√", "yes"]
|
||
false_indicators = ["false", "错误", "否", "错", "F", "×", "no"]
|
||
|
||
correct_lower = correct_option.lower().strip()
|
||
|
||
if any(indicator in correct_lower for indicator in true_indicators):
|
||
return "True"
|
||
elif any(indicator in correct_lower for indicator in false_indicators):
|
||
return "False"
|
||
else:
|
||
return "True"
|
||
|
||
# 提示词方法保持不变
|
||
def _create_enhanced_multiple_choice_prompt(self, question: str, correct_option: str, original_question: str, question_type: str) -> str:
|
||
"""增强的选择题提示词 - 专注于生成更具迷惑性的选项"""
|
||
return f"""
|
||
你是材料科学领域的顶级专家,专门设计能够挑战AI系统的高难度考试题目。你的目标是创建连GPT-4、Claude等先进AI都可能答错的题目。
|
||
|
||
题目: {question}
|
||
正确答案: {correct_option}
|
||
原始题目: {original_question}
|
||
题目类型: {question_type}
|
||
|
||
**核心策略:让AI模型产生错误的直觉判断**
|
||
|
||
**超高难度干扰项设计原则:**
|
||
|
||
1. **认知偏差利用型干扰项**:
|
||
- 利用常见的材料科学概念混淆(如弹性模量vs剪切模量)
|
||
- 设计基于"第一印象"正确但深入分析错误的选项
|
||
- 使用数值接近但单位或条件不同的陷阱
|
||
|
||
2. **专业直觉陷阱型干扰项**:
|
||
- 基于材料科学中"看似显而易见"但实际错误的推理
|
||
- 利用不同材料体系间的类比误区
|
||
- 设计需要区分相似工艺或现象的细微差别
|
||
|
||
3. **多层次验证陷阱**:
|
||
- 创建需要同时考虑多个材料参数的复杂选项
|
||
- 设计表面符合某个理论但忽略关键约束条件的选项
|
||
- 使用真实存在但在当前情境下不适用的材料数据
|
||
|
||
输出格式(严格JSON,无注释):
|
||
{{
|
||
"question_type": "multiple_choice",
|
||
"options": {{
|
||
"A": "选项A内容",
|
||
"B": "选项B内容",
|
||
"C": "选项C内容",
|
||
"D": "选项D内容"
|
||
}},
|
||
"correct_answer": "A",
|
||
"explanation": "详细解释为什么正确答案正确,以及每个干扰项的具体迷惑策略"
|
||
}}
|
||
|
||
现在,请基于以上原则设计一道能让先进AI模型都可能答错的高难度选择题。
|
||
"""
|
||
|
||
def _create_enhanced_true_false_prompt(self, question: str, correct_option: str, original_question: str) -> str:
|
||
"""增强的判断题提示词"""
|
||
return f"""
|
||
你是材料科学专家,需要设计能够挑战AI判断能力的高难度判断题。
|
||
|
||
题目: {question}
|
||
正确答案: {correct_option}
|
||
原始题目: {original_question}
|
||
|
||
**设计高难度判断题的策略:**
|
||
|
||
1. **微妙条件陷阱**:设计在特定条件下成立但一般情况下错误(或相反)的陈述
|
||
2. **精确性陷阱**:使用"总是"、"从不"、"所有"等绝对词汇的微妙误用
|
||
3. **概念边界模糊**:涉及材料科学中定义边界模糊的概念
|
||
4. **数值精度陷阱**:涉及需要精确数值判断的陈述
|
||
|
||
输出格式(严格JSON,无注释):
|
||
{{
|
||
"question_type": "true_false",
|
||
"statement": "需要判断的复杂陈述句",
|
||
"options": ["True", "False"],
|
||
"correct_answer": "True或False",
|
||
"explanation": "详细解释判断理由和可能的误解点"
|
||
}}
|
||
"""
|
||
|
||
def create_options_prompt(self, question_data: Dict[str, Any]) -> str:
|
||
"""创建生成选项的提示词"""
|
||
choice_question = question_data.get("choice_question", "")
|
||
correct_option = question_data.get("correct_option", "")
|
||
original_question = question_data.get("question", "")
|
||
question_type = question_data.get("question_type", "")
|
||
|
||
if question_type == "true_false":
|
||
return self._create_enhanced_true_false_prompt(choice_question, correct_option, original_question)
|
||
else:
|
||
return self._create_enhanced_multiple_choice_prompt(choice_question, correct_option, original_question, question_type)
|
||
|
||
def _attempt_generate_options(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""单次尝试生成选项"""
|
||
client = self.get_client()
|
||
prompt = self.create_options_prompt(question_data)
|
||
|
||
response = client.chat.completions.create(
|
||
model=self.model_name,
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个材料科学专业的教育评估专家。请严格按照要求的JSON格式输出,不要添加任何额外的文本、注释或代码块标记。确保输出的JSON语法完全正确。"
|
||
},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=0.9,
|
||
max_tokens=2000,
|
||
top_p=0.95
|
||
)
|
||
|
||
result_text = response.choices[0].message.content.strip()
|
||
logging.debug(f"AI响应: {result_text}")
|
||
|
||
json_result = self._extract_and_fix_json(result_text)
|
||
return json_result
|
||
|
||
def _extract_and_fix_json(self, response_text: str) -> Dict[str, Any]:
|
||
"""从响应文本中提取并修复JSON"""
|
||
response_text = re.sub(r'```json\s*', '', response_text)
|
||
response_text = re.sub(r'```\s*$', '', response_text)
|
||
|
||
json_start = response_text.find('{')
|
||
json_end = response_text.rfind('}') + 1
|
||
|
||
if json_start == -1 or json_end <= json_start:
|
||
raise ValueError("无法在响应中找到JSON格式内容")
|
||
|
||
json_str = response_text[json_start:json_end]
|
||
json_str = self._fix_json_syntax(json_str)
|
||
|
||
try:
|
||
return json.loads(json_str)
|
||
except json.JSONDecodeError as e:
|
||
logging.error(f"JSON解析失败: {e}")
|
||
json_str = self._aggressive_json_fix(json_str)
|
||
return json.loads(json_str)
|
||
|
||
def _fix_json_syntax(self, json_str: str) -> str:
|
||
"""修复常见的JSON语法错误"""
|
||
json_str = re.sub(r'//.*?(?=\n|$)', '', json_str)
|
||
json_str = re.sub(r'/\*.*?\*/', '', json_str, flags=re.DOTALL)
|
||
json_str = re.sub(r',\s*}', '}', json_str)
|
||
json_str = re.sub(r',\s*]', ']', json_str)
|
||
json_str = re.sub(r"'([^']*)':", r'"\1":', json_str)
|
||
json_str = re.sub(r":\s*'([^']*)'", r': "\1"', json_str)
|
||
|
||
return json_str
|
||
|
||
def _aggressive_json_fix(self, json_str: str) -> str:
|
||
"""更激进的JSON修复方法"""
|
||
try:
|
||
patterns = {
|
||
'question_type': r'"question_type"\s*:\s*"([^"]*)"',
|
||
'correct_answer': r'"correct_answer"\s*:\s*"([^"]*)"',
|
||
'explanation': r'"explanation"\s*:\s*"([^"]*)"'
|
||
}
|
||
|
||
extracted = {}
|
||
for key, pattern in patterns.items():
|
||
match = re.search(pattern, json_str)
|
||
if match:
|
||
extracted[key] = match.group(1)
|
||
|
||
options_match = re.search(r'"options"\s*:\s*{([^}]*)}', json_str)
|
||
if options_match:
|
||
options_content = options_match.group(1)
|
||
options = {}
|
||
option_pattern = r'"([ABCD])"\s*:\s*"([^"]*)"'
|
||
for match in re.finditer(option_pattern, options_content):
|
||
options[match.group(1)] = match.group(2)
|
||
extracted['options'] = options
|
||
|
||
if 'question_type' in extracted and len(extracted) >= 3:
|
||
return json.dumps(extracted, ensure_ascii=False)
|
||
|
||
except Exception as e:
|
||
logging.error(f"激进修复失败: {e}")
|
||
|
||
raise ValueError("无法修复JSON格式")
|
||
|
||
def _validate_options_quality(self, result: Dict[str, Any], original_data: Dict[str, Any]) -> bool:
|
||
"""验证生成选项的质量"""
|
||
if not result:
|
||
return False
|
||
|
||
question_type = result.get("question_type", "")
|
||
|
||
if question_type == "true_false":
|
||
return self._validate_true_false_quality(result)
|
||
elif question_type == "multiple_choice":
|
||
return self._validate_multiple_choice_quality(result, original_data)
|
||
|
||
return False
|
||
|
||
def _validate_true_false_quality(self, result: Dict[str, Any]) -> bool:
|
||
"""验证判断题质量"""
|
||
required_fields = ["statement", "options", "correct_answer", "explanation"]
|
||
|
||
if not all(field in result for field in required_fields):
|
||
return False
|
||
|
||
options = result.get("options", [])
|
||
if not (len(options) == 2 and "True" in options and "False" in options):
|
||
return False
|
||
|
||
correct_answer = result.get("correct_answer", "")
|
||
if correct_answer not in ["True", "False"]:
|
||
return False
|
||
|
||
return True
|
||
|
||
def _validate_multiple_choice_quality(self, result: Dict[str, Any], original_data: Dict[str, Any]) -> bool:
|
||
"""验证选择题质量"""
|
||
if not all(key in result for key in ["options", "correct_answer", "explanation"]):
|
||
return False
|
||
|
||
options = result.get("options", {})
|
||
|
||
if len(options) != 4 or not all(label in options for label in ["A", "B", "C", "D"]):
|
||
return False
|
||
|
||
correct_answer = result.get("correct_answer", "")
|
||
if correct_answer not in ["A", "B", "C", "D"]:
|
||
return False
|
||
|
||
if any(len(str(option).strip()) < 2 for option in options.values()):
|
||
return False
|
||
|
||
option_values = [str(option).strip().lower() for option in options.values()]
|
||
if len(set(option_values)) != 4:
|
||
return False
|
||
|
||
return True
|
||
|
||
def generate_options(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""为单个题目生成选项,使用多次采样策略"""
|
||
result = self.generate_options_with_sampling(question_data)
|
||
|
||
if result:
|
||
return result
|
||
|
||
logging.warning("采样生成失败,回退到原始生成方法")
|
||
return self._generate_with_basic_retry(question_data)
|
||
|
||
def _generate_with_basic_retry(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""基础重试生成方法"""
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
result = self._attempt_generate_options(question_data)
|
||
|
||
if self._validate_options_quality(result, question_data):
|
||
# 为基础重试添加采样信息
|
||
result["sampling_summary"] = {
|
||
"result_type": "basic_retry",
|
||
"difficulty_label": "unknown_retry",
|
||
"total_sampling_attempts": 1,
|
||
"is_early_stop": False,
|
||
"is_all_correct": False
|
||
}
|
||
return result
|
||
else:
|
||
if attempt < self.max_retries - 1:
|
||
logging.warning(f"第{attempt+1}次生成的选项质量不佳,重试中...")
|
||
time.sleep(1)
|
||
continue
|
||
|
||
except Exception as e:
|
||
logging.error(f"第{attempt+1}次生成选项失败: {e}")
|
||
if attempt < self.max_retries - 1:
|
||
time.sleep(2)
|
||
continue
|
||
|
||
logging.error("所有重试都失败,使用备用选项生成")
|
||
return self._create_fallback_options(question_data)
|
||
|
||
|
||
def process_single_question(generator, question, question_index):
|
||
"""处理单个题目的函数"""
|
||
try:
|
||
options_data = generator.generate_options(question)
|
||
|
||
complete_question = question.copy()
|
||
complete_question["generated_options"] = options_data
|
||
complete_question["generation_status"] = "success"
|
||
complete_question["question_index"] = question_index
|
||
|
||
# 提取采样信息用于日志
|
||
sampling_info = options_data.get("sampling_summary", {})
|
||
difficulty_label = sampling_info.get("difficulty_label", "unknown")
|
||
attempts = sampling_info.get("total_sampling_attempts", 1)
|
||
is_early_stop = sampling_info.get("is_early_stop", False)
|
||
is_all_correct = sampling_info.get("is_all_correct", False)
|
||
|
||
status_emoji = {
|
||
"hard_early_stop": "🔥",
|
||
"easy_all_correct": "✅",
|
||
"mixed": "⚡",
|
||
"unknown_fallback": "❓",
|
||
"unknown_retry": "🔄"
|
||
}
|
||
|
||
logging.info(f"第{question_index+1}题完成 - {difficulty_label} - 采样{attempts}次 - {'早停' if is_early_stop else '全采样'}")
|
||
|
||
return complete_question
|
||
|
||
except Exception as e:
|
||
logging.error(f"第{question_index+1}题处理失败: {e}")
|
||
|
||
failed_question = question.copy()
|
||
failed_question["generated_options"] = generator._create_fallback_options(question)
|
||
failed_question["generation_status"] = "failed"
|
||
failed_question["error_message"] = str(e)
|
||
failed_question["question_index"] = question_index
|
||
|
||
return failed_question
|
||
|
||
|
||
def main():
|
||
# 配置信息
|
||
API_KEY = "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
|
||
BASE_URL = "https://vip.apiyi.com/v1"
|
||
MODEL_NAME = "deepseek-chat"
|
||
MAX_WORKERS = 20
|
||
|
||
INPUT_FILE = "/home/ubuntu/50T/LYT/MatBench/layer2/PGEE/code/step7_no_perp_convertible.json"
|
||
OUTPUT_FILE = "/home/ubuntu/50T/LYT/MatBench/layer2/PGEE/code/stepy_complete_choice_questions_with_sampling.json"
|
||
|
||
# 加载数据
|
||
print("正在加载题目数据...")
|
||
with open(INPUT_FILE, 'r', encoding='utf-8') as f:
|
||
questions = json.load(f)
|
||
|
||
import random
|
||
random.shuffle(questions) # 随机打乱题目顺序
|
||
# questions = questions[:100] # 限制处理前100道题目以便测试
|
||
print(f"加载了 {len(questions)} 道题目")
|
||
|
||
# 统计题目类型分布
|
||
type_counts = {}
|
||
for q in questions:
|
||
qtype = q.get("question_type", "unknown")
|
||
type_counts[qtype] = type_counts.get(qtype, 0) + 1
|
||
|
||
print("题目类型分布:")
|
||
for qtype, count in type_counts.items():
|
||
print(f" {qtype}: {count} 道")
|
||
|
||
# 初始化生成器
|
||
generator = ChoiceOptionsGenerator(API_KEY, BASE_URL, MODEL_NAME, MAX_WORKERS)
|
||
|
||
print(f"\n开始生成选项,每题最多采样{generator.max_sampling_attempts}次...")
|
||
print("策略:答错题目会早停,答对题目会继续采样直到上限")
|
||
|
||
# 使用ThreadPoolExecutor进行并发处理
|
||
|
||
# 使用ThreadPoolExecutor进行并发处理
|
||
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
||
# 提交所有任务
|
||
future_to_question = {
|
||
executor.submit(process_single_question, generator, question, i): (question, i)
|
||
for i, question in enumerate(questions)
|
||
}
|
||
|
||
# 使用tqdm显示进度
|
||
with tqdm(total=len(questions), desc="生成选项") as pbar:
|
||
# 收集结果
|
||
temp_results = []
|
||
|
||
for future in as_completed(future_to_question):
|
||
try:
|
||
result = future.result()
|
||
temp_results.append(result)
|
||
|
||
pbar.update(1)
|
||
|
||
# 更新进度条描述信息
|
||
sampling_info = result.get("generated_options", {}).get("sampling_summary", {})
|
||
difficulty_label = sampling_info.get("difficulty_label", "unknown")
|
||
|
||
status_emoji = {
|
||
"hard_early_stop": "🔥",
|
||
"easy_all_correct": "✅",
|
||
"mixed": "⚡",
|
||
"unknown_fallback": "❓",
|
||
"unknown_retry": "🔄"
|
||
}
|
||
|
||
desc = f"生成选项 {status_emoji.get(difficulty_label, '❓')}"
|
||
pbar.set_description(desc)
|
||
|
||
except Exception as e:
|
||
logging.error(f"处理结果时发生错误: {e}")
|
||
original_question, question_index = future_to_question[future]
|
||
|
||
# 创建失败结果
|
||
failed_result = original_question.copy()
|
||
failed_result["generated_options"] = generator._create_fallback_options(original_question)
|
||
failed_result["generation_status"] = "processing_failed"
|
||
failed_result["error_message"] = str(e)
|
||
failed_result["question_index"] = question_index
|
||
|
||
temp_results.append(failed_result)
|
||
pbar.update(1)
|
||
|
||
# 按原始顺序排序结果
|
||
complete_questions = sorted(temp_results, key=lambda x: x.get("question_index", 0))
|
||
|
||
# 移除临时的索引字段
|
||
for question in complete_questions:
|
||
if "question_index" in question:
|
||
del question["question_index"]
|
||
|
||
# 统计采样结果
|
||
print("\n=== 采样结果统计 ===")
|
||
|
||
sampling_stats = {
|
||
"hard_early_stop": 0, # 答错后早停的困难题
|
||
"easy_all_correct": 0, # 全部采样都答对的简单题
|
||
"mixed": 0, # 混合结果
|
||
"unknown_fallback": 0, # 备用方案
|
||
"unknown_retry": 0, # 重试方案
|
||
"total": len(complete_questions)
|
||
}
|
||
|
||
early_stop_questions = []
|
||
all_correct_questions = []
|
||
|
||
total_api_calls = 0
|
||
total_generation_calls = 0
|
||
total_validation_calls = 0
|
||
|
||
for q in complete_questions:
|
||
options_data = q.get("generated_options", {})
|
||
sampling_info = options_data.get("sampling_summary", {})
|
||
|
||
difficulty_label = sampling_info.get("difficulty_label", "unknown_fallback")
|
||
is_early_stop = sampling_info.get("is_early_stop", False)
|
||
is_all_correct = sampling_info.get("is_all_correct", False)
|
||
attempts = sampling_info.get("total_sampling_attempts", 0)
|
||
|
||
# 统计标签分布
|
||
if difficulty_label in sampling_stats:
|
||
sampling_stats[difficulty_label] += 1
|
||
|
||
# 收集特殊类别的题目
|
||
if is_early_stop:
|
||
early_stop_questions.append(q)
|
||
|
||
if is_all_correct:
|
||
all_correct_questions.append(q)
|
||
|
||
# 统计API调用次数
|
||
total_generation_calls += attempts
|
||
# 每次采样都需要验证(除了备用方案)
|
||
if difficulty_label not in ["unknown_fallback", "unknown_retry"]:
|
||
total_validation_calls += attempts
|
||
|
||
total_api_calls = total_generation_calls + total_validation_calls
|
||
|
||
# 输出统计结果
|
||
print("题目标记分布:")
|
||
for label, count in sampling_stats.items():
|
||
if label != "total" and count > 0:
|
||
percentage = (count / sampling_stats["total"]) * 100
|
||
print(f" {label}: {count} 道 ({percentage:.1f}%)")
|
||
|
||
print(f"\n关键指标:")
|
||
print(f" 早停困难题(答错后早停): {len(early_stop_questions)} 道")
|
||
print(f" 全正确简单题(所有采样都答对): {len(all_correct_questions)} 道")
|
||
print(f" 早停率: {len(early_stop_questions)/len(complete_questions)*100:.1f}%")
|
||
print(f" 全正确率: {len(all_correct_questions)/len(complete_questions)*100:.1f}%")
|
||
|
||
# API成本统计
|
||
print(f"\n=== API调用统计 ===")
|
||
print(f"总生成调用: {total_generation_calls}")
|
||
print(f"总验证调用: {total_validation_calls}")
|
||
print(f"总API调用: {total_api_calls}")
|
||
print(f"平均每题调用: {total_api_calls/len(complete_questions):.1f}")
|
||
|
||
# 采样效率分析
|
||
if early_stop_questions:
|
||
early_stop_attempts = [q.get("generated_options", {}).get("sampling_summary", {}).get("total_sampling_attempts", 0)
|
||
for q in early_stop_questions]
|
||
avg_early_stop_attempts = sum(early_stop_attempts) / len(early_stop_attempts)
|
||
print(f"早停题目平均采样次数: {avg_early_stop_attempts:.1f}")
|
||
|
||
if all_correct_questions:
|
||
all_correct_attempts = [q.get("generated_options", {}).get("sampling_summary", {}).get("total_sampling_attempts", 0)
|
||
for q in all_correct_questions]
|
||
avg_all_correct_attempts = sum(all_correct_attempts) / len(all_correct_attempts)
|
||
print(f"全正确题目平均采样次数: {avg_all_correct_attempts:.1f}")
|
||
|
||
# 按题目类型分析
|
||
print(f"\n=== 各题型采样效果 ===")
|
||
type_sampling_analysis = {}
|
||
|
||
for q in complete_questions:
|
||
qtype = q.get("question_type", "unknown")
|
||
options_data = q.get("generated_options", {})
|
||
sampling_info = options_data.get("sampling_summary", {})
|
||
difficulty_label = sampling_info.get("difficulty_label", "unknown")
|
||
|
||
if qtype not in type_sampling_analysis:
|
||
type_sampling_analysis[qtype] = {
|
||
"hard_early_stop": 0,
|
||
"easy_all_correct": 0,
|
||
"mixed": 0,
|
||
"unknown": 0,
|
||
"total": 0
|
||
}
|
||
|
||
type_sampling_analysis[qtype]["total"] += 1
|
||
|
||
if difficulty_label == "hard_early_stop":
|
||
type_sampling_analysis[qtype]["hard_early_stop"] += 1
|
||
elif difficulty_label == "easy_all_correct":
|
||
type_sampling_analysis[qtype]["easy_all_correct"] += 1
|
||
elif difficulty_label == "mixed":
|
||
type_sampling_analysis[qtype]["mixed"] += 1
|
||
else:
|
||
type_sampling_analysis[qtype]["unknown"] += 1
|
||
|
||
for qtype, stats in type_sampling_analysis.items():
|
||
if stats["total"] > 0:
|
||
print(f"{qtype}:")
|
||
early_stop_rate = (stats["hard_early_stop"] / stats["total"]) * 100
|
||
all_correct_rate = (stats["easy_all_correct"] / stats["total"]) * 100
|
||
print(f" 早停率: {early_stop_rate:.1f}% ({stats['hard_early_stop']}/{stats['total']})")
|
||
print(f" 全正确率: {all_correct_rate:.1f}% ({stats['easy_all_correct']}/{stats['total']})")
|
||
|
||
# 保存结果
|
||
final_output = {
|
||
"questions": complete_questions,
|
||
"sampling_statistics": {
|
||
"label_distribution": {k: v for k, v in sampling_stats.items() if k != "total"},
|
||
"early_stop_count": len(early_stop_questions),
|
||
"all_correct_count": len(all_correct_questions),
|
||
"early_stop_rate": len(early_stop_questions)/len(complete_questions),
|
||
"all_correct_rate": len(all_correct_questions)/len(complete_questions),
|
||
"total_questions": len(complete_questions)
|
||
},
|
||
"api_usage": {
|
||
"total_generation_calls": total_generation_calls,
|
||
"total_validation_calls": total_validation_calls,
|
||
"total_api_calls": total_api_calls,
|
||
"average_calls_per_question": total_api_calls/len(complete_questions)
|
||
},
|
||
"generation_metadata": {
|
||
"generation_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"model_used": MODEL_NAME,
|
||
"max_sampling_attempts": generator.max_sampling_attempts,
|
||
"success_rate": sum(1 for q in complete_questions if q.get("generation_status") == "success") / len(complete_questions)
|
||
}
|
||
}
|
||
|
||
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
|
||
json.dump(final_output, f, ensure_ascii=False, indent=2)
|
||
|
||
# 输出成功率统计
|
||
success_count = sum(1 for q in complete_questions if q.get("generation_status") == "success")
|
||
failed_count = len(complete_questions) - success_count
|
||
|
||
print(f"\n=== 生成成功率统计 ===")
|
||
print(f"总共处理: {len(complete_questions)} 道题目")
|
||
print(f"成功生成: {success_count} 道")
|
||
print(f"使用备用方案: {failed_count} 道")
|
||
print(f"成功率: {success_count/len(complete_questions)*100:.2f}%")
|
||
|
||
# 策略效果评估
|
||
print(f"\n=== 策略效果评估 ===")
|
||
|
||
if len(early_stop_questions) > 0:
|
||
print("✅ 早停策略有效:成功识别出困难题目")
|
||
print(f" 困难题目数量: {len(early_stop_questions)} 道")
|
||
|
||
# 展示几个早停题目的例子
|
||
print(" 早停题目示例:")
|
||
for i, q in enumerate(early_stop_questions[:3]): # 只显示前3个
|
||
qtype = q.get("question_type", "unknown")
|
||
attempts = q.get("generated_options", {}).get("sampling_summary", {}).get("total_sampling_attempts", 0)
|
||
print(f" {i+1}. {qtype}题,第{attempts}次采样后早停")
|
||
else:
|
||
print("⚠️ 没有题目触发早停,可能需要调整难度")
|
||
|
||
if len(all_correct_questions) > 0:
|
||
print("✅ 全采样策略有效:识别出简单题目")
|
||
print(f" 简单题目数量: {len(all_correct_questions)} 道")
|
||
|
||
# 展示几个全正确题目的例子
|
||
print(" 全正确题目示例:")
|
||
for i, q in enumerate(all_correct_questions[:3]): # 只显示前3个
|
||
qtype = q.get("question_type", "unknown")
|
||
attempts = q.get("generated_options", {}).get("sampling_summary", {}).get("total_sampling_attempts", 0)
|
||
print(f" {i+1}. {qtype}题,{attempts}次采样全部答对")
|
||
else:
|
||
print("⚠️ 没有题目全部答对,生成的题目可能都比较困难")
|
||
|
||
# 给出优化建议
|
||
print(f"\n=== 优化建议 ===")
|
||
|
||
early_stop_rate = len(early_stop_questions)/len(complete_questions)
|
||
all_correct_rate = len(all_correct_questions)/len(complete_questions)
|
||
|
||
if early_stop_rate < 0.2:
|
||
print("• 早停率偏低,建议:")
|
||
print(" - 增强提示词的迷惑性设计")
|
||
print(" - 提高选项生成的创造性(增加temperature)")
|
||
print(" - 添加更多AI容易犯错的陷阱类型")
|
||
|
||
if all_correct_rate > 0.6:
|
||
print("• 全正确率过高,建议:")
|
||
print(" - 检查题目是否过于简单")
|
||
print(" - 提升干扰选项的质量")
|
||
print(" - 增加专业深度和复杂性")
|
||
|
||
if early_stop_rate > 0.8:
|
||
print("• 早停率过高,建议:")
|
||
print(" - 适当降低题目难度")
|
||
print(" - 平衡难易程度分布")
|
||
print(" - 检查是否过度设计陷阱")
|
||
|
||
avg_api_calls = total_api_calls/len(complete_questions)
|
||
if avg_api_calls > 8:
|
||
print("• API调用次数偏高,建议:")
|
||
print(" - 优化提示词提高首次生成质量")
|
||
print(" - 考虑减少最大采样次数")
|
||
print(" - 改进验证逻辑减少失败率")
|
||
|
||
print(f"\n结果已保存到: {OUTPUT_FILE}")
|
||
print("包含完整的题目数据、采样统计和API使用情况")
|
||
|
||
|
||
def export_analysis_report(questions: List[Dict], output_path: str):
|
||
"""导出分析报告"""
|
||
early_stop_questions = []
|
||
all_correct_questions = []
|
||
mixed_questions = []
|
||
|
||
for q in questions:
|
||
options_data = q.get("generated_options", {})
|
||
sampling_info = options_data.get("sampling_summary", {})
|
||
difficulty_label = sampling_info.get("difficulty_label", "unknown")
|
||
|
||
if difficulty_label == "hard_early_stop":
|
||
early_stop_questions.append(q)
|
||
elif difficulty_label == "easy_all_correct":
|
||
all_correct_questions.append(q)
|
||
elif difficulty_label == "mixed":
|
||
mixed_questions.append(q)
|
||
|
||
report = {
|
||
"summary": {
|
||
"total_questions": len(questions),
|
||
"early_stop_questions": len(early_stop_questions),
|
||
"all_correct_questions": len(all_correct_questions),
|
||
"mixed_questions": len(mixed_questions),
|
||
"early_stop_rate": len(early_stop_questions) / len(questions),
|
||
"all_correct_rate": len(all_correct_questions) / len(questions)
|
||
},
|
||
"early_stop_examples": early_stop_questions[:10], # 前10个早停例子
|
||
"all_correct_examples": all_correct_questions[:10], # 前10个全正确例子
|
||
"mixed_examples": mixed_questions[:5] # 前5个混合例子
|
||
}
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(report, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"分析报告已保存到: {output_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|