Files
mars-mcp/generate_data/generate_sft_data/generate_llms_ans_multiturn.py

219 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import jsonlines
import argparse
import generate_data.generate_sft_data.utils as utils
import glob
import json
from ase import io
import tempfile
import re
from pymatgen.io.vasp import Poscar
from pymatgen.io.cif import CifParser
import threading
import concurrent.futures
import copy
from generate_data.generate_sft_data.sft_utils import generate_design_question, generate_props_question#, generate_obs_response
import sys
sys.path.append('/home/ubuntu/sas0/lzy/mars-mcp/')
from mars_toolkit import get_tool_schemas
tools=get_tool_schemas()
tools_description=''
tools_description+='\n'.join (json.dumps(tool) for tool in tools)
# Create a lock for file writing
file_lock = threading.Lock()
def generate_deepseek_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1):
messages_ = copy.deepcopy(messages)
system_prompt = '''
你是一个有用的AI助手可以使用以下工具来帮助用户完成任务。
可用工具:
{tools_description}
当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考,
并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案;
若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
'''
messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)})
# instrument='''
# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考,
# 并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案;
# 若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
# '''
# messages_.append({"role": "user", "content": instrument})
_reasoning_content, response = utils.get_response_from_deepseek_r1(
messages_,prefix=False,max_retries=max_retries,initial_backoff=initial_backoff)
return _reasoning_content, response
def generate_qwq_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1):
messages_ = copy.deepcopy(messages)
system_prompt = '''
你是一个有用的AI助手可以使用以下工具来帮助用户完成任务。
可用工具:
{tools_description}.
当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考,工具调用失败的原因可能是因为你写的工具调用方式不规范导致调用时解析错误或者是网络原因等等。
并自行判断这些工具调用的结果能否让你直接给出答案情况1.如果能给出答案,请在回答中直接给出答案;
情况2.若无法给出答案,你将再拥有一次调用多个工具获取信息的机会,请在思考过程中尽可能地调用多个提供给你的工具查询该问题的相关信息(之前的工具调用如果调用失败你也可以自行决定要不要重新调用),而不是直接回答该问题。
情况1.的回答内容为思考内容以及答案情况2.的回答内容为思考内容以及每一个工具调用的名称和参数(不要省略具体的参数内容,因为这会导致无法解析参数)。
思考和回答时使用和问题相同的语言。
'''
messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)})
# instrument='''
# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考,
# 并自行判断这些工具调用的结果能否让你直接给出答案情况1.如果能给出答案,请在回答中直接给出答案;
# 情况2.若无法给出答案,你将拥有一次调用多个工具获取信息的机会,在思考过程中尽可能地描述如何通过调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题,因此,在这种情况下你在回答中只需要一次给出多个工具调用,不需要额外的文字描述。
# '''
# messages_.append({"role": "assistant", "content": instrument})
think_answer,tool_info=utils.get_response_from_qwq(messages_,model_name='qwq-32b',tools=tools,max_retries=max_retries,initial_backoff=initial_backoff)
return think_answer,tool_info
from .filter_messages import is_answer_none,is_data_with_last_obs
def worker(data, output_file_path):
if is_data_with_last_obs(data):
try:
messages = copy.deepcopy(data['messages'])
# print(messages)
# print(obs)
retry=0
think_answer,tool_info = generate_qwq_obs_response(messages,tools_description)
#print("reasoning_content",reasoning_content)
#print("response",response)
data['messages'].append({"role": "assistant", "content": think_answer})
while is_answer_none(data) and retry<5:
# 如果答案为空,则需要重新调用工具
retry+=1
# 重新生成observation
# print("数据中最后一条消息为observation且为空")
# print("retry",retry)
# 重新生成observation
messages = copy.deepcopy(data['messages'].pop())
think_answer,tool_info = generate_qwq_obs_response(messages,tools_description,max_retries=5)
data['messages'].append({"role": "assistant", "content": think_answer})
# Use the lock to safely write to the file
data['observation']=''
data['function_calls']=''
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data)
return f"Processed successfully"
except Exception as e:
return f"Error processing: {str(e)}"
else:
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data)
return f"Processed successfully"
def main(input_file_path, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
datas = None
with jsonlines.open(input_file_path, mode='r') as reader:
datas = [line for line in reader]
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing CIF files")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_data = {}
for data in datas:
future = executor.submit(worker, data, output_file_path)
future_to_data[future] = data
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_data):
data = future_to_data[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {data} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
def main_for_datas(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
#datas = None
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing CIF files")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_data = {}
for data in datas:
future = executor.submit(worker, data, output_file_path)
future_to_data[future] = data
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_data):
data = future_to_data[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {data} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
origin_file = "/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_obs.jsonl"
#output_file = f"/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_turn2_ans_no_none1{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
output_file='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans.jsonl'
main(origin_file, output_file,max_workers=36)
print("output_file",output_file)
output_datas = utils.read_jsonline_file(output_file)
c=0
print("len output_datas",len(output_datas))
for data in output_datas:
if is_answer_none(data):
#print("data len",len(data['messages']))
c+=1
print("answer none number",c)