生成sft数据,设置OQMD的代理,测试mars-t1
This commit is contained in:
218
generate_data/generate_sft_data/generate_llms_ans_multiturn.py
Normal file
218
generate_data/generate_sft_data/generate_llms_ans_multiturn.py
Normal file
@@ -0,0 +1,218 @@
|
||||
|
||||
import jsonlines
|
||||
import argparse
|
||||
import generate_data.generate_sft_data.utils as utils
|
||||
import glob
|
||||
import json
|
||||
from ase import io
|
||||
import tempfile
|
||||
import re
|
||||
from pymatgen.io.vasp import Poscar
|
||||
from pymatgen.io.cif import CifParser
|
||||
import threading
|
||||
import concurrent.futures
|
||||
import copy
|
||||
from generate_data.generate_sft_data.sft_utils import generate_design_question, generate_props_question#, generate_obs_response
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/mars-mcp/')
|
||||
|
||||
from mars_toolkit import get_tool_schemas
|
||||
tools=get_tool_schemas()
|
||||
tools_description=''
|
||||
tools_description+='\n'.join (json.dumps(tool) for tool in tools)
|
||||
# Create a lock for file writing
|
||||
file_lock = threading.Lock()
|
||||
|
||||
|
||||
def generate_deepseek_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1):
|
||||
messages_ = copy.deepcopy(messages)
|
||||
system_prompt = '''
|
||||
你是一个有用的AI助手,可以使用以下工具来帮助用户完成任务。
|
||||
可用工具:
|
||||
{tools_description},
|
||||
当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考,
|
||||
并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案;
|
||||
若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
|
||||
'''
|
||||
messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)})
|
||||
|
||||
|
||||
# instrument='''
|
||||
# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考,
|
||||
# 并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案;
|
||||
# 若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
|
||||
# '''
|
||||
# messages_.append({"role": "user", "content": instrument})
|
||||
_reasoning_content, response = utils.get_response_from_deepseek_r1(
|
||||
messages_,prefix=False,max_retries=max_retries,initial_backoff=initial_backoff)
|
||||
return _reasoning_content, response
|
||||
def generate_qwq_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1):
|
||||
messages_ = copy.deepcopy(messages)
|
||||
system_prompt = '''
|
||||
你是一个有用的AI助手,可以使用以下工具来帮助用户完成任务。
|
||||
可用工具:
|
||||
{tools_description}.
|
||||
当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考,工具调用失败的原因可能是因为你写的工具调用方式不规范导致调用时解析错误或者是网络原因等等。
|
||||
并自行判断这些工具调用的结果能否让你直接给出答案,情况1.如果能给出答案,请在回答中直接给出答案;
|
||||
情况2.若无法给出答案,你将再拥有一次调用多个工具获取信息的机会,请在思考过程中尽可能地调用多个提供给你的工具查询该问题的相关信息(之前的工具调用如果调用失败你也可以自行决定要不要重新调用),而不是直接回答该问题。
|
||||
情况1.的回答内容为思考内容以及答案,情况2.的回答内容为思考内容以及每一个工具调用的名称和参数(不要省略具体的参数内容,因为这会导致无法解析参数)。
|
||||
思考和回答时使用和问题相同的语言。
|
||||
'''
|
||||
messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)})
|
||||
|
||||
|
||||
# instrument='''
|
||||
# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考,
|
||||
# 并自行判断这些工具调用的结果能否让你直接给出答案,情况1.如果能给出答案,请在回答中直接给出答案;
|
||||
# 情况2.若无法给出答案,你将拥有一次调用多个工具获取信息的机会,在思考过程中尽可能地描述如何通过调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题,因此,在这种情况下你在回答中只需要一次给出多个工具调用,不需要额外的文字描述。
|
||||
|
||||
# '''
|
||||
# messages_.append({"role": "assistant", "content": instrument})
|
||||
think_answer,tool_info=utils.get_response_from_qwq(messages_,model_name='qwq-32b',tools=tools,max_retries=max_retries,initial_backoff=initial_backoff)
|
||||
return think_answer,tool_info
|
||||
from .filter_messages import is_answer_none,is_data_with_last_obs
|
||||
|
||||
def worker(data, output_file_path):
|
||||
if is_data_with_last_obs(data):
|
||||
try:
|
||||
messages = copy.deepcopy(data['messages'])
|
||||
|
||||
# print(messages)
|
||||
# print(obs)
|
||||
retry=0
|
||||
think_answer,tool_info = generate_qwq_obs_response(messages,tools_description)
|
||||
|
||||
#print("reasoning_content",reasoning_content)
|
||||
#print("response",response)
|
||||
data['messages'].append({"role": "assistant", "content": think_answer})
|
||||
while is_answer_none(data) and retry<5:
|
||||
# 如果答案为空,则需要重新调用工具
|
||||
retry+=1
|
||||
# 重新生成observation
|
||||
# print("数据中最后一条消息为observation且为空")
|
||||
# print("retry",retry)
|
||||
# 重新生成observation
|
||||
messages = copy.deepcopy(data['messages'].pop())
|
||||
think_answer,tool_info = generate_qwq_obs_response(messages,tools_description,max_retries=5)
|
||||
data['messages'].append({"role": "assistant", "content": think_answer})
|
||||
|
||||
|
||||
# Use the lock to safely write to the file
|
||||
data['observation']=''
|
||||
data['function_calls']=''
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(data)
|
||||
|
||||
return f"Processed successfully"
|
||||
except Exception as e:
|
||||
return f"Error processing: {str(e)}"
|
||||
else:
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(data)
|
||||
return f"Processed successfully"
|
||||
def main(input_file_path, output_file_path, max_workers=1):
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
datas = None
|
||||
with jsonlines.open(input_file_path, mode='r') as reader:
|
||||
datas = [line for line in reader]
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(datas), desc="Processing CIF files")
|
||||
|
||||
# 创建一个线程池
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交任务到执行器
|
||||
future_to_data = {}
|
||||
for data in datas:
|
||||
future = executor.submit(worker, data, output_file_path)
|
||||
future_to_data[future] = data
|
||||
|
||||
# 处理结果
|
||||
completed = 0
|
||||
failed = 0
|
||||
for future in concurrent.futures.as_completed(future_to_data):
|
||||
data = future_to_data[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if "successfully" in result:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
# 每100个文件更新一次统计信息
|
||||
if (completed + failed) % 100 == 0:
|
||||
pbar.set_postfix(completed=completed, failed=failed)
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
pbar.update(1)
|
||||
print(f"\nWorker for {data} generated an exception: {e}")
|
||||
|
||||
pbar.close()
|
||||
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||
|
||||
|
||||
def main_for_datas(datas, output_file_path, max_workers=1):
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
#datas = None
|
||||
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(datas), desc="Processing CIF files")
|
||||
|
||||
# 创建一个线程池
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交任务到执行器
|
||||
future_to_data = {}
|
||||
for data in datas:
|
||||
future = executor.submit(worker, data, output_file_path)
|
||||
future_to_data[future] = data
|
||||
|
||||
# 处理结果
|
||||
completed = 0
|
||||
failed = 0
|
||||
for future in concurrent.futures.as_completed(future_to_data):
|
||||
data = future_to_data[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if "successfully" in result:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
# 每100个文件更新一次统计信息
|
||||
if (completed + failed) % 100 == 0:
|
||||
pbar.set_postfix(completed=completed, failed=failed)
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
pbar.update(1)
|
||||
print(f"\nWorker for {data} generated an exception: {e}")
|
||||
|
||||
pbar.close()
|
||||
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
import datetime
|
||||
origin_file = "/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_obs.jsonl"
|
||||
#output_file = f"/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_turn2_ans_no_none1{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||
|
||||
output_file='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans.jsonl'
|
||||
main(origin_file, output_file,max_workers=36)
|
||||
print("output_file",output_file)
|
||||
output_datas = utils.read_jsonline_file(output_file)
|
||||
c=0
|
||||
print("len output_datas",len(output_datas))
|
||||
for data in output_datas:
|
||||
if is_answer_none(data):
|
||||
#print("data len",len(data['messages']))
|
||||
c+=1
|
||||
print("answer none number",c)
|
||||
Reference in New Issue
Block a user