594 lines
21 KiB
Python
594 lines
21 KiB
Python
import json
|
||
from typing import Dict, Any, List, Optional, Tuple
|
||
import random
|
||
from collections import Counter
|
||
|
||
def convert_to_target_format(source_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
将源JSON格式转换为目标格式
|
||
"""
|
||
if "generated_options" not in source_data:
|
||
return None
|
||
|
||
generated_options = source_data["generated_options"]
|
||
|
||
# 只处理单选题
|
||
if generated_options.get("question_type") != "multiple_choice":
|
||
return None
|
||
|
||
question = source_data.get("choice_question", "")
|
||
if not question:
|
||
return None
|
||
|
||
options = generated_options.get("options", {})
|
||
if len(options) != 4:
|
||
return None
|
||
|
||
correct_answer = generated_options.get("correct_answer", "")
|
||
if correct_answer not in ["A", "B", "C", "D"]:
|
||
return None
|
||
|
||
target_data = {
|
||
"question": question,
|
||
"choices": {
|
||
"text": [
|
||
options.get("A", ""),
|
||
options.get("B", ""),
|
||
options.get("C", ""),
|
||
options.get("D", "")
|
||
],
|
||
"label": ["A", "B", "C", "D"]
|
||
},
|
||
"answer": f"[ANSWER]{correct_answer}[/ANSWER]",
|
||
"prompt": "You are an expert in materials science. Please answer the following materials science question by selecting the correct option. You MUST include the letter of the correct answer at the end of your response within the following tags: [ANSWER] and [/ANSWER]. For example: [ANSWER]A[/ANSWER]."
|
||
}
|
||
|
||
return target_data
|
||
|
||
def extract_answer_from_question(question: Dict[str, Any]) -> Optional[str]:
|
||
"""从转换后的题目中提取答案选项"""
|
||
answer_text = question.get("answer", "")
|
||
if answer_text.startswith("[ANSWER]") and answer_text.endswith("[/ANSWER]"):
|
||
answer = answer_text[8:-9]
|
||
if answer in ["A", "B", "C", "D"]:
|
||
return answer
|
||
return None
|
||
|
||
def shuffle_question_options(question: Dict[str, Any], new_correct_answer: str) -> Dict[str, Any]:
|
||
"""
|
||
重新排列题目选项,使正确答案变为指定选项
|
||
|
||
Args:
|
||
question: 题目字典
|
||
new_correct_answer: 新的正确答案选项 (A/B/C/D)
|
||
|
||
Returns:
|
||
重新排列后的题目
|
||
"""
|
||
# 获取当前正确答案
|
||
current_answer = extract_answer_from_question(question)
|
||
if not current_answer:
|
||
return question
|
||
|
||
# 如果已经是目标答案,不需要改变
|
||
if current_answer == new_correct_answer:
|
||
return question
|
||
|
||
# 获取当前选项
|
||
choices = question.get("choices", {})
|
||
current_texts = choices.get("text", [])
|
||
current_labels = choices.get("label", ["A", "B", "C", "D"])
|
||
|
||
if len(current_texts) != 4 or len(current_labels) != 4:
|
||
return question
|
||
|
||
# 找到当前正确答案的索引
|
||
current_index = current_labels.index(current_answer)
|
||
new_index = current_labels.index(new_correct_answer)
|
||
|
||
# 交换选项
|
||
new_texts = current_texts[:]
|
||
new_texts[new_index], new_texts[current_index] = new_texts[current_index], new_texts[new_index]
|
||
|
||
# 创建新的题目
|
||
new_question = question.copy()
|
||
new_question["choices"] = {
|
||
"text": new_texts,
|
||
"label": ["A", "B", "C", "D"]
|
||
}
|
||
new_question["answer"] = f"[ANSWER]{new_correct_answer}[/ANSWER]"
|
||
|
||
return new_question
|
||
|
||
def balance_answer_distribution_by_shuffling(questions: List[Dict[str, Any]],
|
||
random_seed: Optional[int] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||
"""
|
||
通过重新排列选项来平衡答案分布
|
||
|
||
Args:
|
||
questions: 题目列表
|
||
random_seed: 随机种子
|
||
|
||
Returns:
|
||
平衡后的题目列表和统计信息
|
||
"""
|
||
if random_seed is not None:
|
||
random.seed(random_seed)
|
||
|
||
total_questions = len(questions)
|
||
target_per_answer = total_questions // 4
|
||
remainder = total_questions % 4
|
||
|
||
print(f"\n=== 答案分布平衡 (重排选项法) ===")
|
||
print(f"总题目数: {total_questions}")
|
||
print(f"标准分配: 每个选项 {target_per_answer} 道题")
|
||
if remainder > 0:
|
||
print(f"余数: {remainder} 道题 (将分配给前{remainder}个选项)")
|
||
|
||
# 统计当前答案分布
|
||
answer_groups = {"A": [], "B": [], "C": [], "D": []}
|
||
for i, question in enumerate(questions):
|
||
answer = extract_answer_from_question(question)
|
||
if answer and answer in answer_groups:
|
||
answer_groups[answer].append((i, question))
|
||
|
||
print(f"\n当前答案分布:")
|
||
for answer in ["A", "B", "C", "D"]:
|
||
count = len(answer_groups[answer])
|
||
ratio = count / total_questions if total_questions > 0 else 0
|
||
print(f" {answer}: {count} ({ratio*100:.1f}%)")
|
||
|
||
# 计算目标分配(前remainder个选项多分配1道题)
|
||
target_counts = {}
|
||
for i, answer in enumerate(["A", "B", "C", "D"]):
|
||
if i < remainder:
|
||
target_counts[answer] = target_per_answer + 1
|
||
else:
|
||
target_counts[answer] = target_per_answer
|
||
|
||
print(f"\n目标分配:")
|
||
for answer in ["A", "B", "C", "D"]:
|
||
print(f" {answer}: {target_counts[answer]} 道题")
|
||
|
||
# 计算需要调整的数量
|
||
surplus_questions = [] # (question_index, question, from_answer)
|
||
deficit_needed = [] # (to_answer, count_needed)
|
||
|
||
for answer in ["A", "B", "C", "D"]:
|
||
current_count = len(answer_groups[answer])
|
||
target_count = target_counts[answer]
|
||
difference = current_count - target_count
|
||
|
||
if difference > 0:
|
||
# 有多余的题目,需要转移出去
|
||
print(f" {answer}: 多 {difference} 道题")
|
||
# 随机选择要转移的题目
|
||
questions_to_move = random.sample(answer_groups[answer], difference)
|
||
for q_idx, q in questions_to_move:
|
||
surplus_questions.append((q_idx, q, answer))
|
||
elif difference < 0:
|
||
# 缺少题目,需要接收
|
||
needed = -difference
|
||
print(f" {answer}: 少 {needed} 道题")
|
||
deficit_needed.extend([(answer, 1)] * needed)
|
||
|
||
# 打乱顺序以避免偏向性
|
||
random.shuffle(surplus_questions)
|
||
random.shuffle(deficit_needed)
|
||
|
||
# 执行调整
|
||
balanced_questions = questions[:] # 复制原题目列表
|
||
|
||
print(f"\n开始重新分配 {len(surplus_questions)} 道题:")
|
||
|
||
for i, ((q_idx, question, from_answer), (to_answer, _)) in enumerate(zip(surplus_questions, deficit_needed)):
|
||
# 重新排列这道题的选项
|
||
new_question = shuffle_question_options(question, to_answer)
|
||
balanced_questions[q_idx] = new_question
|
||
|
||
print(f" 第{i+1}次调整: 题目{q_idx+1} 答案从 {from_answer} 改为 {to_answer}")
|
||
|
||
# 验证最终分布
|
||
final_counter = Counter()
|
||
for question in balanced_questions:
|
||
answer = extract_answer_from_question(question)
|
||
if answer:
|
||
final_counter[answer] += 1
|
||
|
||
print(f"\n平衡后答案分布:")
|
||
max_deviation = 0
|
||
target_ratio = 0.25
|
||
|
||
for answer in ["A", "B", "C", "D"]:
|
||
count = final_counter.get(answer, 0)
|
||
ratio = count / total_questions if total_questions > 0 else 0
|
||
deviation = abs(ratio - target_ratio)
|
||
max_deviation = max(max_deviation, deviation)
|
||
print(f" {answer}: {count} ({ratio*100:.1f}%)")
|
||
|
||
# 统计信息
|
||
balance_info = {
|
||
"original_total": total_questions,
|
||
"final_total": total_questions, # 题目总数不变
|
||
"target_per_answer": target_per_answer,
|
||
"remainder": remainder,
|
||
"final_distribution": dict(final_counter),
|
||
"max_deviation": max_deviation,
|
||
"adjustments_made": len(surplus_questions),
|
||
"perfectly_balanced": max_deviation <= 0.05
|
||
}
|
||
|
||
if balance_info["perfectly_balanced"]:
|
||
print(f"✅ 完美平衡!最大偏差: {max_deviation*100:.1f}%")
|
||
else:
|
||
print(f"📊 接近平衡,最大偏差: {max_deviation*100:.1f}%")
|
||
|
||
print(f"总共调整了 {balance_info['adjustments_made']} 道题的答案")
|
||
|
||
return balanced_questions, balance_info
|
||
|
||
def classify_questions_by_difficulty(questions: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
|
||
"""
|
||
按难度分类题目
|
||
|
||
Args:
|
||
questions: 题目列表
|
||
|
||
Returns:
|
||
按难度分类的题目字典
|
||
"""
|
||
difficulty_groups = {
|
||
"hard_early_stop": [], # 困难题(答错后早停)
|
||
"easy_all_correct": [], # 简单题(所有采样都答对)
|
||
"mixed": [], # 混合题(部分对部分错)
|
||
"unknown": [] # 未知难度
|
||
}
|
||
|
||
for question in questions:
|
||
generated_options = question.get("generated_options", {})
|
||
sampling_summary = generated_options.get("sampling_summary", {})
|
||
difficulty_label = sampling_summary.get("difficulty_label", "unknown")
|
||
|
||
if difficulty_label in difficulty_groups:
|
||
difficulty_groups[difficulty_label].append(question)
|
||
else:
|
||
difficulty_groups["unknown"].append(question)
|
||
|
||
return difficulty_groups
|
||
|
||
def select_questions_by_ratio(difficulty_groups: Dict[str, List[Dict[str, Any]]],
|
||
selection_ratios: Dict[str, float],
|
||
random_seed: Optional[int] = None) -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
|
||
"""
|
||
按比例选择题目
|
||
|
||
Args:
|
||
difficulty_groups: 按难度分类的题目
|
||
selection_ratios: 各难度等级的选择比例 (0.0-1.0)
|
||
random_seed: 随机种子
|
||
|
||
Returns:
|
||
选中的题目列表和选择统计信息
|
||
"""
|
||
if random_seed is not None:
|
||
random.seed(random_seed)
|
||
|
||
selected_questions = []
|
||
selection_stats = {}
|
||
|
||
for difficulty, questions in difficulty_groups.items():
|
||
total_count = len(questions)
|
||
ratio = selection_ratios.get(difficulty, 0.0)
|
||
|
||
# 计算要选择的题目数量
|
||
if ratio <= 0:
|
||
selected_count = 0
|
||
elif ratio >= 1:
|
||
selected_count = total_count
|
||
else:
|
||
selected_count = int(total_count * ratio)
|
||
|
||
# 随机选择题目
|
||
if selected_count > 0 and total_count > 0:
|
||
if selected_count >= total_count:
|
||
selected = questions
|
||
else:
|
||
selected = random.sample(questions, selected_count)
|
||
selected_questions.extend(selected)
|
||
else:
|
||
selected = []
|
||
|
||
# 记录统计信息
|
||
selection_stats[difficulty] = {
|
||
"total": total_count,
|
||
"selected": len(selected),
|
||
"ratio_target": ratio,
|
||
"ratio_actual": len(selected) / total_count if total_count > 0 else 0
|
||
}
|
||
|
||
# 打乱最终题目顺序
|
||
random.shuffle(selected_questions)
|
||
|
||
return selected_questions, selection_stats
|
||
|
||
def batch_convert_questions_with_difficulty_filter(input_file: str,
|
||
output_file: str,
|
||
selection_ratios: Dict[str, float],
|
||
balance_answers: bool = True,
|
||
random_seed: Optional[int] = None) -> None:
|
||
"""
|
||
批量转换题目格式,支持按难度筛选和答案平衡
|
||
|
||
Args:
|
||
input_file: 输入文件路径
|
||
output_file: 输出文件路径
|
||
selection_ratios: 各难度等级的选择比例
|
||
balance_answers: 是否平衡答案分布
|
||
random_seed: 随机种子
|
||
"""
|
||
print("=== 批量转换题目(难度筛选 + 答案平衡)===")
|
||
print(f"输入文件: {input_file}")
|
||
print(f"输出文件: {output_file}")
|
||
print(f"答案平衡: {'开启' if balance_answers else '关闭'}")
|
||
print(f"随机种子: {random_seed}")
|
||
|
||
# 加载数据
|
||
print("\n正在加载数据...")
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
# 处理两种可能的输入格式
|
||
if isinstance(data, dict) and "questions" in data:
|
||
source_questions = data["questions"]
|
||
print(f"检测到完整格式数据,包含其他元数据")
|
||
elif isinstance(data, list):
|
||
source_questions = data
|
||
print(f"检测到题目列表格式")
|
||
else:
|
||
raise ValueError("不支持的输入文件格式")
|
||
|
||
print(f"加载了 {len(source_questions)} 道题目")
|
||
|
||
# 按难度分类题目
|
||
print("\n正在按难度分类题目...")
|
||
difficulty_groups = classify_questions_by_difficulty(source_questions)
|
||
|
||
print("题目难度分布:")
|
||
total_multiple_choice = 0
|
||
for difficulty, questions in difficulty_groups.items():
|
||
# 统计该难度下的单选题数量
|
||
mc_count = sum(1 for q in questions
|
||
if q.get("generated_options", {}).get("question_type") == "multiple_choice")
|
||
total_multiple_choice += mc_count
|
||
print(f" {difficulty}: {len(questions)} 道总题目, {mc_count} 道单选题")
|
||
|
||
print(f"可转换的单选题总数: {total_multiple_choice}")
|
||
|
||
# 按比例选择题目
|
||
print("\n正在按比例选择题目...")
|
||
print("选择比例设置:")
|
||
for difficulty, ratio in selection_ratios.items():
|
||
if difficulty in difficulty_groups:
|
||
print(f" {difficulty}: {ratio*100:.1f}%")
|
||
|
||
selected_questions, selection_stats = select_questions_by_ratio(
|
||
difficulty_groups, selection_ratios, random_seed
|
||
)
|
||
|
||
print(f"\n题目选择结果:")
|
||
total_selected = 0
|
||
for difficulty, stats in selection_stats.items():
|
||
print(f" {difficulty}:")
|
||
print(f" 总数: {stats['total']}")
|
||
print(f" 选中: {stats['selected']}")
|
||
print(f" 目标比例: {stats['ratio_target']*100:.1f}%")
|
||
print(f" 实际比例: {stats['ratio_actual']*100:.1f}%")
|
||
total_selected += stats['selected']
|
||
|
||
print(f"总共选中: {total_selected} 道题目")
|
||
|
||
# 转换选中的题目
|
||
print("\n正在转换题目格式...")
|
||
converted_questions = []
|
||
conversion_stats = {
|
||
"selected": total_selected,
|
||
"multiple_choice": 0,
|
||
"true_false": 0,
|
||
"other": 0,
|
||
"converted": 0,
|
||
"failed": 0
|
||
}
|
||
|
||
for i, question in enumerate(selected_questions):
|
||
try:
|
||
# 统计题目类型
|
||
generated_options = question.get("generated_options", {})
|
||
question_type = generated_options.get("question_type", "unknown")
|
||
|
||
if question_type == "multiple_choice":
|
||
conversion_stats["multiple_choice"] += 1
|
||
elif question_type == "true_false":
|
||
conversion_stats["true_false"] += 1
|
||
else:
|
||
conversion_stats["other"] += 1
|
||
|
||
# 转换题目
|
||
converted = convert_to_target_format(question)
|
||
if converted:
|
||
converted_questions.append(converted)
|
||
conversion_stats["converted"] += 1
|
||
else:
|
||
conversion_stats["failed"] += 1
|
||
|
||
except Exception as e:
|
||
print(f"第{i+1}题转换失败: {e}")
|
||
conversion_stats["failed"] += 1
|
||
|
||
print(f"转换完成: {conversion_stats['converted']} 道题目成功转换")
|
||
|
||
# 对转换后的题目进行答案分布平衡
|
||
balance_info = None
|
||
if balance_answers and converted_questions:
|
||
print("\n正在对转换后的题目进行答案分布平衡...")
|
||
|
||
balanced_questions, balance_info = balance_answer_distribution_by_shuffling(
|
||
converted_questions,
|
||
random_seed=random_seed
|
||
)
|
||
|
||
converted_questions = balanced_questions
|
||
conversion_stats["final_count"] = len(converted_questions)
|
||
|
||
# 保存结果
|
||
print("正在保存转换结果...")
|
||
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(converted_questions, f, ensure_ascii=False, indent=2)
|
||
|
||
# 打印最终统计信息
|
||
print(f"\n=== 转换完成!===")
|
||
print(f"选中题目数: {conversion_stats['selected']}")
|
||
print(f"单选题: {conversion_stats['multiple_choice']}")
|
||
print(f"判断题: {conversion_stats['true_false']}")
|
||
print(f"其他类型: {conversion_stats['other']}")
|
||
print(f"成功转换: {conversion_stats['converted']}")
|
||
print(f"转换失败: {conversion_stats['failed']}")
|
||
|
||
if balance_answers and balance_info:
|
||
print(f"答案平衡后: {conversion_stats.get('final_count', conversion_stats['converted'])}")
|
||
print(f"调整题目数: {balance_info['adjustments_made']}")
|
||
print(f"最终转换率: {conversion_stats.get('final_count', conversion_stats['converted'])/conversion_stats['selected']*100:.1f}%")
|
||
else:
|
||
print(f"最终转换率: {conversion_stats['converted']/conversion_stats['selected']*100:.1f}%")
|
||
|
||
print(f"结果已保存到: {output_file}")
|
||
|
||
def validate_converted_questions(questions: List[Dict[str, Any]]) -> Dict[str, int]:
|
||
"""
|
||
验证转换后的题目格式
|
||
"""
|
||
stats = {
|
||
"total": len(questions),
|
||
"valid": 0,
|
||
"invalid": 0,
|
||
"missing_question": 0,
|
||
"invalid_choices": 0,
|
||
"invalid_answer": 0
|
||
}
|
||
|
||
for i, q in enumerate(questions):
|
||
is_valid = True
|
||
|
||
# 检查question字段
|
||
if not q.get("question", "").strip():
|
||
stats["missing_question"] += 1
|
||
is_valid = False
|
||
|
||
# 检查choices字段
|
||
choices = q.get("choices", {})
|
||
text_list = choices.get("text", [])
|
||
label_list = choices.get("label", [])
|
||
|
||
if (len(text_list) != 4 or len(label_list) != 4 or
|
||
label_list != ["A", "B", "C", "D"] or
|
||
any(not str(text).strip() for text in text_list)):
|
||
stats["invalid_choices"] += 1
|
||
is_valid = False
|
||
|
||
# 检查answer字段
|
||
answer = q.get("answer", "")
|
||
if not (answer.startswith("[ANSWER]") and answer.endswith("[/ANSWER]") and
|
||
answer[8:-9] in ["A", "B", "C", "D"]):
|
||
stats["invalid_answer"] += 1
|
||
is_valid = False
|
||
|
||
if is_valid:
|
||
stats["valid"] += 1
|
||
else:
|
||
stats["invalid"] += 1
|
||
|
||
return stats
|
||
|
||
def main():
|
||
"""主函数"""
|
||
# 文件路径配置
|
||
INPUT_FILE = "/home/ubuntu/50T/LYT/MatBench/layer2/PGEE/code/stepy_complete_choice_questions_with_sampling.json"
|
||
OUTPUT_FILE = "/home/ubuntu/50T/LYT/MatBench/layer2/PGEE/code/stepz_final_choice_questions_filtered.json"
|
||
|
||
# 难度选择比例配置
|
||
SELECTION_RATIOS = {
|
||
"hard_early_stop": 1.0, # 困难题选择10%
|
||
"easy_all_correct": 0.35, # 简单题选择3.5%
|
||
"mixed": 0.0, # 混合题选择0%
|
||
"unknown": 0.0 # 未知难度不选择
|
||
}
|
||
|
||
# 随机种子,保证结果可复现
|
||
RANDOM_SEED = 42
|
||
|
||
# 是否启用答案平衡
|
||
BALANCE_ANSWERS = True
|
||
|
||
try:
|
||
# 显示配置信息
|
||
print("=== 难度筛选配置 ===")
|
||
print("选择比例:")
|
||
for difficulty, ratio in SELECTION_RATIOS.items():
|
||
print(f" {difficulty}: {ratio*100:.1f}%")
|
||
print(f"随机种子: {RANDOM_SEED}")
|
||
print(f"启用答案平衡: {BALANCE_ANSWERS}")
|
||
print()
|
||
|
||
# 批量转换(包含难度筛选和答案平衡)
|
||
batch_convert_questions_with_difficulty_filter(
|
||
INPUT_FILE,
|
||
OUTPUT_FILE,
|
||
SELECTION_RATIOS,
|
||
balance_answers=BALANCE_ANSWERS,
|
||
random_seed=RANDOM_SEED
|
||
)
|
||
|
||
# 验证转换结果
|
||
print("\n正在验证转换结果...")
|
||
with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
|
||
result_data = json.load(f)
|
||
|
||
validation_stats = validate_converted_questions(result_data)
|
||
|
||
print(f"\n=== 验证结果 ===")
|
||
print(f"总题目数: {validation_stats['total']}")
|
||
print(f"格式正确: {validation_stats['valid']}")
|
||
print(f"格式错误: {validation_stats['invalid']}")
|
||
|
||
if validation_stats['invalid'] > 0:
|
||
print(f" 缺少题目: {validation_stats['missing_question']}")
|
||
print(f" 选项格式错误: {validation_stats['invalid_choices']}")
|
||
print(f" 答案格式错误: {validation_stats['invalid_answer']}")
|
||
|
||
print(f"格式正确率: {validation_stats['valid']/validation_stats['total']*100:.1f}%")
|
||
|
||
# 验证最终答案分布
|
||
if BALANCE_ANSWERS:
|
||
print(f"\n=== 最终答案分布验证 ===")
|
||
final_answers = []
|
||
for q in result_data:
|
||
answer = extract_answer_from_question(q)
|
||
if answer:
|
||
final_answers.append(answer)
|
||
|
||
final_counter = Counter(final_answers)
|
||
total = len(final_answers)
|
||
|
||
for answer in ["A", "B", "C", "D"]:
|
||
count = final_counter.get(answer, 0)
|
||
ratio = count / total if total > 0 else 0
|
||
print(f" {answer}: {count} ({ratio*100:.1f}%)")
|
||
|
||
except Exception as e:
|
||
print(f"程序执行失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|