Files
MatBench/layer1/ALL-merge/classify.py
2025-06-03 10:23:41 +08:00

126 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from openai import OpenAI
import time
import os
from tqdm import tqdm
import pandas as pd
import re
API_KEY="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
BASE_URL="https://vip.apiyi.com/v1"
MODEL_DEEPSEEK_V3="deepseek-chat"
CATEGORIES = ['Physics', 'Chemistry', 'Biological', 'Unknown']
# 加载JSON数据
def load_data(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
def classify_question(idx, len,question, options):
prompt = f"""
Please classify the given question into one of these three categories: 'Physics', 'Chemistry', 'Biological' or 'Unknown'.\n
Please format your response by wrapping the category name with the tags [CATEGORY] and [\CATEGORY]. For example, your response should look like one of these:\n
- [CATEGORY]Physics[\CATEGORY]
- [CATEGORY]Chemistry[\CATEGORY]
- [CATEGORY]Biological[\CATEGORY]
- [CATEGORY]Unknown[\CATEGORY]
Question: {question}\n
Options: {options}\n
"""
client = OpenAI(api_key = API_KEY,base_url = BASE_URL)
# 重试机制
max_retries = 3
for attempt in range(max_retries):
try:
response = client.chat.completions.create(
model= MODEL_DEEPSEEK_V3,
messages=[
{"role": "system", "content": "You are a helpful educational assistant."},
{"role": "user", "content": prompt}
],
temperature=0.3,
stream = False,
)
classification = response.choices[0].message.content.strip()
extracted_category = string_extraction(idx,len,classification)
if extracted_category in CATEGORIES:
return extracted_category
else:
print(f"Invalid category '{extracted_category}' returned. Retrying. {attempt + 1}/{max_retries}")
continue
except Exception as e:
print(f"Error on attempt {attempt + 1}/{max_retries}: {e}")
if attempt == max_retries - 1:
return 'Error' # 如果达到最大重试次数,返回错误
# 在重试之前等待
time.sleep(2)
return 'Error'
def string_extraction(idx,len,classification):
pattern = r'\[CATEGORY\](.*?)\[\\CATEGORY\]'
match = re.search(pattern, classification)
print(f"{idx + 1}/{len}: {match.group(1)}")
# if match:
# return match.group(1)
# else:
# return "Unknown"
return match.group(1) if match else 'Unknown'
def main():
# 加载数据
file_path = '/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged.json' # 替换为你的JSON文件路径
data = load_data(file_path)
data_length = len(data)
# 创建结果列表
results = []
# 处理每个问题
for i, item in enumerate(tqdm(data, desc="Classifying questions")):
question = item.get('question', '')
text = item['choices']['text']
label = item['choices']['label']
formatted_choices = " ".join([f"({lbl}) {txt}" for lbl, txt in zip(label, text)])
# correct_answer = item.get('correct_answer', '')
classification = classify_question(i, data_length,question,formatted_choices)
# 添加分类结果
item_with_classification = item.copy()
item_with_classification['subject_category'] = classification
results.append(item_with_classification)
# 每处理100个问题保存一次中间结果
if (i + 1) % 100 == 0:
# with open('interim_results.json', 'w', encoding='utf-8') as f:
# json.dump(results, f, ensure_ascii=False, indent=4)
# 可选:分析中间结果
categories = {'Physics': 0, 'Chemistry': 0, 'Biological': 0, 'Unknown': 0}
for item in results:
categories[item.get('subject_category', 'Unknown')] += 1
print(f"Processed {i+1} questions. Current distribution:")
for category, count in categories.items():
print(f"{category}: {count}")
# API速率限制处理
time.sleep(0.5)
# 保存最终结果
with open('/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged_classified.json', 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=4)
# 分析结果
df = pd.DataFrame(results)
category_counts = df['subject_category'].value_counts()
print("\nFinal distribution of questions by category:")
print(category_counts)
print("\nTask completed. Results saved to 'classified_questions.json'")
if __name__ == "__main__":
main()