Files
MatBench/layer2/process/step3.py
2025-05-28 11:00:24 +08:00

90 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
对821道英文问题进行处理
1. 判断是否包含多个子问题,将问题拆分为完整子问题(去掉推理过程,只保留最后结果)
2. 判断题目类型
3. 将题目做成选择题
对计算题,在数值附近随机生成三个相似答案作为错误选项
对简答题,与标准答案最相近的其他问题的答案作为三个错误选项
4. 将正确和错误选项随机打乱生成ABCD选择题的模型
5. 添加prompt并将选择题包裹在[ANSWER]<answer>[/ANSWER]标签中
6. 模型打分
"""
import json
import re
import random
import copy
def generate_wrong_answers(json_file_path):
# 读取 JSON 文件
with open(json_file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
# 处理每个数据项
for item in data:
if item['type'] == 1: # 判断是否为计算题
answer = item['answer']
if any(char.isdigit() for char in answer):
wrong_answers = []
for _ in range(3):
wrong_answers.append(generate_wrong_answer(answer))
item['wrong_answers_1'] = wrong_answers[0]
item['wrong_answers_2'] = wrong_answers[1]
item['wrong_answers_3'] = wrong_answers[2]
with open(json_file_path.replace('.json', '_with_wrong_answers.json'), 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=2)
return data
def generate_wrong_answer(correct_answer):
# 强化版正则表达式支持普通数、科学计数法、Unicode负号、LaTeX指数、千位逗号
number_pattern = (
r'([-+]?\d{1,3}(?:,\d{3})*(?:\.\d+)?|\d*\.?\d+)' # 主数字部分
r'(?:\s*[×x*]?\s*10(?:\^|\^{|{)?[-−⁻]?\d+(?:\})?)?' # 科学计数部分,可选
)
matches = list(re.finditer(number_pattern, correct_answer, re.IGNORECASE))
if not matches:
return correct_answer # 没找到数字,返回原文
wrong_answer = correct_answer
for match in matches[::-1]: # 反向替换防止位置偏移
full_match = match.group(0)
base = match.group(1).replace(',', '') # 去除逗号用于数值运算
try:
# 转换成 float
base_value = float(base)
perturbed_value = perturb_number(base_value)
# 保留原来的指数部分(如 x 10^6只替换数字
wrong_value_str = full_match.replace(match.group(1), format_similar(base, perturbed_value))
start, end = match.span()
wrong_answer = wrong_answer[:start] + wrong_value_str + wrong_answer[end:]
except:
continue
return wrong_answer
def perturb_number(value):
# 根据数量级添加扰动(高斯扰动 + 偏差)
magnitude = abs(value)
noise = random.uniform(0.03, 0.15) # 扰动比例 3%~15%
direction = random.choice([-1, 1])
new_value = value + direction * magnitude * noise
# 防止扰动结果为 0 或变号
if abs(new_value) < 1e-10:
new_value = value * 1.1
return new_value
def format_similar(original_str, value):
# 保留与原始字符串小数位一致
if '.' in original_str:
decimal_places = len(original_str.split('.')[-1].rstrip('^}')) # 忽略 ^10^6 中的后缀
return f"{value:.{decimal_places}f}"
else:
return str(int(round(value)))
if __name__ == "__main__":
data = generate_wrong_answers('/home/ubuntu/50T/fsy/benchmark/4is_type.json')