130 lines
4.6 KiB
Python
130 lines
4.6 KiB
Python
"""
|
||
0. 将问题从xls提取为json
|
||
1. 将问题进行拆分
|
||
2. 翻译成英文
|
||
3. 去重
|
||
4. 使用大模型进行难度评估和筛选
|
||
"""
|
||
from openai import OpenAI
|
||
import json
|
||
import numpy as np
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
import pickle
|
||
from prompts import CLEAN_PROMPTS, SELECT_QUESTION_PROMPT
|
||
|
||
API_KEY="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
|
||
BASE_URL="https://vip.apiyi.com/v1"
|
||
MODEL_GPT="text-embedding-ada-002"
|
||
MODELS = ["deepseek-reasoner", "claude-3-7-sonnet-20250219", "qwen-max", "deepseek-chat", "gemini-pro"]
|
||
|
||
def get_embedding(text):
|
||
client = OpenAI(api_key= API_KEY, base_url= BASE_URL)
|
||
response = client.embeddings.create(
|
||
model = MODEL_GPT,
|
||
input = text
|
||
)
|
||
return response.data[0].embedding
|
||
|
||
def compute_embeddings(texts):
|
||
embeddings = []
|
||
for i,text in enumerate(texts):
|
||
print("正在处理第{}/{}条".format(i+1,len(texts)))
|
||
embeddings.append(get_embedding(text))
|
||
return np.array(embeddings)
|
||
|
||
def load_json(file_path):
|
||
with open(file_path, 'r', encoding='utf-8') as file:
|
||
return json.load(file)
|
||
|
||
def save_json(data, file_path):
|
||
with open(file_path, 'w', encoding='utf-8') as file:
|
||
json.dump(data, file, ensure_ascii=False, indent=2)
|
||
|
||
def save_embeddings(embeddings, file_path):
|
||
with open(file_path, 'wb') as file:
|
||
pickle.dump(embeddings, file)
|
||
|
||
def load_embeddings(file_path):
|
||
with open(file_path, 'rb') as file:
|
||
return pickle.load(file)
|
||
|
||
def deduplicate_qa(data, save_vectors=True):
|
||
questions = [item['question'] for item in data]
|
||
|
||
# 生成嵌入向量
|
||
question_embeddings = compute_embeddings(questions)
|
||
|
||
if save_vectors:
|
||
print("保存问题的嵌入向量...")
|
||
save_embeddings(question_embeddings, '/home/ubuntu/50T/fsy/layer2/QA/question_embeddings.pkl')
|
||
|
||
# 去重逻辑
|
||
filtered_data, duplicate_entries = de_emphasize(question_embeddings,data)
|
||
|
||
return filtered_data, duplicate_entries
|
||
|
||
def deduplicate_qa_pkl(data,pkl_path):
|
||
|
||
question_embeddings = load_embeddings(pkl_path)
|
||
filtered_data, duplicate_entries = de_emphasize(question_embeddings,data)
|
||
|
||
return filtered_data, duplicate_entries
|
||
|
||
def de_emphasize(question_embeddings,data,similarity_threshold=0.99):
|
||
|
||
unique_indices = []
|
||
duplicate_entries = [] # 用来保存重复的问答对信息
|
||
for i in range(len(data)):
|
||
print("正在处理第{}/{}条".format(i+1,len(data)))
|
||
duplicate_found = False
|
||
for j in unique_indices:
|
||
# 计算问题的语义相似性
|
||
question_sim = cosine_similarity([question_embeddings[i]], [question_embeddings[j]])[0][0]
|
||
|
||
# 如果相似度均超过阈值,则认为是重复
|
||
if question_sim > similarity_threshold:
|
||
duplicate_found = True
|
||
# 保存重复对的相关信息到 `duplicate_entries`,包括当前问答和匹配到的问答
|
||
duplicate_entries.append({
|
||
"duplicate_question": data[i]['question'],
|
||
"duplicate_answer": data[i]['answer'],
|
||
"matched_question": data[j]['question'],
|
||
"matched_answer": data[j]['answer']
|
||
})
|
||
break
|
||
|
||
if not duplicate_found:
|
||
unique_indices.append(i)
|
||
|
||
# 构建去重后的数据
|
||
filtered_data = [data[i] for i in unique_indices]
|
||
return filtered_data, duplicate_entries
|
||
|
||
# 主程序
|
||
if __name__ == '__main__':
|
||
input_file = '/home/ubuntu/50T/fsy/layer2/PGEE/code/dataset.json' # 输入 JSON 文件路径
|
||
output_file = '/home/ubuntu/50T/fsy/layer2/PGEE/code/onrepeat_99.json' # 去重后的输出文件路径
|
||
duplicates_file = '/home/ubuntu/50T/fsy/layer2/PGEE/codeduplicates_99.json' # 筛选掉的问答对文件路径
|
||
pkl_path = "/home/ubuntu/50T/fsy/layer2/PGEE/question_embeddings.pkl"
|
||
qa_data = load_json(input_file)
|
||
|
||
# 进行去重,将获得的向量保存为pkl文件
|
||
# filtered_data, duplicate_entries = deduplicate_qa(qa_data, similarity_threshold=0.9)
|
||
|
||
# 导入pkl文件进行查重
|
||
filtered_data, duplicate_entries =deduplicate_qa_pkl(qa_data,pkl_path)
|
||
|
||
# 按照难度进行问题筛选
|
||
|
||
|
||
# 对于非选择题,选择答案最相近的答案作为错误选项
|
||
|
||
|
||
# 保存处理后的问答对以及重复的问答对
|
||
save_json(filtered_data, output_file)
|
||
save_json(duplicate_entries, duplicates_file)
|
||
|
||
|
||
|
||
print(f"去重完成!处理前共有 {len(qa_data)} 条问答对,处理后剩余 {len(filtered_data)} 条。")
|
||
print(f"重复问答对保存到 {duplicates_file},共保存 {len(duplicate_entries)} 条。") |