497 lines
17 KiB
Python
497 lines
17 KiB
Python
import json
|
||
from typing import Dict, Any, List, Optional, Tuple
|
||
import random
|
||
|
||
def convert_to_target_format(source_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
将源JSON格式转换为目标格式
|
||
|
||
Args:
|
||
source_data: 源数据字典
|
||
|
||
Returns:
|
||
转换后的数据字典,如果不是单选题则返回None
|
||
"""
|
||
# 检查是否有generated_options字段
|
||
if "generated_options" not in source_data:
|
||
return None
|
||
|
||
generated_options = source_data["generated_options"]
|
||
|
||
# 只处理单选题,跳过判断题
|
||
if generated_options.get("question_type") != "multiple_choice":
|
||
return None
|
||
|
||
# 获取题目内容
|
||
question = source_data.get("choice_question", "")
|
||
if not question:
|
||
return None
|
||
|
||
# 获取选项
|
||
options = generated_options.get("options", {})
|
||
if len(options) != 4:
|
||
return None
|
||
|
||
# 获取正确答案
|
||
correct_answer = generated_options.get("correct_answer", "")
|
||
if correct_answer not in ["A", "B", "C", "D"]:
|
||
return None
|
||
|
||
# 构建目标格式
|
||
target_data = {
|
||
"question": question,
|
||
"choices": {
|
||
"text": [
|
||
options.get("A", ""),
|
||
options.get("B", ""),
|
||
options.get("C", ""),
|
||
options.get("D", "")
|
||
],
|
||
"label": ["A", "B", "C", "D"]
|
||
},
|
||
"answer": f"[ANSWER]{correct_answer}[/ANSWER]",
|
||
"prompt": "You are an expert in materials science. Please answer the following materials science question by selecting the correct option. You MUST include the letter of the correct answer at the end of your response within the following tags: [ANSWER] and [/ANSWER]. For example: [ANSWER]A[/ANSWER]."
|
||
}
|
||
|
||
return target_data
|
||
|
||
def classify_questions_by_difficulty(questions: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
|
||
"""
|
||
按难度分类题目
|
||
|
||
Args:
|
||
questions: 题目列表
|
||
|
||
Returns:
|
||
按难度分类的题目字典
|
||
"""
|
||
difficulty_groups = {
|
||
"hard_early_stop": [], # 困难题(答错后早停)
|
||
"easy_all_correct": [], # 简单题(所有采样都答对)
|
||
"mixed": [], # 混合题(部分对部分错)
|
||
"unknown": [] # 未知难度
|
||
}
|
||
|
||
for question in questions:
|
||
generated_options = question.get("generated_options", {})
|
||
sampling_summary = generated_options.get("sampling_summary", {})
|
||
difficulty_label = sampling_summary.get("difficulty_label", "unknown")
|
||
|
||
if difficulty_label in difficulty_groups:
|
||
difficulty_groups[difficulty_label].append(question)
|
||
else:
|
||
difficulty_groups["unknown"].append(question)
|
||
|
||
return difficulty_groups
|
||
|
||
def select_questions_by_ratio(difficulty_groups: Dict[str, List[Dict[str, Any]]],
|
||
selection_ratios: Dict[str, float],
|
||
random_seed: Optional[int] = None) -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
|
||
"""
|
||
按比例选择题目
|
||
|
||
Args:
|
||
difficulty_groups: 按难度分类的题目
|
||
selection_ratios: 各难度等级的选择比例 (0.0-1.0)
|
||
random_seed: 随机种子
|
||
|
||
Returns:
|
||
选中的题目列表和选择统计信息
|
||
"""
|
||
if random_seed is not None:
|
||
random.seed(random_seed)
|
||
|
||
selected_questions = []
|
||
selection_stats = {}
|
||
|
||
for difficulty, questions in difficulty_groups.items():
|
||
total_count = len(questions)
|
||
ratio = selection_ratios.get(difficulty, 0.0)
|
||
|
||
# 计算要选择的题目数量
|
||
if ratio <= 0:
|
||
selected_count = 0
|
||
elif ratio >= 1:
|
||
selected_count = total_count
|
||
else:
|
||
selected_count = int(total_count * ratio)
|
||
|
||
# 随机选择题目
|
||
if selected_count > 0 and total_count > 0:
|
||
if selected_count >= total_count:
|
||
selected = questions
|
||
else:
|
||
selected = random.sample(questions, selected_count)
|
||
selected_questions.extend(selected)
|
||
else:
|
||
selected = []
|
||
|
||
# 记录统计信息
|
||
selection_stats[difficulty] = {
|
||
"total": total_count,
|
||
"selected": len(selected),
|
||
"ratio_target": ratio,
|
||
"ratio_actual": len(selected) / total_count if total_count > 0 else 0
|
||
}
|
||
|
||
# 打乱最终题目顺序
|
||
random.shuffle(selected_questions)
|
||
|
||
return selected_questions, selection_stats
|
||
|
||
def batch_convert_questions_with_difficulty_filter(input_file: str,
|
||
output_file: str,
|
||
selection_ratios: Dict[str, float],
|
||
random_seed: Optional[int] = None) -> None:
|
||
"""
|
||
批量转换题目格式,支持按难度筛选
|
||
|
||
Args:
|
||
input_file: 输入文件路径
|
||
output_file: 输出文件路径
|
||
selection_ratios: 各难度等级的选择比例
|
||
random_seed: 随机种子
|
||
"""
|
||
print("正在加载数据...")
|
||
|
||
# 判断输入文件格式
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
# 处理两种可能的输入格式
|
||
if isinstance(data, dict) and "questions" in data:
|
||
# 格式:{"questions": [...], "其他字段": ...}
|
||
source_questions = data["questions"]
|
||
print(f"检测到完整格式数据,包含其他元数据")
|
||
elif isinstance(data, list):
|
||
# 格式:[{题目1}, {题目2}, ...]
|
||
source_questions = data
|
||
print(f"检测到题目列表格式")
|
||
else:
|
||
raise ValueError("不支持的输入文件格式")
|
||
|
||
print(f"加载了 {len(source_questions)} 道题目")
|
||
|
||
# 按难度分类题目
|
||
print("正在按难度分类题目...")
|
||
difficulty_groups = classify_questions_by_difficulty(source_questions)
|
||
|
||
print("题目难度分布:")
|
||
total_multiple_choice = 0
|
||
for difficulty, questions in difficulty_groups.items():
|
||
# 统计该难度下的单选题数量
|
||
mc_count = sum(1 for q in questions
|
||
if q.get("generated_options", {}).get("question_type") == "multiple_choice")
|
||
total_multiple_choice += mc_count
|
||
print(f" {difficulty}: {len(questions)} 道总题目, {mc_count} 道单选题")
|
||
|
||
print(f"可转换的单选题总数: {total_multiple_choice}")
|
||
|
||
# 按比例选择题目
|
||
print("\n正在按比例选择题目...")
|
||
print("选择比例设置:")
|
||
for difficulty, ratio in selection_ratios.items():
|
||
if difficulty in difficulty_groups:
|
||
print(f" {difficulty}: {ratio*100:.1f}%")
|
||
|
||
selected_questions, selection_stats = select_questions_by_ratio(
|
||
difficulty_groups, selection_ratios, random_seed
|
||
)
|
||
|
||
print(f"\n题目选择结果:")
|
||
total_selected = 0
|
||
for difficulty, stats in selection_stats.items():
|
||
print(f" {difficulty}:")
|
||
print(f" 总数: {stats['total']}")
|
||
print(f" 选中: {stats['selected']}")
|
||
print(f" 目标比例: {stats['ratio_target']*100:.1f}%")
|
||
print(f" 实际比例: {stats['ratio_actual']*100:.1f}%")
|
||
total_selected += stats['selected']
|
||
|
||
print(f"总共选中: {total_selected} 道题目")
|
||
|
||
# 转换选中的题目
|
||
print("\n正在转换题目格式...")
|
||
converted_questions = []
|
||
conversion_stats = {
|
||
"selected": total_selected,
|
||
"multiple_choice": 0,
|
||
"true_false": 0,
|
||
"other": 0,
|
||
"converted": 0,
|
||
"failed": 0
|
||
}
|
||
|
||
for i, question in enumerate(selected_questions):
|
||
try:
|
||
# 统计题目类型
|
||
generated_options = question.get("generated_options", {})
|
||
question_type = generated_options.get("question_type", "unknown")
|
||
|
||
if question_type == "multiple_choice":
|
||
conversion_stats["multiple_choice"] += 1
|
||
elif question_type == "true_false":
|
||
conversion_stats["true_false"] += 1
|
||
else:
|
||
conversion_stats["other"] += 1
|
||
|
||
# 转换题目
|
||
converted = convert_to_target_format(question)
|
||
if converted:
|
||
converted_questions.append(converted)
|
||
conversion_stats["converted"] += 1
|
||
else:
|
||
conversion_stats["failed"] += 1
|
||
|
||
except Exception as e:
|
||
print(f"第{i+1}题转换失败: {e}")
|
||
conversion_stats["failed"] += 1
|
||
|
||
# 保存结果
|
||
print("正在保存转换结果...")
|
||
output_data = {
|
||
"questions": converted_questions,
|
||
"metadata": {
|
||
"total_original_questions": len(source_questions),
|
||
"selection_ratios": selection_ratios,
|
||
"selection_stats": selection_stats,
|
||
"conversion_stats": conversion_stats,
|
||
"random_seed": random_seed
|
||
}
|
||
}
|
||
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(converted_questions, f, ensure_ascii=False, indent=2)
|
||
|
||
# 打印最终统计信息
|
||
print(f"\n转换完成!")
|
||
print(f"选中题目数: {conversion_stats['selected']}")
|
||
print(f"单选题: {conversion_stats['multiple_choice']}")
|
||
print(f"判断题: {conversion_stats['true_false']}")
|
||
print(f"其他类型: {conversion_stats['other']}")
|
||
print(f"成功转换: {conversion_stats['converted']}")
|
||
print(f"转换失败: {conversion_stats['failed']}")
|
||
print(f"最终转换率: {conversion_stats['converted']/conversion_stats['selected']*100:.1f}%")
|
||
print(f"结果已保存到: {output_file}")
|
||
|
||
def validate_converted_questions(questions: List[Dict[str, Any]]) -> Dict[str, int]:
|
||
"""
|
||
验证转换后的题目格式
|
||
|
||
Args:
|
||
questions: 转换后的题目列表
|
||
|
||
Returns:
|
||
验证统计信息
|
||
"""
|
||
stats = {
|
||
"total": len(questions),
|
||
"valid": 0,
|
||
"invalid": 0,
|
||
"missing_question": 0,
|
||
"invalid_choices": 0,
|
||
"invalid_answer": 0
|
||
}
|
||
|
||
for i, q in enumerate(questions):
|
||
is_valid = True
|
||
|
||
# 检查question字段
|
||
if not q.get("question", "").strip():
|
||
stats["missing_question"] += 1
|
||
is_valid = False
|
||
|
||
# 检查choices字段
|
||
choices = q.get("choices", {})
|
||
text_list = choices.get("text", [])
|
||
label_list = choices.get("label", [])
|
||
|
||
if (len(text_list) != 4 or len(label_list) != 4 or
|
||
label_list != ["A", "B", "C", "D"] or
|
||
any(not str(text).strip() for text in text_list)):
|
||
stats["invalid_choices"] += 1
|
||
is_valid = False
|
||
|
||
# 检查answer字段
|
||
answer = q.get("answer", "")
|
||
if not (answer.startswith("[ANSWER]") and answer.endswith("[/ANSWER]") and
|
||
answer[8:-9] in ["A", "B", "C", "D"]):
|
||
stats["invalid_answer"] += 1
|
||
is_valid = False
|
||
|
||
if is_valid:
|
||
stats["valid"] += 1
|
||
else:
|
||
stats["invalid"] += 1
|
||
print(f"第{i+1}题格式无效")
|
||
|
||
return stats
|
||
|
||
def create_difficulty_config_template():
|
||
"""创建难度配置模板"""
|
||
template = {
|
||
"hard_early_stop": 1.0, # 困难题选择100%
|
||
"easy_all_correct": 0.1, # 简单题选择10%
|
||
"mixed": 0.5, # 混合题选择50%
|
||
"unknown": 0.0 # 未知难度题目选择0%
|
||
}
|
||
|
||
print("难度选择比例配置模板:")
|
||
print(json.dumps(template, indent=2))
|
||
print("\n说明:")
|
||
print("- 1.0 = 100% (全部选择)")
|
||
print("- 0.5 = 50% (选择一半)")
|
||
print("- 0.1 = 10% (选择10%)")
|
||
print("- 0.0 = 0% (不选择)")
|
||
|
||
return template
|
||
|
||
def main():
|
||
"""主函数"""
|
||
# 文件路径配置
|
||
INPUT_FILE = "/home/ubuntu/50T/LYT/MatBench/layer2/PGEE/code/stepy_complete_choice_questions_with_sampling.json"
|
||
OUTPUT_FILE = "/home/ubuntu/50T/LYT/MatBench/layer2/PGEE/code/stepz_final_choice_questions_filtered.json"
|
||
|
||
# 难度选择比例配置
|
||
# 可以根据需要调整这些比例
|
||
SELECTION_RATIOS = {
|
||
"hard_early_stop": 1.0, # 困难题选择100% (全部)
|
||
"easy_all_correct": 0.0, # 简单题选择10%
|
||
"mixed": 0.0, # 混合题选择30%
|
||
"unknown": 0.0 # 未知难度不选择
|
||
}
|
||
|
||
# 随机种子,保证结果可复现
|
||
RANDOM_SEED = 42
|
||
|
||
try:
|
||
# 显示配置信息
|
||
print("=== 难度筛选配置 ===")
|
||
print("选择比例:")
|
||
for difficulty, ratio in SELECTION_RATIOS.items():
|
||
print(f" {difficulty}: {ratio*100:.1f}%")
|
||
print(f"随机种子: {RANDOM_SEED}")
|
||
print()
|
||
|
||
# 批量转换(包含难度筛选)
|
||
batch_convert_questions_with_difficulty_filter(
|
||
INPUT_FILE,
|
||
OUTPUT_FILE,
|
||
SELECTION_RATIOS,
|
||
RANDOM_SEED
|
||
)
|
||
|
||
# 验证转换结果
|
||
print("\n正在验证转换结果...")
|
||
with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
|
||
result_data = json.load(f)
|
||
|
||
# 检查输出文件格式
|
||
if "questions" in result_data:
|
||
converted_questions = result_data["questions"]
|
||
metadata = result_data.get("metadata", {})
|
||
|
||
print("\n=== 元数据信息 ===")
|
||
if metadata:
|
||
print(f"原始题目总数: {metadata.get('total_original_questions', 'N/A')}")
|
||
print(f"随机种子: {metadata.get('random_seed', 'N/A')}")
|
||
else:
|
||
converted_questions = result_data
|
||
|
||
validation_stats = validate_converted_questions(converted_questions)
|
||
|
||
print(f"\n=== 验证结果 ===")
|
||
print(f"总题目数: {validation_stats['total']}")
|
||
print(f"格式正确: {validation_stats['valid']}")
|
||
print(f"格式错误: {validation_stats['invalid']}")
|
||
|
||
if validation_stats['invalid'] > 0:
|
||
print(f" 缺少题目: {validation_stats['missing_question']}")
|
||
print(f" 选项格式错误: {validation_stats['invalid_choices']}")
|
||
print(f" 答案格式错误: {validation_stats['invalid_answer']}")
|
||
|
||
print(f"格式正确率: {validation_stats['valid']/validation_stats['total']*100:.1f}%")
|
||
|
||
except Exception as e:
|
||
print(f"程序执行失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
def interactive_config():
|
||
"""交互式配置选择比例"""
|
||
print("=== 交互式难度选择配置 ===")
|
||
|
||
difficulties = ["hard_early_stop", "easy_all_correct", "mixed", "unknown"]
|
||
difficulty_names = {
|
||
"hard_early_stop": "困难题(答错早停)",
|
||
"easy_all_correct": "简单题(全部答对)",
|
||
"mixed": "混合题(部分对错)",
|
||
"unknown": "未知难度题"
|
||
}
|
||
|
||
ratios = {}
|
||
|
||
for diff in difficulties:
|
||
while True:
|
||
try:
|
||
ratio_input = input(f"请输入{difficulty_names.get(diff, diff)}的选择比例 (0-100%): ").strip()
|
||
if ratio_input.endswith('%'):
|
||
ratio_input = ratio_input[:-1]
|
||
|
||
ratio_percent = float(ratio_input)
|
||
if 0 <= ratio_percent <= 100:
|
||
ratios[diff] = ratio_percent / 100.0
|
||
break
|
||
else:
|
||
print("请输入0-100之间的数值")
|
||
except ValueError:
|
||
print("请输入有效的数值")
|
||
|
||
print("\n配置结果:")
|
||
for diff, ratio in ratios.items():
|
||
print(f" {difficulty_names.get(diff, diff)}: {ratio*100:.1f}%")
|
||
|
||
return ratios
|
||
|
||
def test_difficulty_distribution(input_file: str):
|
||
"""测试文件中的难度分布"""
|
||
print(f"正在分析文件难度分布: {input_file}")
|
||
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
# 处理两种可能的输入格式
|
||
if isinstance(data, dict) and "questions" in data:
|
||
questions = data["questions"]
|
||
elif isinstance(data, list):
|
||
questions = data
|
||
else:
|
||
print("不支持的文件格式")
|
||
return
|
||
|
||
difficulty_groups = classify_questions_by_difficulty(questions)
|
||
|
||
print(f"\n难度分布分析:")
|
||
print(f"总题目数: {len(questions)}")
|
||
|
||
for difficulty, question_list in difficulty_groups.items():
|
||
mc_count = sum(1 for q in question_list
|
||
if q.get("generated_options", {}).get("question_type") == "multiple_choice")
|
||
print(f" {difficulty}:")
|
||
print(f" 总数: {len(question_list)}")
|
||
print(f" 单选题: {mc_count}")
|
||
print(f" 占比: {len(question_list)/len(questions)*100:.1f}%")
|
||
|
||
if __name__ == "__main__":
|
||
# 可以先测试难度分布
|
||
# test_difficulty_distribution("/path/to/your/input/file.json")
|
||
|
||
# 可以使用交互式配置
|
||
# ratios = interactive_config()
|
||
|
||
# 运行主程序
|
||
main()
|
||
|
||
# 显示配置模板
|
||
# create_difficulty_config_template()
|