layer2 commit
This commit is contained in:
89
layer2/process/step3.py
Normal file
89
layer2/process/step3.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
对821道英文问题进行处理
|
||||
1. 判断是否包含多个子问题,将问题拆分为完整子问题(去掉推理过程,只保留最后结果)
|
||||
2. 判断题目类型
|
||||
3. 将题目做成选择题
|
||||
对计算题,在数值附近随机生成三个相似答案作为错误选项
|
||||
对简答题,与标准答案最相近的其他问题的答案作为三个错误选项
|
||||
4. 将正确和错误选项随机打乱,生成ABCD选择题的模型
|
||||
5. 添加prompt,并将选择题包裹在[ANSWER]<answer>[/ANSWER]标签中
|
||||
6. 模型打分
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import random
|
||||
import copy
|
||||
|
||||
def generate_wrong_answers(json_file_path):
|
||||
# 读取 JSON 文件
|
||||
with open(json_file_path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
|
||||
# 处理每个数据项
|
||||
for item in data:
|
||||
if item['type'] == 1: # 判断是否为计算题
|
||||
answer = item['answer']
|
||||
if any(char.isdigit() for char in answer):
|
||||
wrong_answers = []
|
||||
for _ in range(3):
|
||||
wrong_answers.append(generate_wrong_answer(answer))
|
||||
item['wrong_answers_1'] = wrong_answers[0]
|
||||
item['wrong_answers_2'] = wrong_answers[1]
|
||||
item['wrong_answers_3'] = wrong_answers[2]
|
||||
|
||||
with open(json_file_path.replace('.json', '_with_wrong_answers.json'), 'w', encoding='utf-8') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=2)
|
||||
|
||||
return data
|
||||
|
||||
def generate_wrong_answer(correct_answer):
|
||||
# 强化版正则表达式:支持普通数、科学计数法、Unicode负号、LaTeX指数、千位逗号
|
||||
number_pattern = (
|
||||
r'([-+]?\d{1,3}(?:,\d{3})*(?:\.\d+)?|\d*\.?\d+)' # 主数字部分
|
||||
r'(?:\s*[×x*]?\s*10(?:\^|\^{|{)?[-−⁻]?\d+(?:\})?)?' # 科学计数部分,可选
|
||||
)
|
||||
matches = list(re.finditer(number_pattern, correct_answer, re.IGNORECASE))
|
||||
if not matches:
|
||||
return correct_answer # 没找到数字,返回原文
|
||||
|
||||
wrong_answer = correct_answer
|
||||
for match in matches[::-1]: # 反向替换防止位置偏移
|
||||
full_match = match.group(0)
|
||||
base = match.group(1).replace(',', '') # 去除逗号用于数值运算
|
||||
|
||||
try:
|
||||
# 转换成 float
|
||||
base_value = float(base)
|
||||
perturbed_value = perturb_number(base_value)
|
||||
|
||||
# 保留原来的指数部分(如 x 10^6),只替换数字
|
||||
wrong_value_str = full_match.replace(match.group(1), format_similar(base, perturbed_value))
|
||||
start, end = match.span()
|
||||
wrong_answer = wrong_answer[:start] + wrong_value_str + wrong_answer[end:]
|
||||
except:
|
||||
continue
|
||||
|
||||
return wrong_answer
|
||||
|
||||
def perturb_number(value):
|
||||
# 根据数量级添加扰动(高斯扰动 + 偏差)
|
||||
magnitude = abs(value)
|
||||
noise = random.uniform(0.03, 0.15) # 扰动比例 3%~15%
|
||||
direction = random.choice([-1, 1])
|
||||
new_value = value + direction * magnitude * noise
|
||||
|
||||
# 防止扰动结果为 0 或变号
|
||||
if abs(new_value) < 1e-10:
|
||||
new_value = value * 1.1
|
||||
return new_value
|
||||
|
||||
def format_similar(original_str, value):
|
||||
# 保留与原始字符串小数位一致
|
||||
if '.' in original_str:
|
||||
decimal_places = len(original_str.split('.')[-1].rstrip('^}')) # 忽略 ^10^6 中的后缀
|
||||
return f"{value:.{decimal_places}f}"
|
||||
else:
|
||||
return str(int(round(value)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = generate_wrong_answers('/home/ubuntu/50T/fsy/benchmark/4is_type.json')
|
||||
Reference in New Issue
Block a user