92 lines
3.1 KiB
Python
Executable File
92 lines
3.1 KiB
Python
Executable File
import jsonlines
|
|
import argparse
|
|
import generate_data.utils as utils
|
|
import glob
|
|
import json
|
|
from ase import io
|
|
import tempfile
|
|
import re
|
|
from pymatgen.io.vasp import Poscar
|
|
from pymatgen.io.cif import CifParser
|
|
import threading
|
|
import concurrent.futures
|
|
import copy
|
|
from grpo_utils import generate_design_question, generate_props_question, generate_obs_response
|
|
# Create a lock for file writing
|
|
file_lock = threading.Lock()
|
|
|
|
|
|
def worker(data, output_file_path):
|
|
try:
|
|
messages = copy.deepcopy(data['messages'])
|
|
obs = data['observation']
|
|
messages[-1]['content'] = messages[-1]['content'].split("<answer>")[-1].split("</answer>")[0]
|
|
messages.append({"role": "user", "content": obs})
|
|
data['messages'].append({"role": "user", "content": obs})
|
|
# print(messages)
|
|
# print(obs)
|
|
|
|
reasoning_content, response = generate_obs_response(messages)
|
|
data['messages'].append({"role": "assistant", "content": f"<think>\n{reasoning_content}</think>\n<answer>\n{response}</answer>\n"})
|
|
# Use the lock to safely write to the file
|
|
with file_lock:
|
|
with jsonlines.open(output_file_path, mode='a') as writer:
|
|
writer.write(messages)
|
|
|
|
return f"Processed successfully"
|
|
except Exception as e:
|
|
return f"Error processing: {str(e)}"
|
|
|
|
|
|
def main(input_file_path, output_file_path, max_workers=1):
|
|
import random
|
|
from tqdm import tqdm
|
|
import os
|
|
|
|
datas = None
|
|
with jsonlines.open(input_file_path, mode='r') as reader:
|
|
datas = [line for line in reader]
|
|
|
|
# 创建进度条
|
|
pbar = tqdm(total=len(datas), desc="Processing CIF files")
|
|
|
|
# 创建一个线程池
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
# 提交任务到执行器
|
|
future_to_data = {}
|
|
for data in datas:
|
|
future = executor.submit(worker, data, output_file_path)
|
|
future_to_data[future] = data
|
|
|
|
# 处理结果
|
|
completed = 0
|
|
failed = 0
|
|
for future in concurrent.futures.as_completed(future_to_data):
|
|
data = future_to_data[future]
|
|
try:
|
|
result = future.result()
|
|
if "successfully" in result:
|
|
completed += 1
|
|
else:
|
|
failed += 1
|
|
# 更新进度条
|
|
pbar.update(1)
|
|
# 每100个文件更新一次统计信息
|
|
if (completed + failed) % 100 == 0:
|
|
pbar.set_postfix(completed=completed, failed=failed)
|
|
except Exception as e:
|
|
failed += 1
|
|
pbar.update(1)
|
|
print(f"\nWorker for {data} generated an exception: {e}")
|
|
|
|
pbar.close()
|
|
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import datetime
|
|
origin_file = "/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl"
|
|
output_file = f"agent_questions_solutions_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
|
main(origin_file, output_file)
|
|
|