Files
MatBench/layer2/PGEE/code/classify_muti.py
2025-06-03 11:17:01 +08:00

169 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from openai import OpenAI
import time
import os
from tqdm import tqdm
import pandas as pd
import re
import concurrent.futures
import threading
API_KEY="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
BASE_URL="https://vip.apiyi.com/v1"
MODEL_DEEPSEEK_V3 = "deepseek-chat"
CATEGORIES = ['Atomic Structure and Interatomic Bonding', 'The Structure of Solids', 'Imperfections in Solids', 'Mechanical Properties of Metals','Dislocations and Strengthening Mechanisms','Failure','Phase Transformations: Development of Microstructure and Alteration of Mechanical Properties','Applications and Processing of Materials','Corrosion and Degradation of Materials','Functional Properties of Materials','Unknown']
FILE_PATH = '/home/ubuntu/50T/fsy/A/MatBench/layer2/PGEE/code/stepz_final_choice_questions_filtered_only_hard.json'
OUTPUT_PATH='/home/ubuntu/50T/fsy/A/MatBench/layer2/PGEE/code/stepz_classified_only_hard.json'
# Thread-local storage for OpenAI clients
local = threading.local()
# Lock for thread-safe operations
write_lock = threading.Lock()
progress_lock = threading.Lock()
processed_count = 0
category_counts = {'Atomic Structure and Interatomic Bonding': 0, 'The Structure of Solids': 0, 'Imperfections in Solids': 0, 'Mechanical Properties of Metals':0,'Dislocations and Strengthening Mechanisms':0,'Failure':0,'Phase Transformations: Development of Microstructure and Alteration of Mechanical Properties':0,'Applications and Processing of Materials':0,'Corrosion and Degradation of Materials':0,'Functional Properties of Materials':0,'Unknown':0}
# 加载JSON数据
def load_data(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
def get_client():
"""Get thread-local OpenAI client"""
if not hasattr(local, 'client'):
local.client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
return local.client
def classify_question(idx, total_len, question, answer):
prompt = f"""
Given a question and its answer from the field of Materials Science fundamentals, identify which chapter or category of Materials Science the question belongs to. Choose from the following 10 categories:
-- Atomic Structure and Interatomic Bonding
-- The Structure of Solids
-- Imperfections in Solids
-- Mechanical Properties of Metals
-- Dislocations and Strengthening Mechanisms
-- Failure
-- Phase Transformations: Development of Microstructure and Alteration of Mechanical Properties
-- Applications and Processing of Materials
-- Corrosion and Degradation of Materials
-- Functional Properties of Materials
QUESTIONS:{question}\n
ANSWER:{answer}\n
Provide your response by enclosing the category number and name within [CATEGORY] and [/CATEGORY] tags. For example: [CATEGORY]Atomic Structure and Interatomic Bonding[/CATEGORY]
Analyze both the question and answer carefully to determine the most appropriate category based on the question and options.
"""
client = get_client()
# 重试机制
max_retries = 3
for attempt in range(max_retries):
try:
response = client.chat.completions.create(
model=MODEL_DEEPSEEK_V3,
messages=[
{"role": "system", "content": "You are a helpful educational assistant."},
{"role": "user", "content": prompt}
],
temperature=0.3,
stream=False,
)
classification = response.choices[0].message.content.strip()
extracted_category = string_extraction(idx, total_len, classification)
if extracted_category in CATEGORIES:
return extracted_category
else:
with progress_lock:
print(f"Invalid category '{extracted_category}' returned. Retrying. {attempt + 1}/{max_retries}")
continue
except Exception as e:
with progress_lock:
print(f"Error on attempt {attempt + 1}/{max_retries}: {e}")
if attempt == max_retries - 1:
return 'Error' # 如果达到最大重试次数,返回错误
# 在重试之前等待
time.sleep(2)
return 'Error'
def string_extraction(idx, total_len, classification):
pattern = r'\[CATEGORY\](.*?)\[\/CATEGORY\]'
match = re.search(pattern, classification)
extracted = match.group(1) if match else 'Unknown'
with progress_lock:
print(f"{idx + 1}/{total_len}: {extracted}")
return extracted
def process_item(args):
idx, total_len, item = args
question = item.get('question', '')
text = item['choices']['text']
label = item['choices']['label']
# answer = item.get('correct_option','')
formatted_choices = " ".join([f"({lbl}) {txt}" for lbl, txt in zip(label, text)])
classification = classify_question(idx, total_len, question, formatted_choices)
# 添加分类结果
item_with_classification = item.copy()
item_with_classification['subject_category'] = classification
# Update global counters
global processed_count
with write_lock:
processed_count += 1
category_counts[classification] += 1
# 每处理100个问题打印一次中间结果
if processed_count % 100 == 0:
print(f"\nProcessed {processed_count} questions. Current distribution:")
for category, count in category_counts.items():
print(f"{category}: {count}")
# API速率限制处理 - 减少sleep时间因为多线程已经提供了自然的延迟
time.sleep(0.1)
return item_with_classification
def main():
# 加载数据
data = load_data(FILE_PATH)
data_length = len(data)
results = []
# 创建参数列表
args_list = [(i, data_length, item) for i, item in enumerate(data)]
# 设定线程数根据实际API限制和服务器性能调整
num_threads = 20 # 根据需要调整线程数
print(f"Starting classification with {num_threads} threads...")
# 使用ThreadPoolExecutor进行并行处理
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
# 使用tqdm来显示进度
futures = list(tqdm(executor.map(process_item, args_list), total=data_length, desc="Classifying questions"))
results = futures
# 保存最终结果
with open(OUTPUT_PATH, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=4)
# 分析结果
df = pd.DataFrame(results)
category_counts_final = df['subject_category'].value_counts()
print("\nFinal distribution of questions by category:")
print(category_counts_final)
print(f"\nTask completed. Results saved to {OUTPUT_PATH}")
if __name__ == "__main__":
main()