126 lines
4.7 KiB
Python
126 lines
4.7 KiB
Python
import json
|
||
from openai import OpenAI
|
||
import time
|
||
import os
|
||
from tqdm import tqdm
|
||
import pandas as pd
|
||
import re
|
||
|
||
API_KEY="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
|
||
BASE_URL="https://vip.apiyi.com/v1"
|
||
MODEL_DEEPSEEK_V3="deepseek-chat"
|
||
CATEGORIES = ['Physics', 'Chemistry', 'Biological', 'Unknown']
|
||
# 加载JSON数据
|
||
def load_data(file_path):
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
return data
|
||
|
||
def classify_question(idx, len,question, options):
|
||
prompt = f"""
|
||
Please classify the given question into one of these three categories: 'Physics', 'Chemistry', 'Biological' or 'Unknown'.\n
|
||
Please format your response by wrapping the category name with the tags [CATEGORY] and [\CATEGORY]. For example, your response should look like one of these:\n
|
||
- [CATEGORY]Physics[\CATEGORY]
|
||
- [CATEGORY]Chemistry[\CATEGORY]
|
||
- [CATEGORY]Biological[\CATEGORY]
|
||
- [CATEGORY]Unknown[\CATEGORY]
|
||
Question: {question}\n
|
||
Options: {options}\n
|
||
|
||
"""
|
||
client = OpenAI(api_key = API_KEY,base_url = BASE_URL)
|
||
# 重试机制
|
||
max_retries = 3
|
||
for attempt in range(max_retries):
|
||
try:
|
||
response = client.chat.completions.create(
|
||
model= MODEL_DEEPSEEK_V3,
|
||
messages=[
|
||
{"role": "system", "content": "You are a helpful educational assistant."},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=0.3,
|
||
stream = False,
|
||
)
|
||
|
||
classification = response.choices[0].message.content.strip()
|
||
extracted_category = string_extraction(idx,len,classification)
|
||
if extracted_category in CATEGORIES:
|
||
return extracted_category
|
||
else:
|
||
print(f"Invalid category '{extracted_category}' returned. Retrying. {attempt + 1}/{max_retries}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
print(f"Error on attempt {attempt + 1}/{max_retries}: {e}")
|
||
if attempt == max_retries - 1:
|
||
return 'Error' # 如果达到最大重试次数,返回错误
|
||
|
||
# 在重试之前等待
|
||
time.sleep(2)
|
||
|
||
return 'Error'
|
||
def string_extraction(idx,len,classification):
|
||
pattern = r'\[CATEGORY\](.*?)\[\\CATEGORY\]'
|
||
match = re.search(pattern, classification)
|
||
print(f"{idx + 1}/{len}: {match.group(1)}")
|
||
# if match:
|
||
# return match.group(1)
|
||
# else:
|
||
# return "Unknown"
|
||
|
||
return match.group(1) if match else 'Unknown'
|
||
|
||
def main():
|
||
# 加载数据
|
||
file_path = '/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged.json' # 替换为你的JSON文件路径
|
||
data = load_data(file_path)
|
||
data_length = len(data)
|
||
# 创建结果列表
|
||
results = []
|
||
|
||
# 处理每个问题
|
||
for i, item in enumerate(tqdm(data, desc="Classifying questions")):
|
||
question = item.get('question', '')
|
||
text = item['choices']['text']
|
||
label = item['choices']['label']
|
||
formatted_choices = " ".join([f"({lbl}) {txt}" for lbl, txt in zip(label, text)])
|
||
# correct_answer = item.get('correct_answer', '')
|
||
|
||
classification = classify_question(i, data_length,question,formatted_choices)
|
||
# 添加分类结果
|
||
item_with_classification = item.copy()
|
||
item_with_classification['subject_category'] = classification
|
||
results.append(item_with_classification)
|
||
|
||
# 每处理100个问题,保存一次中间结果
|
||
if (i + 1) % 100 == 0:
|
||
# with open('interim_results.json', 'w', encoding='utf-8') as f:
|
||
# json.dump(results, f, ensure_ascii=False, indent=4)
|
||
|
||
# 可选:分析中间结果
|
||
categories = {'Physics': 0, 'Chemistry': 0, 'Biological': 0, 'Unknown': 0}
|
||
for item in results:
|
||
categories[item.get('subject_category', 'Unknown')] += 1
|
||
|
||
print(f"Processed {i+1} questions. Current distribution:")
|
||
for category, count in categories.items():
|
||
print(f"{category}: {count}")
|
||
|
||
# API速率限制处理
|
||
time.sleep(0.5)
|
||
|
||
# 保存最终结果
|
||
with open('/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged_classified.json', 'w', encoding='utf-8') as f:
|
||
json.dump(results, f, ensure_ascii=False, indent=4)
|
||
|
||
# 分析结果
|
||
df = pd.DataFrame(results)
|
||
category_counts = df['subject_category'].value_counts()
|
||
print("\nFinal distribution of questions by category:")
|
||
print(category_counts)
|
||
|
||
print("\nTask completed. Results saved to 'classified_questions.json'")
|
||
|
||
if __name__ == "__main__":
|
||
main() |