Files
mars-mcp/generate_data/grpo_tools.py
2025-04-16 11:15:01 +08:00

92 lines
3.1 KiB
Python
Executable File

import jsonlines
import argparse
import generate_data.utils as utils
import glob
import json
from ase import io
import tempfile
import re
from pymatgen.io.vasp import Poscar
from pymatgen.io.cif import CifParser
import threading
import concurrent.futures
import copy
from grpo_utils import generate_design_question, generate_props_question, generate_obs_response
# Create a lock for file writing
file_lock = threading.Lock()
def worker(data, output_file_path):
try:
messages = copy.deepcopy(data['messages'])
obs = data['observation']
messages[-1]['content'] = messages[-1]['content'].split("<answer>")[-1].split("</answer>")[0]
messages.append({"role": "user", "content": obs})
data['messages'].append({"role": "user", "content": obs})
# print(messages)
# print(obs)
reasoning_content, response = generate_obs_response(messages)
data['messages'].append({"role": "assistant", "content": f"<think>\n{reasoning_content}</think>\n<answer>\n{response}</answer>\n"})
# Use the lock to safely write to the file
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(messages)
return f"Processed successfully"
except Exception as e:
return f"Error processing: {str(e)}"
def main(input_file_path, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
datas = None
with jsonlines.open(input_file_path, mode='r') as reader:
datas = [line for line in reader]
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing CIF files")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_data = {}
for data in datas:
future = executor.submit(worker, data, output_file_path)
future_to_data[future] = data
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_data):
data = future_to_data[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {data} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
origin_file = "/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl"
output_file = f"agent_questions_solutions_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(origin_file, output_file)