Files
MatBench/layer2/PGEE/code/step3_deduplication.py
2025-05-28 11:00:24 +08:00

130 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
0. 将问题从xls提取为json
1. 将问题进行拆分
2. 翻译成英文
3. 去重
4. 使用大模型进行难度评估和筛选
"""
from openai import OpenAI
import json
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pickle
from prompts import CLEAN_PROMPTS, SELECT_QUESTION_PROMPT
API_KEY="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
BASE_URL="https://vip.apiyi.com/v1"
MODEL_GPT="text-embedding-ada-002"
MODELS = ["deepseek-reasoner", "claude-3-7-sonnet-20250219", "qwen-max", "deepseek-chat", "gemini-pro"]
def get_embedding(text):
client = OpenAI(api_key= API_KEY, base_url= BASE_URL)
response = client.embeddings.create(
model = MODEL_GPT,
input = text
)
return response.data[0].embedding
def compute_embeddings(texts):
embeddings = []
for i,text in enumerate(texts):
print("正在处理第{}/{}".format(i+1,len(texts)))
embeddings.append(get_embedding(text))
return np.array(embeddings)
def load_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
return json.load(file)
def save_json(data, file_path):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=2)
def save_embeddings(embeddings, file_path):
with open(file_path, 'wb') as file:
pickle.dump(embeddings, file)
def load_embeddings(file_path):
with open(file_path, 'rb') as file:
return pickle.load(file)
def deduplicate_qa(data, save_vectors=True):
questions = [item['question'] for item in data]
# 生成嵌入向量
question_embeddings = compute_embeddings(questions)
if save_vectors:
print("保存问题的嵌入向量...")
save_embeddings(question_embeddings, '/home/ubuntu/50T/fsy/layer2/QA/question_embeddings.pkl')
# 去重逻辑
filtered_data, duplicate_entries = de_emphasize(question_embeddings,data)
return filtered_data, duplicate_entries
def deduplicate_qa_pkl(data,pkl_path):
question_embeddings = load_embeddings(pkl_path)
filtered_data, duplicate_entries = de_emphasize(question_embeddings,data)
return filtered_data, duplicate_entries
def de_emphasize(question_embeddings,data,similarity_threshold=0.99):
unique_indices = []
duplicate_entries = [] # 用来保存重复的问答对信息
for i in range(len(data)):
print("正在处理第{}/{}".format(i+1,len(data)))
duplicate_found = False
for j in unique_indices:
# 计算问题的语义相似性
question_sim = cosine_similarity([question_embeddings[i]], [question_embeddings[j]])[0][0]
# 如果相似度均超过阈值,则认为是重复
if question_sim > similarity_threshold:
duplicate_found = True
# 保存重复对的相关信息到 `duplicate_entries`,包括当前问答和匹配到的问答
duplicate_entries.append({
"duplicate_question": data[i]['question'],
"duplicate_answer": data[i]['answer'],
"matched_question": data[j]['question'],
"matched_answer": data[j]['answer']
})
break
if not duplicate_found:
unique_indices.append(i)
# 构建去重后的数据
filtered_data = [data[i] for i in unique_indices]
return filtered_data, duplicate_entries
# 主程序
if __name__ == '__main__':
input_file = '/home/ubuntu/50T/fsy/layer2/PGEE/code/dataset.json' # 输入 JSON 文件路径
output_file = '/home/ubuntu/50T/fsy/layer2/PGEE/code/onrepeat_99.json' # 去重后的输出文件路径
duplicates_file = '/home/ubuntu/50T/fsy/layer2/PGEE/codeduplicates_99.json' # 筛选掉的问答对文件路径
pkl_path = "/home/ubuntu/50T/fsy/layer2/PGEE/question_embeddings.pkl"
qa_data = load_json(input_file)
# 进行去重将获得的向量保存为pkl文件
# filtered_data, duplicate_entries = deduplicate_qa(qa_data, similarity_threshold=0.9)
# 导入pkl文件进行查重
filtered_data, duplicate_entries =deduplicate_qa_pkl(qa_data,pkl_path)
# 按照难度进行问题筛选
# 对于非选择题,选择答案最相近的答案作为错误选项
# 保存处理后的问答对以及重复的问答对
save_json(filtered_data, output_file)
save_json(duplicate_entries, duplicates_file)
print(f"去重完成!处理前共有 {len(qa_data)} 条问答对,处理后剩余 {len(filtered_data)} 条。")
print(f"重复问答对保存到 {duplicates_file},共保存 {len(duplicate_entries)} 条。")