167 lines
5.3 KiB
Python
167 lines
5.3 KiB
Python
import json
|
|
import threading
|
|
from tqdm import tqdm
|
|
import concurrent.futures
|
|
from openai import OpenAI
|
|
import numpy as np
|
|
from sklearn.metrics import precision_score, recall_score, f1_score
|
|
import re
|
|
|
|
client = OpenAI(
|
|
api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d",
|
|
base_url="https://vip.apiyi.com/v1"
|
|
)
|
|
|
|
thread_lock = threading.Lock()
|
|
|
|
def load_json_data(filepath):
|
|
with open(filepath, 'r') as file:
|
|
data = json.load(file)
|
|
return data
|
|
|
|
def get_response(input,max_retries=10):
|
|
retries = 0
|
|
while retries<max_retries:
|
|
try:
|
|
response = client.chat.completions.create(
|
|
#
|
|
model="qwen-max-2025-01-25",
|
|
messages= [
|
|
{"role": "system", "content": "You are an expert in the field of materials science, adept at answering questions related to fundamental aspects of materials science, including material structure, properties, processing, and applications."},
|
|
{"role": "user", "content": input}
|
|
],
|
|
temperature=0
|
|
)
|
|
answer = response.choices[0].message.content
|
|
return answer
|
|
except Exception as e:
|
|
print(f"Error in getting LLM response (Attempt {retries + 1}/{max_retries}): {e}")
|
|
retries += 1
|
|
|
|
print(f"Failed to get response after {max_retries} attempts, returning None.")
|
|
return "error!"
|
|
|
|
def process_item(item, index):
|
|
question = item['question']
|
|
text = item['choices']['text']
|
|
label = item['choices']['label']
|
|
prompt = item['prompt']
|
|
expected_answer = item['answer'].strip()
|
|
|
|
formatted_choices = " ".join([f"({label}) {text}" for label, text in zip(label, text)])
|
|
input = f"{question} {formatted_choices}. {prompt}"
|
|
|
|
llm_answer = get_response(input)
|
|
|
|
return {
|
|
'index': index,
|
|
'question': question,
|
|
'choices': item['choices'],
|
|
'answer': expected_answer,
|
|
'llm_answer': llm_answer
|
|
}
|
|
|
|
def extract_answer(answer_string):
|
|
match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return None
|
|
|
|
|
|
def parse_answer(answer):
|
|
if answer is None:
|
|
return []
|
|
return [a.strip() for a in answer.split(',')]
|
|
|
|
def compute_metrics(data):
|
|
|
|
true_answers = []
|
|
pred_answers = []
|
|
|
|
for item in data:
|
|
true_ans = extract_answer(item["answer"])
|
|
pred_ans = extract_answer(item["llm_answer"])
|
|
|
|
true_answers.append(parse_answer(true_ans))
|
|
pred_answers.append(parse_answer(pred_ans))
|
|
|
|
correct_counts = []
|
|
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
|
if true_ans and pred_ans and set(true_ans) == set(pred_ans):
|
|
correct_counts.append(1)
|
|
else:
|
|
correct_counts.append(0)
|
|
|
|
accuracy = np.mean(correct_counts)
|
|
|
|
y_true_multi = []
|
|
y_pred_multi = []
|
|
all_labels = set()
|
|
|
|
for item in data:
|
|
choices = item["choices"]["label"]
|
|
for label in choices:
|
|
all_labels.add(label)
|
|
|
|
all_labels = sorted(list(all_labels))
|
|
|
|
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
|
true_vector = [1 if label in true_ans else 0 for label in all_labels]
|
|
pred_vector = [1 if label in pred_ans else 0 for label in all_labels]
|
|
y_true_multi.append(true_vector)
|
|
y_pred_multi.append(pred_vector)
|
|
|
|
y_true_multi = np.array(y_true_multi)
|
|
y_pred_multi = np.array(y_pred_multi)
|
|
|
|
precision_micro = precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0)
|
|
recall_micro = recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0)
|
|
f1_micro = f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0)
|
|
|
|
precision_macro = precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
|
recall_macro = recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
|
f1_macro = f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
|
|
|
return {
|
|
"accuracy": accuracy,
|
|
"precision_micro": precision_micro,
|
|
"recall_micro": recall_micro,
|
|
"f1_micro": f1_micro,
|
|
"precision_macro": precision_macro,
|
|
"recall_macro": recall_macro,
|
|
"f1_macro": f1_macro
|
|
}
|
|
|
|
def calculate_accuracy_multithreaded(data, max_workers=5):
|
|
results = []
|
|
|
|
with tqdm(total=len(data), desc="Processing items") as pbar:
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
|
future_to_index = {executor.submit(process_item, item, i): i for i, item in enumerate(data)}
|
|
|
|
for future in concurrent.futures.as_completed(future_to_index):
|
|
result = future.result()
|
|
results.append(result)
|
|
pbar.update(1)
|
|
|
|
results.sort(key=lambda x: x['index'])
|
|
|
|
metric = compute_metrics(results)
|
|
|
|
return metric, results
|
|
|
|
def main():
|
|
filepath = '/home/ubuntu/50T/fsy/benchmark-dataset-third/ALL-merge/merged.json'
|
|
data = load_json_data(filepath)
|
|
max_workers = 8
|
|
|
|
metric, results = calculate_accuracy_multithreaded(data,max_workers)
|
|
print(f"Accuracy of qwen-max-2025-01-25: {metric}")
|
|
|
|
with open('qwen-max-2025-01-25.json', 'w') as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|