Files
MatBench/layer2/process/step1and2.py
2025-05-28 11:00:24 +08:00

132 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
对821道英文问题进行处理
1. 判断是否包含多个子问题,将问题拆分为完整子问题(去掉推理过程,只保留最后结果)
2. 判断题目类型
3. 将题目做成选择题
对计算题,在数值附近随机生成三个相似答案作为错误选项
对简答题,与标准答案最相近的其他问题的答案作为三个错误选项
4. 将正确和错误选项随机打乱生成ABCD选择题的模型
5. 添加prompt并将选择题包裹在[ANSWER]<answer>[/ANSWER]标签中
6. 模型打分
"""
import json
import time
from openai import OpenAI
import re
from prompts import SINGLE_QUESTION_PROMPTS, QA_TYPE_PROMPTS, ONLY_ANSWER_PROMPTS
API_KEY="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
BASE_URL="https://vip.apiyi.com/v1"
MODEL_DEEPSEEK_V3="deepseek-chat"
def load_data(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
def process_response(response):
"""Extract and parse JSON from a response."""
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
json_str = json_match.group(1) if json_match else response.strip()
json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str)
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
return json_str
def save_data(data, output_file):
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def split_complex_question(question, answer):
client = OpenAI(api_key = API_KEY,base_url = BASE_URL)
try:
response = client.chat.completions.create(
model= MODEL_DEEPSEEK_V3,
messages=[
{"role": "system", "content": "You are an expert in decomposing complex technical questions into independent sub-questions and providing corresponding complete answers with preserved context, precision, and technical terminology. "},
{"role": "user", "content": SINGLE_QUESTION_PROMPTS.replace("{question}",question).replace("{answer}",answer)}
],
stream = False,
temperature = 0
)
result = response.choices[0].message.content.strip()
# print(result)
return 1 if "It's a single issue." in result else json.loads(process_response(result))
except Exception as e:
print(f"API调用错误: {e}")
return [{"question": question, "answer": answer}]
def single_question_process(data):
single_question_data = []
total = len(data)
for i, item in enumerate(data):
print(f"处理第 {i+1}/{total} 条数据...")
question = item["question"]
answer = item["answer"]
split_data = split_complex_question(question, answer)
if isinstance(split_data, list):
for q_data in split_data:
single_question_data.append({
"idx":item["idx"],
"question": q_data["question"],
"answer": q_data["answer"]
})
else:
single_question_data.append({
"idx":item["idx"],
"question": question,
"answer": answer
})
if (i+1) % 10 == 0:
time.sleep(2)
return single_question_data
def classify_qa_type(question, answer):
client = OpenAI(api_key = API_KEY,base_url = BASE_URL)
try:
response = client.chat.completions.create(
model = MODEL_DEEPSEEK_V3,
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": QA_TYPE_PROMPTS.replace("{question}",question).replace("{answer}",answer)}
],
stream=False
)
result = response.choices[0].message.content.strip().lower()
# print(result)
return {"1": "Calculation", "2": "Multiple choice", "3": "True/False"}.get(result, "Other")
except Exception as e:
print(f"API调用错误: {e}")
return "Other"
def qa_type_process(data):
total = len(data)
for i, item in enumerate(data):
print(f"处理第 {i+1}/{total} 条数据...")
question = item["question"]
answer = item["answer"]
label = classify_qa_type(question, answer)
item["type"] = label
if (i+1) % 10 == 0:
time.sleep(2)
return data
def main():
input_file = "/home/ubuntu/50T/fsy/layer2/QA/code/821.json"
output_file = "/home/ubuntu/50T/fsy/layer2/QA/code/processed_data.json"
data = load_data(input_file)
# step:1
single_question_data = single_question_process(data)
# step:2
qa_type_data = qa_type_process(single_question_data)
# step:3
# save_data(processed_data, output_file)
print(f"处理完成,结果已保存到 {output_file}")
if __name__ == "__main__":
main()