Files
MatBench/layer2/eval/eval.py
2025-05-28 11:00:24 +08:00

101 lines
3.3 KiB
Python

#多线程对LLM进行评估
import json
import threading
from tqdm import tqdm
import concurrent.futures
from openai import OpenAI
client = OpenAI(
api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d",
base_url="https://vip.apiyi.com/v1"
)
# 创建一个线程锁,用于保护共享资源
thread_lock = threading.Lock()
def load_json_data(filepath):
with open(filepath, 'r') as file:
data = json.load(file)
return data
def get_response(question,max_retries=10):
retries = 0
while retries<max_retries:
try:
response = client.chat.completions.create(
#
model="claude-3-7-sonnet-20250219-thinking",
messages= [
{"role": "system", "content": "You are an expert in the field of materials science, adept at answering questions related to fundamental aspects of materials science, including material structure, properties, processing, and applications."},
{"role": "user", "content": question}
],
temperature=0
)
answer = response.choices[0].message.content
return answer
except Exception as e:
print(f"Error in getting LLM response (Attempt {retries + 1}/{max_retries}): {e}")
retries += 1
print(f"Failed to get response after {max_retries} attempts, returning None.")
return "error!"
def process_item(item, index):
question = item['question']
expected_answer = item['answer'].strip()
llm_answer = get_response(question)
# 返回处理结果和是否正确
is_correct = expected_answer in llm_answer
return {
'index': index,
'question': question,
'expected_answer': expected_answer,
'llm_answer': llm_answer,
'is_correct': is_correct
}
def calculate_accuracy_multithreaded(data, max_workers=5):
correct_answers = 0
results = []
# 使用进度条跟踪进度
with tqdm(total=len(data), desc="Processing items") as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务
future_to_index = {executor.submit(process_item, item, i): i for i, item in enumerate(data)}
# 处理结果
for future in concurrent.futures.as_completed(future_to_index):
result = future.result()
results.append(result)
if result['is_correct']:
with thread_lock:
correct_answers += 1
pbar.update(1)
# 按原始索引排序结果
results.sort(key=lambda x: x['index'])
# 计算准确率
total_questions = len(data)
accuracy = (correct_answers / total_questions) * 100
return accuracy, results
def main():
filepath = '/home/ubuntu/50T/fsy/benchmark/1200ckjtest/1200ckj.json'
data = load_json_data(filepath)
max_workers = 8
accuracy, results =calculate_accuracy_multithreaded(data,max_workers)
# accuracy = calculate_accuracy(data)
print(f"Accuracy of claude-3-7-sonnet-20250219-thinking: {accuracy:.2f}%")
with open('claude-3-7-sonnet-20250219-thinking.json', 'w') as f:
json.dump(results, f, indent=2)
if __name__ == "__main__":
main()