Files
MatBench/layer1/ALL-merge/classify_muti.py
2025-06-03 10:23:41 +08:00

157 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from openai import OpenAI
import time
import os
from tqdm import tqdm
import pandas as pd
import re
import concurrent.futures
import threading
API_KEY="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
BASE_URL="https://vip.apiyi.com/v1"
MODEL_DEEPSEEK_V3 = "deepseek-chat"
CATEGORIES = ['Physics', 'Chemistry', 'Biological', 'Unknown']
# Thread-local storage for OpenAI clients
local = threading.local()
# Lock for thread-safe operations
write_lock = threading.Lock()
progress_lock = threading.Lock()
processed_count = 0
category_counts = {'Physics': 0, 'Chemistry': 0, 'Biological': 0, 'Unknown': 0}
# 加载JSON数据
def load_data(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
def get_client():
"""Get thread-local OpenAI client"""
if not hasattr(local, 'client'):
local.client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
return local.client
def classify_question(idx, total_len, question, options):
prompt = f"""
Please classify the given question into one of these three categories: 'Physics', 'Chemistry', 'Biological' or 'Unknown'.\n
Please format your response by wrapping the category name with the tags [CATEGORY] and [/CATEGORY]. For example, your response should look like one of these:\n
- [CATEGORY]Physics[/CATEGORY]
- [CATEGORY]Chemistry[/CATEGORY]
- [CATEGORY]Biological[/CATEGORY]
- [CATEGORY]Unknown[/CATEGORY]
Question: {question}\n
Options: {options}\n
"""
client = get_client()
# 重试机制
max_retries = 3
for attempt in range(max_retries):
try:
response = client.chat.completions.create(
model=MODEL_DEEPSEEK_V3,
messages=[
{"role": "system", "content": "You are a helpful educational assistant."},
{"role": "user", "content": prompt}
],
temperature=0.3,
stream=False,
)
classification = response.choices[0].message.content.strip()
extracted_category = string_extraction(idx, total_len, classification)
if extracted_category in CATEGORIES:
return extracted_category
else:
with progress_lock:
print(f"Invalid category '{extracted_category}' returned. Retrying. {attempt + 1}/{max_retries}")
continue
except Exception as e:
with progress_lock:
print(f"Error on attempt {attempt + 1}/{max_retries}: {e}")
if attempt == max_retries - 1:
return 'Error' # 如果达到最大重试次数,返回错误
# 在重试之前等待
time.sleep(2)
return 'Error'
def string_extraction(idx, total_len, classification):
pattern = r'\[CATEGORY\](.*?)\[\/CATEGORY\]'
match = re.search(pattern, classification)
extracted = match.group(1) if match else 'Unknown'
with progress_lock:
print(f"{idx + 1}/{total_len}: {extracted}")
return extracted
def process_item(args):
idx, total_len, item = args
question = item.get('question', '')
text = item['choices']['text']
label = item['choices']['label']
formatted_choices = " ".join([f"({lbl}) {txt}" for lbl, txt in zip(label, text)])
classification = classify_question(idx, total_len, question, formatted_choices)
# 添加分类结果
item_with_classification = item.copy()
item_with_classification['subject_category'] = classification
# Update global counters
global processed_count
with write_lock:
processed_count += 1
category_counts[classification] += 1
# 每处理100个问题打印一次中间结果
if processed_count % 100 == 0:
print(f"\nProcessed {processed_count} questions. Current distribution:")
for category, count in category_counts.items():
print(f"{category}: {count}")
# API速率限制处理 - 减少sleep时间因为多线程已经提供了自然的延迟
time.sleep(0.1)
return item_with_classification
def main():
# 加载数据
file_path = '/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged.json'
data = load_data(file_path)
data_length = len(data)
results = []
# 创建参数列表
args_list = [(i, data_length, item) for i, item in enumerate(data)]
# 设定线程数根据实际API限制和服务器性能调整
num_threads = 10 # 根据需要调整线程数
print(f"Starting classification with {num_threads} threads...")
# 使用ThreadPoolExecutor进行并行处理
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
# 使用tqdm来显示进度
futures = list(tqdm(executor.map(process_item, args_list), total=data_length, desc="Classifying questions"))
results = futures
# 保存最终结果
with open('/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged_classified.json', 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=4)
# 分析结果
df = pd.DataFrame(results)
category_counts_final = df['subject_category'].value_counts()
print("\nFinal distribution of questions by category:")
print(category_counts_final)
print("\nTask completed. Results saved to '/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged_classified.json'")
if __name__ == "__main__":
main()