import json from openai import OpenAI import time import os from tqdm import tqdm import pandas as pd import re import concurrent.futures import threading API_KEY="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d" BASE_URL="https://vip.apiyi.com/v1" MODEL_DEEPSEEK_V3 = "deepseek-chat" CATEGORIES = ['Physics', 'Chemistry', 'Biological', 'Unknown'] # Thread-local storage for OpenAI clients local = threading.local() # Lock for thread-safe operations write_lock = threading.Lock() progress_lock = threading.Lock() processed_count = 0 category_counts = {'Physics': 0, 'Chemistry': 0, 'Biological': 0, 'Unknown': 0} # 加载JSON数据 def load_data(file_path): with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) return data def get_client(): """Get thread-local OpenAI client""" if not hasattr(local, 'client'): local.client = OpenAI(api_key=API_KEY, base_url=BASE_URL) return local.client def classify_question(idx, total_len, question, options): prompt = f""" Please classify the given question into one of these three categories: 'Physics', 'Chemistry', 'Biological' or 'Unknown'.\n Please format your response by wrapping the category name with the tags [CATEGORY] and [/CATEGORY]. For example, your response should look like one of these:\n - [CATEGORY]Physics[/CATEGORY] - [CATEGORY]Chemistry[/CATEGORY] - [CATEGORY]Biological[/CATEGORY] - [CATEGORY]Unknown[/CATEGORY] Question: {question}\n Options: {options}\n """ client = get_client() # 重试机制 max_retries = 3 for attempt in range(max_retries): try: response = client.chat.completions.create( model=MODEL_DEEPSEEK_V3, messages=[ {"role": "system", "content": "You are a helpful educational assistant."}, {"role": "user", "content": prompt} ], temperature=0.3, stream=False, ) classification = response.choices[0].message.content.strip() extracted_category = string_extraction(idx, total_len, classification) if extracted_category in CATEGORIES: return extracted_category else: with progress_lock: print(f"Invalid category '{extracted_category}' returned. Retrying. {attempt + 1}/{max_retries}") continue except Exception as e: with progress_lock: print(f"Error on attempt {attempt + 1}/{max_retries}: {e}") if attempt == max_retries - 1: return 'Error' # 如果达到最大重试次数,返回错误 # 在重试之前等待 time.sleep(2) return 'Error' def string_extraction(idx, total_len, classification): pattern = r'\[CATEGORY\](.*?)\[\/CATEGORY\]' match = re.search(pattern, classification) extracted = match.group(1) if match else 'Unknown' with progress_lock: print(f"{idx + 1}/{total_len}: {extracted}") return extracted def process_item(args): idx, total_len, item = args question = item.get('question', '') text = item['choices']['text'] label = item['choices']['label'] formatted_choices = " ".join([f"({lbl}) {txt}" for lbl, txt in zip(label, text)]) classification = classify_question(idx, total_len, question, formatted_choices) # 添加分类结果 item_with_classification = item.copy() item_with_classification['subject_category'] = classification # Update global counters global processed_count with write_lock: processed_count += 1 category_counts[classification] += 1 # 每处理100个问题,打印一次中间结果 if processed_count % 100 == 0: print(f"\nProcessed {processed_count} questions. Current distribution:") for category, count in category_counts.items(): print(f"{category}: {count}") # API速率限制处理 - 减少sleep时间,因为多线程已经提供了自然的延迟 time.sleep(0.1) return item_with_classification def main(): # 加载数据 file_path = '/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged.json' data = load_data(file_path) data_length = len(data) results = [] # 创建参数列表 args_list = [(i, data_length, item) for i, item in enumerate(data)] # 设定线程数,根据实际API限制和服务器性能调整 num_threads = 10 # 根据需要调整线程数 print(f"Starting classification with {num_threads} threads...") # 使用ThreadPoolExecutor进行并行处理 with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: # 使用tqdm来显示进度 futures = list(tqdm(executor.map(process_item, args_list), total=data_length, desc="Classifying questions")) results = futures # 保存最终结果 with open('/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged_classified.json', 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=4) # 分析结果 df = pd.DataFrame(results) category_counts_final = df['subject_category'].value_counts() print("\nFinal distribution of questions by category:") print(category_counts_final) print("\nTask completed. Results saved to '/home/ubuntu/50T/fsy/MatBench/layer1/ALL-merge/merged_classified.json'") if __name__ == "__main__": main()