Files
MatBench/layer2/process/step6.py
2025-05-28 11:00:24 +08:00

178 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
对821道英文问题进行处理
1. 判断是否包含多个子问题,将问题拆分为完整子问题(去掉推理过程,只保留最后结果)
2. 判断题目类型
3. 将题目做成选择题
对计算题,在数值附近随机生成三个相似答案作为错误选项
对简答题,与标准答案最相近的其他问题的答案作为三个错误选项
4. 将正确和错误选项随机打乱生成ABCD选择题的模型
5. 添加prompt并将选择题包裹在[ANSWER]<answer>[/ANSWER]标签中
6. 模型打分
"""
import json
import threading
from tqdm import tqdm
import concurrent.futures
from openai import OpenAI
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
import re
client = OpenAI(
api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d",
base_url="https://vip.apiyi.com/v1"
)
thread_lock = threading.Lock()
def load_json_data(filepath):
with open(filepath, 'r') as file:
data = json.load(file)
return data
def get_response(input,max_retries=10):
retries = 0
while retries<max_retries:
try:
response = client.chat.completions.create(
#
model="qwen-max-2025-01-25",
messages= [
{"role": "system", "content": "You are an expert in the field of materials science, adept at answering questions related to fundamental aspects of materials science, including material structure, properties, processing, and applications."},
{"role": "user", "content": input}
],
temperature=0
)
answer = response.choices[0].message.content
return answer
except Exception as e:
print(f"Error in getting LLM response (Attempt {retries + 1}/{max_retries}): {e}")
retries += 1
print(f"Failed to get response after {max_retries} attempts, returning None.")
return "error!"
def process_item(item, index):
question = item['question']
text = item['choices']['text']
label = item['choices']['label']
prompt = item['prompt']
expected_answer = item['answer'].strip()
formatted_choices = " ".join([f"({label}) {text}" for label, text in zip(label, text)])
input = f"{question} {formatted_choices}. {prompt}"
llm_answer = get_response(input)
return {
'index': index,
'question': question,
'choices': item['choices'],
'answer': expected_answer,
'llm_answer': llm_answer
}
def extract_answer(answer_string):
match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string)
if match:
return match.group(1).strip()
return None
def parse_answer(answer):
if answer is None:
return []
return [a.strip() for a in answer.split(',')]
def compute_metrics(data):
true_answers = []
pred_answers = []
for item in data:
true_ans = extract_answer(item["answer"])
pred_ans = extract_answer(item["llm_answer"])
true_answers.append(parse_answer(true_ans))
pred_answers.append(parse_answer(pred_ans))
correct_counts = []
for true_ans, pred_ans in zip(true_answers, pred_answers):
if true_ans and pred_ans and set(true_ans) == set(pred_ans):
correct_counts.append(1)
else:
correct_counts.append(0)
accuracy = np.mean(correct_counts)
y_true_multi = []
y_pred_multi = []
all_labels = set()
for item in data:
choices = item["choices"]["label"]
for label in choices:
all_labels.add(label)
all_labels = sorted(list(all_labels))
for true_ans, pred_ans in zip(true_answers, pred_answers):
true_vector = [1 if label in true_ans else 0 for label in all_labels]
pred_vector = [1 if label in pred_ans else 0 for label in all_labels]
y_true_multi.append(true_vector)
y_pred_multi.append(pred_vector)
y_true_multi = np.array(y_true_multi)
y_pred_multi = np.array(y_pred_multi)
precision_micro = precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0)
recall_micro = recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0)
f1_micro = f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0)
precision_macro = precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
recall_macro = recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
f1_macro = f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
return {
"accuracy": accuracy,
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1_micro": f1_micro,
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro
}
def calculate_accuracy_multithreaded(data, max_workers=5):
results = []
with tqdm(total=len(data), desc="Processing items") as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index = {executor.submit(process_item, item, i): i for i, item in enumerate(data)}
for future in concurrent.futures.as_completed(future_to_index):
result = future.result()
results.append(result)
pbar.update(1)
results.sort(key=lambda x: x['index'])
metric = compute_metrics(results)
return metric, results
def main():
filepath = '/home/ubuntu/50T/fsy/benchmark-dataset-third/ALL-merge/merged.json'
data = load_json_data(filepath)
max_workers = 8
metric, results = calculate_accuracy_multithreaded(data,max_workers)
print(f"Accuracy of qwen-max-2025-01-25: {metric}")
with open('qwen-max-2025-01-25.json', 'w') as f:
json.dump(results, f, indent=2)
if __name__ == "__main__":
main()