90 lines
3.4 KiB
Python
90 lines
3.4 KiB
Python
"""
|
||
对821道英文问题进行处理
|
||
1. 判断是否包含多个子问题,将问题拆分为完整子问题(去掉推理过程,只保留最后结果)
|
||
2. 判断题目类型
|
||
3. 将题目做成选择题
|
||
对计算题,在数值附近随机生成三个相似答案作为错误选项
|
||
对简答题,与标准答案最相近的其他问题的答案作为三个错误选项
|
||
4. 将正确和错误选项随机打乱,生成ABCD选择题的模型
|
||
5. 添加prompt,并将选择题包裹在[ANSWER]<answer>[/ANSWER]标签中
|
||
6. 模型打分
|
||
"""
|
||
import json
|
||
import re
|
||
import random
|
||
import copy
|
||
|
||
def generate_wrong_answers(json_file_path):
|
||
# 读取 JSON 文件
|
||
with open(json_file_path, 'r', encoding='utf-8') as file:
|
||
data = json.load(file)
|
||
|
||
# 处理每个数据项
|
||
for item in data:
|
||
if item['type'] == 1: # 判断是否为计算题
|
||
answer = item['answer']
|
||
if any(char.isdigit() for char in answer):
|
||
wrong_answers = []
|
||
for _ in range(3):
|
||
wrong_answers.append(generate_wrong_answer(answer))
|
||
item['wrong_answers_1'] = wrong_answers[0]
|
||
item['wrong_answers_2'] = wrong_answers[1]
|
||
item['wrong_answers_3'] = wrong_answers[2]
|
||
|
||
with open(json_file_path.replace('.json', '_with_wrong_answers.json'), 'w', encoding='utf-8') as file:
|
||
json.dump(data, file, ensure_ascii=False, indent=2)
|
||
|
||
return data
|
||
|
||
def generate_wrong_answer(correct_answer):
|
||
# 强化版正则表达式:支持普通数、科学计数法、Unicode负号、LaTeX指数、千位逗号
|
||
number_pattern = (
|
||
r'([-+]?\d{1,3}(?:,\d{3})*(?:\.\d+)?|\d*\.?\d+)' # 主数字部分
|
||
r'(?:\s*[×x*]?\s*10(?:\^|\^{|{)?[-−⁻]?\d+(?:\})?)?' # 科学计数部分,可选
|
||
)
|
||
matches = list(re.finditer(number_pattern, correct_answer, re.IGNORECASE))
|
||
if not matches:
|
||
return correct_answer # 没找到数字,返回原文
|
||
|
||
wrong_answer = correct_answer
|
||
for match in matches[::-1]: # 反向替换防止位置偏移
|
||
full_match = match.group(0)
|
||
base = match.group(1).replace(',', '') # 去除逗号用于数值运算
|
||
|
||
try:
|
||
# 转换成 float
|
||
base_value = float(base)
|
||
perturbed_value = perturb_number(base_value)
|
||
|
||
# 保留原来的指数部分(如 x 10^6),只替换数字
|
||
wrong_value_str = full_match.replace(match.group(1), format_similar(base, perturbed_value))
|
||
start, end = match.span()
|
||
wrong_answer = wrong_answer[:start] + wrong_value_str + wrong_answer[end:]
|
||
except:
|
||
continue
|
||
|
||
return wrong_answer
|
||
|
||
def perturb_number(value):
|
||
# 根据数量级添加扰动(高斯扰动 + 偏差)
|
||
magnitude = abs(value)
|
||
noise = random.uniform(0.03, 0.15) # 扰动比例 3%~15%
|
||
direction = random.choice([-1, 1])
|
||
new_value = value + direction * magnitude * noise
|
||
|
||
# 防止扰动结果为 0 或变号
|
||
if abs(new_value) < 1e-10:
|
||
new_value = value * 1.1
|
||
return new_value
|
||
|
||
def format_similar(original_str, value):
|
||
# 保留与原始字符串小数位一致
|
||
if '.' in original_str:
|
||
decimal_places = len(original_str.split('.')[-1].rstrip('^}')) # 忽略 ^10^6 中的后缀
|
||
return f"{value:.{decimal_places}f}"
|
||
else:
|
||
return str(int(round(value)))
|
||
|
||
if __name__ == "__main__":
|
||
data = generate_wrong_answers('/home/ubuntu/50T/fsy/benchmark/4is_type.json')
|