107 lines
4.0 KiB
Python
107 lines
4.0 KiB
Python
# 判断科学问题是否关于材料
|
||
import json
|
||
import time
|
||
import threading
|
||
import queue
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from openai import OpenAI
|
||
|
||
# 创建用于线程安全操作的锁和队列
|
||
result_lock = threading.Lock()
|
||
api_semaphore = threading.Semaphore(5) # 限制同时进行的API请求数量
|
||
material_items = [] # 存储材料科学相关的条目
|
||
error_items = [] # 存储处理出错的条目
|
||
|
||
client = OpenAI(
|
||
api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d",
|
||
base_url="https://vip.apiyi.com/v1"
|
||
)
|
||
|
||
def load_qa_data(file_path):
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
return data
|
||
|
||
# 判断是否为材料科学相关题目
|
||
def classify_qa_type(question,choices):
|
||
prompt = f"""
|
||
This is a classification task. Please analyze the given input question & choices and determine whether it is related to the field of materials science. If the question pertains to materials science topics, such as material properties, composition, structure, preparation, applications, performance, processing, or other materials-related subjects, strictly return the number 1. If the question is not related to the field of materials science, strictly return the number 0. Do not provide any other explanations or outputs; just return the number 1 or 0.
|
||
Question:
|
||
{question}
|
||
Choices:
|
||
{choices}
|
||
"""
|
||
|
||
with api_semaphore: # 控制API请求并发
|
||
try:
|
||
response = client.chat.completions.create(
|
||
model="deepseek-chat",
|
||
messages=[
|
||
{"role": "system", "content": "You are a helpful assistant"},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
stream=False
|
||
)
|
||
result = response.choices[0].message.content.strip().lower()
|
||
print(result)
|
||
return result
|
||
except Exception as e:
|
||
print(f"API调用错误: {e}")
|
||
return "2" # 如果API调用失败,返回错误码
|
||
|
||
# 处理单个条目的函数
|
||
def process_item(item, index, total):
|
||
print(f"处理第 {index+1}/{total} 条数据...")
|
||
question = item["question"]
|
||
choices = item["choices"]
|
||
# choices = item["distractor3"] +','+ item["distractor2"] + ',' + item["distractor1"] +','+item["correct_answer"]
|
||
|
||
label = classify_qa_type(question,choices)
|
||
|
||
with result_lock:
|
||
if "1" in label:
|
||
material_items.append(item)
|
||
elif "2" in label:
|
||
item["error"] = "yes"
|
||
error_items.append(item)
|
||
|
||
# 保存处理后的数据
|
||
def save_processed_data(data, output_file):
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
def main():
|
||
input_file = "/home/ubuntu/50T/fsy/mmlu/high_school_physics/test.json"
|
||
output_file = "/home/ubuntu/50T/fsy/mmlu/high_school_physics/mmlu-test-mat.json"
|
||
error_file = "/home/ubuntu/50T/fsy/mmlu/high_school_physics/mmlu-test-error.json"
|
||
|
||
data = load_qa_data(input_file)
|
||
total = len(data)
|
||
|
||
# 使用线程池处理数据
|
||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||
futures = []
|
||
for i, item in enumerate(data):
|
||
# 提交任务到线程池
|
||
future = executor.submit(process_item, item, i, total)
|
||
futures.append(future)
|
||
|
||
# 每提交10个任务,休息一下,避免API限制
|
||
if (i+1) % 10 == 0:
|
||
time.sleep(1)
|
||
|
||
# 等待所有任务完成
|
||
for future in futures:
|
||
future.result()
|
||
|
||
# 保存材料科学相关的条目
|
||
save_processed_data(material_items, output_file)
|
||
print(f"处理完成,材料科学相关条目已保存到 {output_file}")
|
||
|
||
# 保存处理出错的条目
|
||
if error_items:
|
||
save_processed_data(error_items, error_file)
|
||
print(f"处理出错的条目已保存到 {error_file}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |