layer2 commit
This commit is contained in:
64
layer2/process/prompts.py
Normal file
64
layer2/process/prompts.py
Normal file
@@ -0,0 +1,64 @@
|
||||
'''
|
||||
保留计算题的计算过程:- Fully preserve the step-by-step calculation process along with the final results
|
||||
只保留计算题的结果:- Preserve final calculation results
|
||||
'''
|
||||
SINGLE_QUESTION_PROMPTS="""
|
||||
Follow these instructions strictly to perform question decomposition:
|
||||
Input requirements:
|
||||
- Question text: {question}
|
||||
- Answer text: {answer}
|
||||
Output rules:
|
||||
1. Single issue determination criteria:
|
||||
- Question contains only one clear technical inquiry point
|
||||
- Answer content cannot be divided into independent parts
|
||||
→ Return: "It's a single issue."
|
||||
2. Compound question decomposition criteria (must satisfy all):
|
||||
a) Question contains multiple technically independent sub-questions
|
||||
b) Answer contains independent solution paragraphs corresponding to sub-questions
|
||||
c) Each sub-question's answer does not depend on context from other sub-questions
|
||||
3. Decomposition format standards:
|
||||
[
|
||||
{{
|
||||
"question": "[Complete sub-question 1] (including necessary shared parameters)",
|
||||
"answer": "[Corresponding complete answer]"
|
||||
}},
|
||||
{{
|
||||
"question": "[Complete sub-question 2] (including necessary shared parameters)",
|
||||
"answer": "[Corresponding complete answer]"
|
||||
}},
|
||||
......
|
||||
]
|
||||
Key control points:
|
||||
1. Context integrity:
|
||||
- Each sub-question must include shared parameters from the original question
|
||||
2. Answer integrity:
|
||||
- Fully preserve the step-by-step calculation process along with the final results
|
||||
- Maintain original units and precision (e.g., 6.02×10²³ cannot be simplified to 6.02e23)
|
||||
3.
|
||||
|
||||
3. Format prohibitions:
|
||||
- No explanatory text additions
|
||||
- No modifications to original technical terminology
|
||||
- Return data must not use Markdown and Latex formats (like \times, \mathrm)
|
||||
- Use scientific notation for data representation
|
||||
"""
|
||||
|
||||
QA_TYPE_PROMPTS="""
|
||||
Please analyze the following question and its answer, and classify the question type into one of the following four categories:
|
||||
|
||||
1. Calculation: A question that requires mathematical operations to derive the result.
|
||||
2. Multiple choice: A question that provides multiple options (e.g., A/B/C/D) for the respondent to choose from.
|
||||
3. True/False: A question that only requires answering true/false, yes/no, or correct/incorrect.
|
||||
4. Other: A question that does not fall under the above three categories.
|
||||
|
||||
Question:
|
||||
{question}
|
||||
Answer:
|
||||
{answer}
|
||||
|
||||
Please respond with the corresponding numeric code directly (without any explanation):
|
||||
1. For Calculation, respond: 1
|
||||
2. For Multiple choice, respond: 2
|
||||
3. For True/False, respond: 3
|
||||
4. For Other, respond: 4
|
||||
"""
|
||||
41
layer2/process/step0.py
Normal file
41
layer2/process/step0.py
Normal file
@@ -0,0 +1,41 @@
|
||||
#step0: 将文本拆分成问题和答案两部分
|
||||
import json
|
||||
|
||||
input_file_path = '/home/ubuntu/50T/fsy/benchmark/dataset-ours/[Solution]qa_segment_all.json'
|
||||
with open(input_file_path, 'r', encoding='utf-8') as infile:
|
||||
data = json.load(infile)
|
||||
|
||||
# 遍历并处理数据
|
||||
processed_data = []
|
||||
for item in data:
|
||||
segment = item.get("segment", "")
|
||||
if "Solution" in segment:
|
||||
question, answer = segment.split("Solution", 1)
|
||||
question = question.strip()
|
||||
answer = answer.strip()
|
||||
processed_data.append({
|
||||
"idx": item.get("idx"),
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
})
|
||||
elif "Answer" in segment:
|
||||
question, answer = segment.split("Answer", 1)
|
||||
question = question.strip()
|
||||
answer = answer.strip()
|
||||
processed_data.append({
|
||||
"idx": item.get("idx"),
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
})
|
||||
else:
|
||||
processed_data.append({
|
||||
"idx": item.get("idx"),
|
||||
"question": "000",
|
||||
"answer": "000",
|
||||
})
|
||||
|
||||
output_file_path = '[Solution]qa_segment.json'
|
||||
with open(output_file_path, 'w', encoding='utf-8') as outfile:
|
||||
json.dump(processed_data, outfile, ensure_ascii=False, indent=4)
|
||||
|
||||
output_file_path
|
||||
132
layer2/process/step1and2.py
Normal file
132
layer2/process/step1and2.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
对821道英文问题进行处理
|
||||
1. 判断是否包含多个子问题,将问题拆分为完整子问题(去掉推理过程,只保留最后结果)
|
||||
2. 判断题目类型
|
||||
3. 将题目做成选择题
|
||||
对计算题,在数值附近随机生成三个相似答案作为错误选项
|
||||
对简答题,与标准答案最相近的其他问题的答案作为三个错误选项
|
||||
4. 将正确和错误选项随机打乱,生成ABCD选择题的模型
|
||||
5. 添加prompt,并将选择题包裹在[ANSWER]<answer>[/ANSWER]标签中
|
||||
6. 模型打分
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from openai import OpenAI
|
||||
import re
|
||||
from prompts import SINGLE_QUESTION_PROMPTS, QA_TYPE_PROMPTS, ONLY_ANSWER_PROMPTS
|
||||
|
||||
API_KEY="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
|
||||
BASE_URL="https://vip.apiyi.com/v1"
|
||||
MODEL_DEEPSEEK_V3="deepseek-chat"
|
||||
|
||||
def load_data(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
def process_response(response):
|
||||
"""Extract and parse JSON from a response."""
|
||||
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
|
||||
json_str = json_match.group(1) if json_match else response.strip()
|
||||
json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str)
|
||||
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
|
||||
return json_str
|
||||
|
||||
def save_data(data, output_file):
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def split_complex_question(question, answer):
|
||||
client = OpenAI(api_key = API_KEY,base_url = BASE_URL)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model= MODEL_DEEPSEEK_V3,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are an expert in decomposing complex technical questions into independent sub-questions and providing corresponding complete answers with preserved context, precision, and technical terminology. "},
|
||||
{"role": "user", "content": SINGLE_QUESTION_PROMPTS.replace("{question}",question).replace("{answer}",answer)}
|
||||
],
|
||||
stream = False,
|
||||
temperature = 0
|
||||
)
|
||||
result = response.choices[0].message.content.strip()
|
||||
# print(result)
|
||||
return 1 if "It's a single issue." in result else json.loads(process_response(result))
|
||||
except Exception as e:
|
||||
print(f"API调用错误: {e}")
|
||||
return [{"question": question, "answer": answer}]
|
||||
|
||||
def single_question_process(data):
|
||||
single_question_data = []
|
||||
total = len(data)
|
||||
for i, item in enumerate(data):
|
||||
print(f"处理第 {i+1}/{total} 条数据...")
|
||||
question = item["question"]
|
||||
answer = item["answer"]
|
||||
split_data = split_complex_question(question, answer)
|
||||
|
||||
if isinstance(split_data, list):
|
||||
for q_data in split_data:
|
||||
single_question_data.append({
|
||||
"idx":item["idx"],
|
||||
"question": q_data["question"],
|
||||
"answer": q_data["answer"]
|
||||
})
|
||||
else:
|
||||
single_question_data.append({
|
||||
"idx":item["idx"],
|
||||
"question": question,
|
||||
"answer": answer
|
||||
})
|
||||
|
||||
if (i+1) % 10 == 0:
|
||||
time.sleep(2)
|
||||
return single_question_data
|
||||
|
||||
def classify_qa_type(question, answer):
|
||||
client = OpenAI(api_key = API_KEY,base_url = BASE_URL)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model = MODEL_DEEPSEEK_V3,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": QA_TYPE_PROMPTS.replace("{question}",question).replace("{answer}",answer)}
|
||||
],
|
||||
stream=False
|
||||
)
|
||||
result = response.choices[0].message.content.strip().lower()
|
||||
# print(result)
|
||||
return {"1": "Calculation", "2": "Multiple choice", "3": "True/False"}.get(result, "Other")
|
||||
except Exception as e:
|
||||
print(f"API调用错误: {e}")
|
||||
return "Other"
|
||||
|
||||
def qa_type_process(data):
|
||||
total = len(data)
|
||||
for i, item in enumerate(data):
|
||||
print(f"处理第 {i+1}/{total} 条数据...")
|
||||
question = item["question"]
|
||||
answer = item["answer"]
|
||||
label = classify_qa_type(question, answer)
|
||||
item["type"] = label
|
||||
|
||||
if (i+1) % 10 == 0:
|
||||
time.sleep(2)
|
||||
return data
|
||||
|
||||
def main():
|
||||
input_file = "/home/ubuntu/50T/fsy/layer2/QA/code/821.json"
|
||||
output_file = "/home/ubuntu/50T/fsy/layer2/QA/code/processed_data.json"
|
||||
data = load_data(input_file)
|
||||
|
||||
# step:1
|
||||
single_question_data = single_question_process(data)
|
||||
# step:2
|
||||
qa_type_data = qa_type_process(single_question_data)
|
||||
# step:3
|
||||
|
||||
|
||||
# save_data(processed_data, output_file)
|
||||
print(f"处理完成,结果已保存到 {output_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
layer2/process/step3.py
Normal file
89
layer2/process/step3.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
对821道英文问题进行处理
|
||||
1. 判断是否包含多个子问题,将问题拆分为完整子问题(去掉推理过程,只保留最后结果)
|
||||
2. 判断题目类型
|
||||
3. 将题目做成选择题
|
||||
对计算题,在数值附近随机生成三个相似答案作为错误选项
|
||||
对简答题,与标准答案最相近的其他问题的答案作为三个错误选项
|
||||
4. 将正确和错误选项随机打乱,生成ABCD选择题的模型
|
||||
5. 添加prompt,并将选择题包裹在[ANSWER]<answer>[/ANSWER]标签中
|
||||
6. 模型打分
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import random
|
||||
import copy
|
||||
|
||||
def generate_wrong_answers(json_file_path):
|
||||
# 读取 JSON 文件
|
||||
with open(json_file_path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
|
||||
# 处理每个数据项
|
||||
for item in data:
|
||||
if item['type'] == 1: # 判断是否为计算题
|
||||
answer = item['answer']
|
||||
if any(char.isdigit() for char in answer):
|
||||
wrong_answers = []
|
||||
for _ in range(3):
|
||||
wrong_answers.append(generate_wrong_answer(answer))
|
||||
item['wrong_answers_1'] = wrong_answers[0]
|
||||
item['wrong_answers_2'] = wrong_answers[1]
|
||||
item['wrong_answers_3'] = wrong_answers[2]
|
||||
|
||||
with open(json_file_path.replace('.json', '_with_wrong_answers.json'), 'w', encoding='utf-8') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=2)
|
||||
|
||||
return data
|
||||
|
||||
def generate_wrong_answer(correct_answer):
|
||||
# 强化版正则表达式:支持普通数、科学计数法、Unicode负号、LaTeX指数、千位逗号
|
||||
number_pattern = (
|
||||
r'([-+]?\d{1,3}(?:,\d{3})*(?:\.\d+)?|\d*\.?\d+)' # 主数字部分
|
||||
r'(?:\s*[×x*]?\s*10(?:\^|\^{|{)?[-−⁻]?\d+(?:\})?)?' # 科学计数部分,可选
|
||||
)
|
||||
matches = list(re.finditer(number_pattern, correct_answer, re.IGNORECASE))
|
||||
if not matches:
|
||||
return correct_answer # 没找到数字,返回原文
|
||||
|
||||
wrong_answer = correct_answer
|
||||
for match in matches[::-1]: # 反向替换防止位置偏移
|
||||
full_match = match.group(0)
|
||||
base = match.group(1).replace(',', '') # 去除逗号用于数值运算
|
||||
|
||||
try:
|
||||
# 转换成 float
|
||||
base_value = float(base)
|
||||
perturbed_value = perturb_number(base_value)
|
||||
|
||||
# 保留原来的指数部分(如 x 10^6),只替换数字
|
||||
wrong_value_str = full_match.replace(match.group(1), format_similar(base, perturbed_value))
|
||||
start, end = match.span()
|
||||
wrong_answer = wrong_answer[:start] + wrong_value_str + wrong_answer[end:]
|
||||
except:
|
||||
continue
|
||||
|
||||
return wrong_answer
|
||||
|
||||
def perturb_number(value):
|
||||
# 根据数量级添加扰动(高斯扰动 + 偏差)
|
||||
magnitude = abs(value)
|
||||
noise = random.uniform(0.03, 0.15) # 扰动比例 3%~15%
|
||||
direction = random.choice([-1, 1])
|
||||
new_value = value + direction * magnitude * noise
|
||||
|
||||
# 防止扰动结果为 0 或变号
|
||||
if abs(new_value) < 1e-10:
|
||||
new_value = value * 1.1
|
||||
return new_value
|
||||
|
||||
def format_similar(original_str, value):
|
||||
# 保留与原始字符串小数位一致
|
||||
if '.' in original_str:
|
||||
decimal_places = len(original_str.split('.')[-1].rstrip('^}')) # 忽略 ^10^6 中的后缀
|
||||
return f"{value:.{decimal_places}f}"
|
||||
else:
|
||||
return str(int(round(value)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = generate_wrong_answers('/home/ubuntu/50T/fsy/benchmark/4is_type.json')
|
||||
78
layer2/process/step4.py
Normal file
78
layer2/process/step4.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
对821道英文问题进行处理
|
||||
1. 判断是否包含多个子问题,将问题拆分为完整子问题(去掉推理过程,只保留最后结果)
|
||||
2. 判断题目类型
|
||||
3. 将题目做成选择题
|
||||
对计算题,在数值附近随机生成三个相似答案作为错误选项
|
||||
对简答题,与标准答案最相近的其他问题的答案作为三个错误选项
|
||||
4. 将正确和错误选项随机打乱,生成ABCD选择题的模型
|
||||
5. 添加prompt,并将选择题包裹在[ANSWER]<answer>[/ANSWER]标签中
|
||||
6. 模型打分
|
||||
"""
|
||||
import json
|
||||
import random
|
||||
from typing import List, Dict
|
||||
|
||||
def process_json_file(file_path: str) -> List[Dict]:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
for item in data:
|
||||
# 收集所有选项
|
||||
options = [
|
||||
item['answer'],
|
||||
item.get('wrong_answers_1', ''),
|
||||
item.get('wrong_answers_2', ''),
|
||||
item.get('wrong_answers_3', '')
|
||||
]
|
||||
|
||||
# 过滤掉空选项
|
||||
options = [opt for opt in options if opt]
|
||||
|
||||
# 打乱选项
|
||||
random.shuffle(options)
|
||||
|
||||
# 找出正确答案的位置
|
||||
correct_answer_index = options.index(item['answer'])
|
||||
correct_answer_letter = chr(65 + correct_answer_index) # A, B, C, or D
|
||||
|
||||
# 构建选项文本
|
||||
options_text = ""
|
||||
for i, option in enumerate(options):
|
||||
letter = chr(65 + i) # A, B, C, or D
|
||||
options_text += f"({letter}){option}"
|
||||
if i < len(options) - 1:
|
||||
options_text += " "
|
||||
|
||||
# 更新问题和答案
|
||||
item['question'] = f"{"The following is a question about Fundamentals of Materials Science"}{item['question']} {options_text}{"You MUST include the letter(s) of the correct answer (separated by comma if there are many) within the following tags: [ANSWER] and [/ANSWER].\nFor example, '[ANSWER]<answer>[/ANSWER]', where <answer> is comma- or space-separated list of the correct letters. Always answer in exactly this format of comma-separated letters between the two tags, even if you are unsure. We require this because we use automatic parsing."}"
|
||||
item['answer'] = f"[ANSWER]{correct_answer_letter}[/ANSWER]"
|
||||
|
||||
# 删除原始的错误选项
|
||||
if 'wrong_answers_1' in item:
|
||||
del item['wrong_answers_1']
|
||||
if 'wrong_answers_2' in item:
|
||||
del item['wrong_answers_2']
|
||||
if 'wrong_answers_3' in item:
|
||||
del item['wrong_answers_3']
|
||||
|
||||
return data
|
||||
|
||||
def save_processed_data(data: List[Dict], output_path: str) -> None:
|
||||
"""
|
||||
保存处理后的数据到新的JSON文件
|
||||
"""
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
input_file = "/home/ubuntu/50T/fsy/5_1.json" # 替换为你的输入文件路径
|
||||
output_file = "output.json" # 替换为你想要的输出文件路径
|
||||
|
||||
try:
|
||||
processed_data = process_json_file(input_file)
|
||||
save_processed_data(processed_data, output_file)
|
||||
print(f"处理完成!结果已保存到 {output_file}")
|
||||
except Exception as e:
|
||||
print(f"处理过程中出现错误: {e}")
|
||||
177
layer2/process/step6.py
Normal file
177
layer2/process/step6.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
对821道英文问题进行处理
|
||||
1. 判断是否包含多个子问题,将问题拆分为完整子问题(去掉推理过程,只保留最后结果)
|
||||
2. 判断题目类型
|
||||
3. 将题目做成选择题
|
||||
对计算题,在数值附近随机生成三个相似答案作为错误选项
|
||||
对简答题,与标准答案最相近的其他问题的答案作为三个错误选项
|
||||
4. 将正确和错误选项随机打乱,生成ABCD选择题的模型
|
||||
5. 添加prompt,并将选择题包裹在[ANSWER]<answer>[/ANSWER]标签中
|
||||
6. 模型打分
|
||||
"""
|
||||
import json
|
||||
import threading
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
from openai import OpenAI
|
||||
import numpy as np
|
||||
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||
import re
|
||||
|
||||
client = OpenAI(
|
||||
api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d",
|
||||
base_url="https://vip.apiyi.com/v1"
|
||||
)
|
||||
|
||||
thread_lock = threading.Lock()
|
||||
|
||||
def load_json_data(filepath):
|
||||
with open(filepath, 'r') as file:
|
||||
data = json.load(file)
|
||||
return data
|
||||
|
||||
def get_response(input,max_retries=10):
|
||||
retries = 0
|
||||
while retries<max_retries:
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
#
|
||||
model="qwen-max-2025-01-25",
|
||||
messages= [
|
||||
{"role": "system", "content": "You are an expert in the field of materials science, adept at answering questions related to fundamental aspects of materials science, including material structure, properties, processing, and applications."},
|
||||
{"role": "user", "content": input}
|
||||
],
|
||||
temperature=0
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
return answer
|
||||
except Exception as e:
|
||||
print(f"Error in getting LLM response (Attempt {retries + 1}/{max_retries}): {e}")
|
||||
retries += 1
|
||||
|
||||
print(f"Failed to get response after {max_retries} attempts, returning None.")
|
||||
return "error!"
|
||||
|
||||
def process_item(item, index):
|
||||
question = item['question']
|
||||
text = item['choices']['text']
|
||||
label = item['choices']['label']
|
||||
prompt = item['prompt']
|
||||
expected_answer = item['answer'].strip()
|
||||
|
||||
formatted_choices = " ".join([f"({label}) {text}" for label, text in zip(label, text)])
|
||||
input = f"{question} {formatted_choices}. {prompt}"
|
||||
|
||||
llm_answer = get_response(input)
|
||||
|
||||
return {
|
||||
'index': index,
|
||||
'question': question,
|
||||
'choices': item['choices'],
|
||||
'answer': expected_answer,
|
||||
'llm_answer': llm_answer
|
||||
}
|
||||
|
||||
def extract_answer(answer_string):
|
||||
match = re.search(r'\[ANSWER\](.*?)\[/ANSWER\]', answer_string)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
|
||||
def parse_answer(answer):
|
||||
if answer is None:
|
||||
return []
|
||||
return [a.strip() for a in answer.split(',')]
|
||||
|
||||
def compute_metrics(data):
|
||||
|
||||
true_answers = []
|
||||
pred_answers = []
|
||||
|
||||
for item in data:
|
||||
true_ans = extract_answer(item["answer"])
|
||||
pred_ans = extract_answer(item["llm_answer"])
|
||||
|
||||
true_answers.append(parse_answer(true_ans))
|
||||
pred_answers.append(parse_answer(pred_ans))
|
||||
|
||||
correct_counts = []
|
||||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||||
if true_ans and pred_ans and set(true_ans) == set(pred_ans):
|
||||
correct_counts.append(1)
|
||||
else:
|
||||
correct_counts.append(0)
|
||||
|
||||
accuracy = np.mean(correct_counts)
|
||||
|
||||
y_true_multi = []
|
||||
y_pred_multi = []
|
||||
all_labels = set()
|
||||
|
||||
for item in data:
|
||||
choices = item["choices"]["label"]
|
||||
for label in choices:
|
||||
all_labels.add(label)
|
||||
|
||||
all_labels = sorted(list(all_labels))
|
||||
|
||||
for true_ans, pred_ans in zip(true_answers, pred_answers):
|
||||
true_vector = [1 if label in true_ans else 0 for label in all_labels]
|
||||
pred_vector = [1 if label in pred_ans else 0 for label in all_labels]
|
||||
y_true_multi.append(true_vector)
|
||||
y_pred_multi.append(pred_vector)
|
||||
|
||||
y_true_multi = np.array(y_true_multi)
|
||||
y_pred_multi = np.array(y_pred_multi)
|
||||
|
||||
precision_micro = precision_score(y_true_multi, y_pred_multi, average='micro', zero_division=0)
|
||||
recall_micro = recall_score(y_true_multi, y_pred_multi, average='micro', zero_division=0)
|
||||
f1_micro = f1_score(y_true_multi, y_pred_multi, average='micro', zero_division=0)
|
||||
|
||||
precision_macro = precision_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
||||
recall_macro = recall_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
||||
f1_macro = f1_score(y_true_multi, y_pred_multi, average='macro', zero_division=0)
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"precision_micro": precision_micro,
|
||||
"recall_micro": recall_micro,
|
||||
"f1_micro": f1_micro,
|
||||
"precision_macro": precision_macro,
|
||||
"recall_macro": recall_macro,
|
||||
"f1_macro": f1_macro
|
||||
}
|
||||
|
||||
def calculate_accuracy_multithreaded(data, max_workers=5):
|
||||
results = []
|
||||
|
||||
with tqdm(total=len(data), desc="Processing items") as pbar:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
|
||||
future_to_index = {executor.submit(process_item, item, i): i for i, item in enumerate(data)}
|
||||
|
||||
for future in concurrent.futures.as_completed(future_to_index):
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
results.sort(key=lambda x: x['index'])
|
||||
|
||||
metric = compute_metrics(results)
|
||||
|
||||
return metric, results
|
||||
|
||||
def main():
|
||||
filepath = '/home/ubuntu/50T/fsy/benchmark-dataset-third/ALL-merge/merged.json'
|
||||
data = load_json_data(filepath)
|
||||
max_workers = 8
|
||||
|
||||
metric, results = calculate_accuracy_multithreaded(data,max_workers)
|
||||
print(f"Accuracy of qwen-max-2025-01-25: {metric}")
|
||||
|
||||
with open('qwen-max-2025-01-25.json', 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user