生成sft数据,设置OQMD的代理,测试mars-t1
This commit is contained in:
79
generate_data/calculate_tokens.py
Normal file
79
generate_data/calculate_tokens.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import json
|
||||
import tiktoken
|
||||
import numpy as np
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
# 均值: 13716.062458398048
|
||||
# 最大值: 106876
|
||||
# 最小值: 5108
|
||||
# 中值: 13285.5
|
||||
# 样本数: 9014
|
||||
def count_tokens_in_string(text):
|
||||
"""使用tiktoken库计算字符串中的token数量"""
|
||||
# 使用cl100k_base编码器,这是GPT-4使用的编码器
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
# 计算tokens
|
||||
tokens = encoding.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
def process_jsonl_file(file_path):
|
||||
"""处理JSONL文件并计算token统计信息"""
|
||||
token_counts = []
|
||||
count=0
|
||||
# 读取JSONL文件
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
# 解析JSON行
|
||||
data = json.loads(line)
|
||||
if len(data['messages'])==4:
|
||||
# 将数据转换为字符串
|
||||
count+=1
|
||||
data_str = json.dumps(data)
|
||||
# 计算tokens
|
||||
token_count = count_tokens_in_string(data_str)
|
||||
token_counts.append(token_count)
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"处理行时出错: {e}")
|
||||
print("countnumber",count)
|
||||
# 计算统计信息
|
||||
if token_counts:
|
||||
mean_value = statistics.mean(token_counts)
|
||||
max_value = max(token_counts)
|
||||
min_value = min(token_counts)
|
||||
median_value = statistics.median(token_counts)
|
||||
|
||||
# 计算token数小于32k的样本数量
|
||||
count_less_than_32k = sum(1 for count in token_counts if count < 32000)
|
||||
count_less_than_24k = sum(1 for count in token_counts if count < 24000)
|
||||
count_less_than_16k = sum(1 for count in token_counts if count < 16000)
|
||||
return {
|
||||
"均值": mean_value,
|
||||
"最大值": max_value,
|
||||
"最小值": min_value,
|
||||
"中值": median_value,
|
||||
"样本数": len(token_counts),
|
||||
"token数小于32k的样本数": count_less_than_32k,
|
||||
"token数小于32k的样本百分比": (count_less_than_32k / len(token_counts)) * 100 if token_counts else 0,
|
||||
"token数小于24k的样本数": count_less_than_24k,
|
||||
"token数小于24k的样本百分比": (count_less_than_24k / len(token_counts)) * 100 if token_counts else 0,
|
||||
"token数小于16k的样本数": count_less_than_16k,
|
||||
"token数小于16k的样本百分比": (count_less_than_16k / len(token_counts)) * 100 if token_counts else 0
|
||||
}
|
||||
else:
|
||||
return {"错误": "没有找到有效数据"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
file_path = "/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl"
|
||||
|
||||
# 确认文件存在
|
||||
if not Path(file_path).exists():
|
||||
print(f"错误: 文件不存在 - {file_path}")
|
||||
else:
|
||||
# 处理文件并打印结果
|
||||
results = process_jsonl_file(file_path)
|
||||
print("\n统计结果:")
|
||||
for key, value in results.items():
|
||||
print(f"{key}: {value}")
|
||||
@@ -1,3 +1,4 @@
|
||||
# 原始数据分为两类 一种是带solution的,一种是没有solution的,这个是提取了各5000条
|
||||
import json
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
77
generate_data/generate_sft_data/address_sft_data.py
Normal file
77
generate_data/generate_sft_data/address_sft_data.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# 这个代码用于生成/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl
|
||||
# 往messages中添加完调用工具的消息后的下一步处理,
|
||||
#address_failed_data用于处理调用大模型返回失败的样本
|
||||
#address_non_answer_data 用于处理大模型返回的answer为空的样本
|
||||
#address_tool_call_data 用于处理answer 中需要调用工具的代码
|
||||
from generate_data.generate_sft_data.utils import read_jsonline_file
|
||||
from generate_data.generate_sft_data.generate_llms_ans_multiturn import main_for_datas
|
||||
from generate_data.generate_sft_data.generate_tool_observation_multiturn import main
|
||||
from collections import Counter
|
||||
def print_message_len(data_path):
|
||||
datas = read_jsonline_file(data_path)
|
||||
data_messages_len = []
|
||||
for data in datas:
|
||||
data_messages_len.append(len(data['messages']))
|
||||
ele_counts=Counter(data_messages_len)
|
||||
print("message len and number:",ele_counts)
|
||||
|
||||
def address_failed_data(original_data_path,failed_data_path):
|
||||
original_datas=read_jsonline_file(original_data_path)
|
||||
#original_datas_question = [original_data['messages'][0]['content'] for original_data in original_datas]
|
||||
failed_datas = read_jsonline_file(failed_data_path)
|
||||
failed_datas_question = [failed_data['messages'][0]['content'] for failed_data in failed_datas]
|
||||
need_to_address_datas = []
|
||||
for original_data in original_datas:
|
||||
if original_data['messages'][0]['content'] in failed_datas_question:
|
||||
pass
|
||||
else:
|
||||
need_to_address_datas.append(original_data)
|
||||
print('need to address',len(need_to_address_datas))
|
||||
main_for_datas(need_to_address_datas,failed_data_path,3)
|
||||
print("after process")
|
||||
new_datas=read_jsonline_file(failed_data_path)
|
||||
print("new datas num",len(new_datas))
|
||||
from .filter_messages import is_answer_none
|
||||
def address_non_answer_data(non_answer_data_path,output_path):
|
||||
non_answer_datas = read_jsonline_file(non_answer_data_path)
|
||||
new_datas=[]
|
||||
need_to_address_datas = []
|
||||
for data in non_answer_datas:
|
||||
if is_answer_none(data):
|
||||
data['messages'].pop()
|
||||
need_to_address_datas.append(data)
|
||||
|
||||
else:
|
||||
new_datas.append(data)
|
||||
print('non answer number',len(need_to_address_datas))
|
||||
main_for_datas(non_answer_datas,output_path,3)
|
||||
print("afer process",)
|
||||
new_output_datas=read_jsonline_file(output_path)
|
||||
print("new output data num",len(new_output_datas))
|
||||
print("answer is none number",len([data for data in new_output_datas if is_answer_none(data)]))
|
||||
|
||||
from .filter_messages import is_tool_call_last
|
||||
def address_tool_call_data(data_path,output_path):
|
||||
datas = read_jsonline_file(data_path)
|
||||
need_to_address_datas = []
|
||||
print("total data",len(datas))
|
||||
for data in datas:
|
||||
if is_tool_call_last(data):
|
||||
need_to_address_datas.append(data)
|
||||
else:
|
||||
pass
|
||||
print("need to address",len(need_to_address_datas))
|
||||
# main(datas,output_file_path=output_path,max_workers=7)
|
||||
# print("after process")
|
||||
# new_output_datas=read_jsonline_file(output_path)
|
||||
# print("new output data num",len(new_output_datas))
|
||||
# print("need to tool call number",len([data for data in new_output_datas if is_tool_call_last(data)]))
|
||||
#main_1(need_to_address_datas,non_answer_data_path,48)
|
||||
original_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl'
|
||||
failed_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl'
|
||||
print_message_len(original_data_path)
|
||||
#address_failed_data(original_data_path,failed_data_path)
|
||||
|
||||
|
||||
#address_non_answer_data(original_data_path,'/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl')
|
||||
address_tool_call_data(original_data_path,'/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn6_obs.jsonl')
|
||||
1
generate_data/generate_sft_data/data_version.md
Normal file
1
generate_data/generate_sft_data/data_version.md
Normal file
@@ -0,0 +1 @@
|
||||
大概是345的格式才是统一的,即;obs是最后一条message为工具调用的,ans是工具调用后再调用qwq32b产生的结果,ans-no-none是重复调用大模型重新生成answer中为空的数据
|
||||
@@ -78,7 +78,7 @@ def filter_generate_material(file_path):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 默认文件路径
|
||||
file_path ='/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl'
|
||||
file_path ='/home/ubuntu/50T/nfs/lzy/mars-mcp/agent_questions_solutions_test20250416152446.jsonl'
|
||||
# "/home/ubuntu/50T/lzy/mars-mcp/mars-agent_data_20250408205427.jsonl"
|
||||
|
||||
# 如果提供了命令行参数,则使用命令行参数作为文件路径
|
||||
129
generate_data/generate_sft_data/filter_messages.py
Normal file
129
generate_data/generate_sft_data/filter_messages.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# 用于处理
|
||||
#from generate_tool_observation_multiturn import worker
|
||||
from generate_data.generate_sft_data.utils import read_jsonline_file
|
||||
import jsonlines
|
||||
|
||||
def is_data_with_last_obs(data):
|
||||
|
||||
'''用于准备用大模型生成工具结果分析和答案时,判断出数据中最后一条消息是否为工具调用的结果observation,
|
||||
是的话,则需要大模型进一步回答;否则也就是上一条消息为答案,即可以跳过大模型的回答'''
|
||||
if data['messages'][-1]['role']=='user':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
def is_obs_none(data):
|
||||
'''用于检查调用工具生成的结果是否为空'''
|
||||
if is_data_with_last_obs(data) and data['messages'][-1]['content']==None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_tool_call_last(data):
|
||||
"""用于检查最后一条消息是否包含工具调用,用于生成observation步骤"""
|
||||
tool_call_str=data['messages'][-1]['content'].split("<answer>")[-1]
|
||||
if data['messages'][-1]['role']=='assistant' and '<tool_call>' in tool_call_str:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
def is_answer_none(data):
|
||||
# 修正换行符表示并获取answer内容
|
||||
try:
|
||||
answer_str = data['messages'][-1]['content'].split("<answer>\n")[-1].split("</answer>\n")[0]
|
||||
except:
|
||||
# 如果分割出错,返回原始内容
|
||||
answer_str = data['messages'][-1]['content']
|
||||
|
||||
# 检查是否为assistant角色且answer_str为空或只包含空白字符
|
||||
if data['messages'][-1]['role'] == 'assistant' and (not answer_str or answer_str.strip() == ''):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# 创建进度条
|
||||
|
||||
|
||||
# for data in ans_datas:
|
||||
# if is_answer_none(data):
|
||||
# data['messages'].pop()
|
||||
# with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
# writer.write(data) # observation . data
|
||||
|
||||
#print("c",c)
|
||||
file_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn4_ans.jsonl'
|
||||
#output_file_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn2_obs_1.jsonl'
|
||||
output_datas = read_jsonline_file(file_path)
|
||||
print(len(output_datas))
|
||||
from collections import Counter
|
||||
data_messages_len = []
|
||||
for data in output_datas:
|
||||
data_messages_len.append(len(data['messages']))
|
||||
ele_counts=Counter(data_messages_len)
|
||||
print(ele_counts)
|
||||
|
||||
count=0
|
||||
for data in output_datas:
|
||||
if is_data_with_last_obs(data):
|
||||
count+=1
|
||||
print("last obs",count)
|
||||
|
||||
# for data in output_datas:
|
||||
# #if is_data_with_last_obs(data):
|
||||
# if is_answer_none(data):
|
||||
# #data['messages'].pop()
|
||||
# c+=1
|
||||
# # with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
# # writer.write(data) # observation . data
|
||||
# print("c",c)
|
||||
|
||||
# print(c)
|
||||
# d=0
|
||||
# new_data=read_jsonline_file(output_file_path)
|
||||
# data_lens=[]
|
||||
# for data in new_data:
|
||||
# data_lens.append(len(data['messages']))
|
||||
|
||||
# #if is_data_with_last_obs(data):
|
||||
# #d+=1
|
||||
# print(set(data_lens))
|
||||
# print(d)
|
||||
# data['messages'].pop()
|
||||
|
||||
# print("数据中最后一条消息为observation且为空")
|
||||
|
||||
# worker(data,output_file_path)
|
||||
# pbar.update(1)
|
||||
# else:
|
||||
|
||||
# with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
# writer.write(data) # observation . data
|
||||
# pbar.update(1)
|
||||
# pbar.close
|
||||
# from generate_tool_observation_multiturn import main
|
||||
# raw_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_qwq1.jsonl'
|
||||
# unfinish_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn2_20250418113514.jsonl'
|
||||
|
||||
# unfinish_datas = read_jsonline_file(unfinish_data_path)
|
||||
# raw_datas = read_jsonline_file(raw_data_path)
|
||||
# unfinish_datas_question =[unfinish_data['messages'][0]['content'] for unfinish_data in unfinish_datas]
|
||||
|
||||
# from generate_tool_observation_multiturn import worker
|
||||
# filtered_unfinish_datas = []
|
||||
# for raw_data in raw_datas:
|
||||
|
||||
# if raw_data['messages'][0]['content'] in unfinish_datas_question:
|
||||
|
||||
# pass
|
||||
# else:
|
||||
# filtered_unfinish_datas.append(raw_data)
|
||||
# # print(raw_data['messages'][-1]['content'])
|
||||
|
||||
# # worker(raw_data,unfinish_data_path)
|
||||
# #exit(0)
|
||||
# main(filtered_unfinish_datas,unfinish_data_path,16)
|
||||
#print(len(filtered_unfinish_datas))
|
||||
|
||||
|
||||
|
||||
|
||||
218
generate_data/generate_sft_data/generate_llms_ans_multiturn.py
Normal file
218
generate_data/generate_sft_data/generate_llms_ans_multiturn.py
Normal file
@@ -0,0 +1,218 @@
|
||||
|
||||
import jsonlines
|
||||
import argparse
|
||||
import generate_data.generate_sft_data.utils as utils
|
||||
import glob
|
||||
import json
|
||||
from ase import io
|
||||
import tempfile
|
||||
import re
|
||||
from pymatgen.io.vasp import Poscar
|
||||
from pymatgen.io.cif import CifParser
|
||||
import threading
|
||||
import concurrent.futures
|
||||
import copy
|
||||
from generate_data.generate_sft_data.sft_utils import generate_design_question, generate_props_question#, generate_obs_response
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/mars-mcp/')
|
||||
|
||||
from mars_toolkit import get_tool_schemas
|
||||
tools=get_tool_schemas()
|
||||
tools_description=''
|
||||
tools_description+='\n'.join (json.dumps(tool) for tool in tools)
|
||||
# Create a lock for file writing
|
||||
file_lock = threading.Lock()
|
||||
|
||||
|
||||
def generate_deepseek_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1):
|
||||
messages_ = copy.deepcopy(messages)
|
||||
system_prompt = '''
|
||||
你是一个有用的AI助手,可以使用以下工具来帮助用户完成任务。
|
||||
可用工具:
|
||||
{tools_description},
|
||||
当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考,
|
||||
并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案;
|
||||
若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
|
||||
'''
|
||||
messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)})
|
||||
|
||||
|
||||
# instrument='''
|
||||
# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考,
|
||||
# 并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案;
|
||||
# 若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
|
||||
# '''
|
||||
# messages_.append({"role": "user", "content": instrument})
|
||||
_reasoning_content, response = utils.get_response_from_deepseek_r1(
|
||||
messages_,prefix=False,max_retries=max_retries,initial_backoff=initial_backoff)
|
||||
return _reasoning_content, response
|
||||
def generate_qwq_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1):
|
||||
messages_ = copy.deepcopy(messages)
|
||||
system_prompt = '''
|
||||
你是一个有用的AI助手,可以使用以下工具来帮助用户完成任务。
|
||||
可用工具:
|
||||
{tools_description}.
|
||||
当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考,工具调用失败的原因可能是因为你写的工具调用方式不规范导致调用时解析错误或者是网络原因等等。
|
||||
并自行判断这些工具调用的结果能否让你直接给出答案,情况1.如果能给出答案,请在回答中直接给出答案;
|
||||
情况2.若无法给出答案,你将再拥有一次调用多个工具获取信息的机会,请在思考过程中尽可能地调用多个提供给你的工具查询该问题的相关信息(之前的工具调用如果调用失败你也可以自行决定要不要重新调用),而不是直接回答该问题。
|
||||
情况1.的回答内容为思考内容以及答案,情况2.的回答内容为思考内容以及每一个工具调用的名称和参数(不要省略具体的参数内容,因为这会导致无法解析参数)。
|
||||
思考和回答时使用和问题相同的语言。
|
||||
'''
|
||||
messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)})
|
||||
|
||||
|
||||
# instrument='''
|
||||
# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考,
|
||||
# 并自行判断这些工具调用的结果能否让你直接给出答案,情况1.如果能给出答案,请在回答中直接给出答案;
|
||||
# 情况2.若无法给出答案,你将拥有一次调用多个工具获取信息的机会,在思考过程中尽可能地描述如何通过调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题,因此,在这种情况下你在回答中只需要一次给出多个工具调用,不需要额外的文字描述。
|
||||
|
||||
# '''
|
||||
# messages_.append({"role": "assistant", "content": instrument})
|
||||
think_answer,tool_info=utils.get_response_from_qwq(messages_,model_name='qwq-32b',tools=tools,max_retries=max_retries,initial_backoff=initial_backoff)
|
||||
return think_answer,tool_info
|
||||
from .filter_messages import is_answer_none,is_data_with_last_obs
|
||||
|
||||
def worker(data, output_file_path):
|
||||
if is_data_with_last_obs(data):
|
||||
try:
|
||||
messages = copy.deepcopy(data['messages'])
|
||||
|
||||
# print(messages)
|
||||
# print(obs)
|
||||
retry=0
|
||||
think_answer,tool_info = generate_qwq_obs_response(messages,tools_description)
|
||||
|
||||
#print("reasoning_content",reasoning_content)
|
||||
#print("response",response)
|
||||
data['messages'].append({"role": "assistant", "content": think_answer})
|
||||
while is_answer_none(data) and retry<5:
|
||||
# 如果答案为空,则需要重新调用工具
|
||||
retry+=1
|
||||
# 重新生成observation
|
||||
# print("数据中最后一条消息为observation且为空")
|
||||
# print("retry",retry)
|
||||
# 重新生成observation
|
||||
messages = copy.deepcopy(data['messages'].pop())
|
||||
think_answer,tool_info = generate_qwq_obs_response(messages,tools_description,max_retries=5)
|
||||
data['messages'].append({"role": "assistant", "content": think_answer})
|
||||
|
||||
|
||||
# Use the lock to safely write to the file
|
||||
data['observation']=''
|
||||
data['function_calls']=''
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(data)
|
||||
|
||||
return f"Processed successfully"
|
||||
except Exception as e:
|
||||
return f"Error processing: {str(e)}"
|
||||
else:
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(data)
|
||||
return f"Processed successfully"
|
||||
def main(input_file_path, output_file_path, max_workers=1):
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
datas = None
|
||||
with jsonlines.open(input_file_path, mode='r') as reader:
|
||||
datas = [line for line in reader]
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(datas), desc="Processing CIF files")
|
||||
|
||||
# 创建一个线程池
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交任务到执行器
|
||||
future_to_data = {}
|
||||
for data in datas:
|
||||
future = executor.submit(worker, data, output_file_path)
|
||||
future_to_data[future] = data
|
||||
|
||||
# 处理结果
|
||||
completed = 0
|
||||
failed = 0
|
||||
for future in concurrent.futures.as_completed(future_to_data):
|
||||
data = future_to_data[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if "successfully" in result:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
# 每100个文件更新一次统计信息
|
||||
if (completed + failed) % 100 == 0:
|
||||
pbar.set_postfix(completed=completed, failed=failed)
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
pbar.update(1)
|
||||
print(f"\nWorker for {data} generated an exception: {e}")
|
||||
|
||||
pbar.close()
|
||||
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||
|
||||
|
||||
def main_for_datas(datas, output_file_path, max_workers=1):
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
#datas = None
|
||||
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(datas), desc="Processing CIF files")
|
||||
|
||||
# 创建一个线程池
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交任务到执行器
|
||||
future_to_data = {}
|
||||
for data in datas:
|
||||
future = executor.submit(worker, data, output_file_path)
|
||||
future_to_data[future] = data
|
||||
|
||||
# 处理结果
|
||||
completed = 0
|
||||
failed = 0
|
||||
for future in concurrent.futures.as_completed(future_to_data):
|
||||
data = future_to_data[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if "successfully" in result:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
# 每100个文件更新一次统计信息
|
||||
if (completed + failed) % 100 == 0:
|
||||
pbar.set_postfix(completed=completed, failed=failed)
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
pbar.update(1)
|
||||
print(f"\nWorker for {data} generated an exception: {e}")
|
||||
|
||||
pbar.close()
|
||||
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
import datetime
|
||||
origin_file = "/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_obs.jsonl"
|
||||
#output_file = f"/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_turn2_ans_no_none1{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||
|
||||
output_file='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans.jsonl'
|
||||
main(origin_file, output_file,max_workers=36)
|
||||
print("output_file",output_file)
|
||||
output_datas = utils.read_jsonline_file(output_file)
|
||||
c=0
|
||||
print("len output_datas",len(output_datas))
|
||||
for data in output_datas:
|
||||
if is_answer_none(data):
|
||||
#print("data len",len(data['messages']))
|
||||
c+=1
|
||||
print("answer none number",c)
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/mars-mcp/')
|
||||
import jsonlines
|
||||
from mars_toolkit import *
|
||||
import threading
|
||||
@@ -19,6 +20,37 @@ init(autoreset=True)
|
||||
|
||||
from typing import Dict, Union, Any, Optional, List
|
||||
|
||||
import re
|
||||
def extract_tool_calls(text):
|
||||
"""
|
||||
提取字符串中所有包裹在<tool_call>\n</tool_call>\n 中的JSON内容并转换为字典列表
|
||||
|
||||
参数:
|
||||
text (str): 包含工具调用的文本
|
||||
|
||||
返回:
|
||||
list: 包含所有工具调用的字典列表
|
||||
"""
|
||||
# 使用正则表达式提取<tool_call>和</tool_call>之间的内容
|
||||
# (?s)表示让.也匹配换行符,使模式可以跨行匹配
|
||||
pattern = r'<tool_call>\n(.*?)</tool_call>'
|
||||
matches = re.finditer(pattern, text, re.DOTALL)
|
||||
|
||||
tool_calls = []
|
||||
|
||||
for match in matches:
|
||||
json_str = match.group(1).strip()
|
||||
try:
|
||||
# 将JSON字符串转换为Python字典
|
||||
tool_call_dict = json.loads(json_str)
|
||||
tool_calls.append(tool_call_dict)
|
||||
except json.JSONDecodeError as e:
|
||||
tool_calls.append(f"无法解析JSON: {e},问题字符串{json_str}")
|
||||
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
规范化传递给generate_material函数的参数格式。
|
||||
@@ -237,7 +269,7 @@ async def execute_tool_from_dict(input_dict: dict):
|
||||
|
||||
# 检查函数名是否存在于工具函数字典中
|
||||
if func_name not in tools:
|
||||
return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"}
|
||||
return f"函数 '{func_name}' 不存在于工具函数字典中"
|
||||
|
||||
# 获取对应的工具函数
|
||||
tool_func = tools[func_name]
|
||||
@@ -265,102 +297,117 @@ async def execute_tool_from_dict(input_dict: dict):
|
||||
arguments = {"raw_string": arguments_data}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
except Exception as e:
|
||||
result = f'工具函数调用时出错:str{e}'
|
||||
# if func_name=='generate_material':
|
||||
# print("xxxxx",result)
|
||||
return result
|
||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
return formatted_result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
|
||||
|
||||
pass
|
||||
|
||||
def worker(data, output_file_path):
|
||||
try:
|
||||
func_contents = data["function_calls"]
|
||||
func_results = []
|
||||
formatted_results = [] # 新增一个列表来存储格式化后的结果
|
||||
for func in func_contents:
|
||||
func_name = func.get("name")
|
||||
arguments_data = func.get("arguments")
|
||||
tool_call_str = data['messages'][-1]['content'].split("<answer>")[-1]
|
||||
if '<tool_call>' in tool_call_str:
|
||||
|
||||
# 使用富文本打印函数名
|
||||
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||
func_contents=extract_tool_calls(tool_call_str)
|
||||
#print(func_contents)
|
||||
|
||||
# 使用富文本打印参数
|
||||
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
||||
#func_results = []
|
||||
formatted_results = [] # 新增一个列表来存储格式化后的结果
|
||||
for func in func_contents:
|
||||
if isinstance(func,Dict):
|
||||
func_name = func.get("name")
|
||||
arguments_data = func.get("arguments")
|
||||
|
||||
if func.get("name") == 'retrieval_from_knowledge_base':
|
||||
# 使用富文本打印函数名
|
||||
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||
|
||||
# delay_time = random.uniform(5, 10)
|
||||
# time.sleep(delay_time)
|
||||
result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
||||
func_results.append({"function": func['name'], "result": result})
|
||||
# 格式化结果
|
||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
# 使用富文本打印参数
|
||||
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
||||
|
||||
if func.get("name") == 'retrieval_from_knowledge_base':
|
||||
|
||||
# delay_time = random.uniform(5, 10)
|
||||
# time.sleep(delay_time)
|
||||
result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
||||
|
||||
# 格式化结果
|
||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
elif func.get("name") == 'generate_material':
|
||||
|
||||
elif func.get("name") == 'generate_material':
|
||||
try:
|
||||
# 确保arguments_data是字典
|
||||
if isinstance(arguments_data, str):
|
||||
try:
|
||||
arguments_data = json.loads(arguments_data)
|
||||
except json.JSONDecodeError as e:
|
||||
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
|
||||
# 确保arguments_data是字典
|
||||
if isinstance(arguments_data, str):
|
||||
try:
|
||||
arguments_data = json.loads(arguments_data)
|
||||
except json.JSONDecodeError as e:
|
||||
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
|
||||
continue
|
||||
|
||||
# 规范化参数
|
||||
normalized_args = normalize_material_args(arguments_data)
|
||||
|
||||
# 优先使用mattergen函数
|
||||
try:
|
||||
|
||||
output = generate_material(**normalized_args)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
#print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
||||
formatted_result = f"调用时出错,请检查输入的参数,异常为{e}"
|
||||
|
||||
|
||||
|
||||
# 格式化结果
|
||||
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
except Exception as e:
|
||||
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
|
||||
import traceback
|
||||
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
||||
continue
|
||||
else:
|
||||
|
||||
# 规范化参数
|
||||
normalized_args = normalize_material_args(arguments_data)
|
||||
|
||||
# 优先使用mattergen函数
|
||||
try:
|
||||
|
||||
output = generate_material(**normalized_args)
|
||||
formatted_result = asyncio.run(execute_tool_from_dict(func))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
#print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
||||
continue
|
||||
# 将结果添加到func_results
|
||||
func_results.append({"function": func_name, "result": output})
|
||||
|
||||
# 格式化结果
|
||||
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
except Exception as e:
|
||||
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
|
||||
import traceback
|
||||
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
||||
continue
|
||||
else:
|
||||
|
||||
result = asyncio.run(execute_tool_from_dict(func))
|
||||
func_results.append({"function": func['name'], "result": result})
|
||||
# 格式化结果
|
||||
func_name = func.get("name")
|
||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
# 将所有格式化后的结果连接起来
|
||||
final_result = "\n\n\n".join(formatted_results)
|
||||
data['observation'] = final_result
|
||||
|
||||
#使用富文本打印开始和结束标记
|
||||
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
|
||||
# print(data['observation'])
|
||||
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(data) # observation . data
|
||||
return f"Processed successfully"
|
||||
formatted_results.append(formatted_result)
|
||||
else:
|
||||
formatted_results.append(func)
|
||||
# 将所有格式化后的结果连接起来
|
||||
final_result = "\n\n\n".join(formatted_results)
|
||||
data['messages'].append({"role": "user", "content": final_result})
|
||||
|
||||
#print("last message",data["messages"][-1])
|
||||
#使用富文本打印开始和结束标记
|
||||
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
|
||||
# print(data['observation'])
|
||||
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(data) # observation . data
|
||||
return f"Processed successfully"
|
||||
else:
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(data) # observation . data
|
||||
return f"Processed successfully"
|
||||
except Exception as e:
|
||||
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
|
||||
return f"Error processing: {str(e)}"
|
||||
@@ -407,19 +454,41 @@ def main(datas, output_file_path, max_workers=1):
|
||||
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import datetime
|
||||
import jsonlines
|
||||
datas = []
|
||||
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
|
||||
total_count=0
|
||||
filtered_count=0
|
||||
with jsonlines.open('/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_qwq1.jsonl') as reader:
|
||||
for obj in reader:
|
||||
#if obj['solution']!='':
|
||||
datas.append(obj)
|
||||
|
||||
print(len(datas))
|
||||
# print()
|
||||
output_file = f"./filter_ok_questions_solutions_agent_data_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||
main(datas, output_file, max_workers=32)
|
||||
for data in datas:
|
||||
tool_call_str=data['messages'][-1]['content'].split("<answer>\n")[-1]
|
||||
if '<tool_call>' in tool_call_str:
|
||||
filtered_count+=1
|
||||
total_count+=1
|
||||
print("total count",total_count)
|
||||
print("filtered count",filtered_count)
|
||||
|
||||
# for data in datas[:5]:
|
||||
# tool_call_str=data['messages'][-1]['content'].split("<answer>\n")[-1].split("<answer>")[0]
|
||||
# tool_call_dict_list=extract_tool_calls(tool_call_str)
|
||||
# for tool_call_dict in tool_call_dict_list:
|
||||
# print("tool name",tool_call_dict['name'])
|
||||
# print("tool arguments",tool_call_dict['arguments'])
|
||||
# print("xxx")
|
||||
# print("==="*20)
|
||||
# # print()
|
||||
# exit()
|
||||
output_file = f"./agent_questions_solution_turn2_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||
main(datas, output_file, max_workers=48)
|
||||
|
||||
# 示例1:使用正确的JSON格式
|
||||
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
|
||||
@@ -1,6 +1,6 @@
|
||||
import jsonlines
|
||||
import argparse
|
||||
import generate_data.utils as utils
|
||||
import generate_data.generate_sft_data.utils as utils
|
||||
import glob
|
||||
import json
|
||||
from ase import io
|
||||
@@ -5,6 +5,7 @@ It uses the OpenAI API and MySQL for storing and retrieving data.
|
||||
"""
|
||||
import multiprocessing
|
||||
import sqlite3
|
||||
import jsonlines
|
||||
import tiktoken
|
||||
import re
|
||||
from fractions import Fraction
|
||||
@@ -48,9 +49,24 @@ def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, ma
|
||||
messages=messages,
|
||||
temperature=0.6
|
||||
)
|
||||
#print("response",response)
|
||||
|
||||
# reasoning_content = "null" if prefix else "<think>\n" + response.choices[0].message.model_extra['reasoning_content'] + "\n</think>\n"
|
||||
reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
|
||||
if reasoning_content=='':
|
||||
reasoning_content=response.choices[0].message.content.split("</think>\n")[1]
|
||||
# while reasoning_content == "" :
|
||||
# if retries<max_retries:
|
||||
# response = client.chat.completions.create(
|
||||
# model="deepseek-r1",
|
||||
# messages=messages,
|
||||
# temperature=0.6
|
||||
# )
|
||||
# retries+=1
|
||||
# else:
|
||||
# print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||||
# return 'apierror', 'apierror'
|
||||
# reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
|
||||
|
||||
content = response.choices[0].message.content.split("</think>\n")[-1]
|
||||
return reasoning_content, content
|
||||
|
||||
@@ -173,13 +189,13 @@ def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = N
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
# client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
|
||||
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
|
||||
# client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
import random
|
||||
if random.random() > 0.5:
|
||||
client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
else:
|
||||
client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
# import random
|
||||
# if random.random() > 0.5:
|
||||
# client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
# else:
|
||||
# client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
if tools is None:
|
||||
response = client.chat.completions.create(
|
||||
@@ -295,8 +311,10 @@ def read_json_file(file_path):
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def read_jsonline_file(file_path):
|
||||
with jsonlines.open(file_path, mode='r') as reader:
|
||||
datas = [line for line in reader]
|
||||
return datas
|
||||
################################## utils
|
||||
|
||||
def clean_all_repetitions_with_details(text, min_length=10, threshold=10):
|
||||
@@ -1,91 +0,0 @@
|
||||
import jsonlines
|
||||
import argparse
|
||||
import generate_data.utils as utils
|
||||
import glob
|
||||
import json
|
||||
from ase import io
|
||||
import tempfile
|
||||
import re
|
||||
from pymatgen.io.vasp import Poscar
|
||||
from pymatgen.io.cif import CifParser
|
||||
import threading
|
||||
import concurrent.futures
|
||||
import copy
|
||||
from grpo_utils import generate_design_question, generate_props_question, generate_obs_response
|
||||
# Create a lock for file writing
|
||||
file_lock = threading.Lock()
|
||||
|
||||
|
||||
def worker(data, output_file_path):
|
||||
try:
|
||||
messages = copy.deepcopy(data['messages'])
|
||||
obs = data['observation']
|
||||
messages[-1]['content'] = messages[-1]['content'].split("<answer>")[-1].split("</answer>")[0]
|
||||
messages.append({"role": "user", "content": obs})
|
||||
data['messages'].append({"role": "user", "content": obs})
|
||||
# print(messages)
|
||||
# print(obs)
|
||||
|
||||
reasoning_content, response = generate_obs_response(messages)
|
||||
data['messages'].append({"role": "assistant", "content": f"<think>\n{reasoning_content}</think>\n<answer>\n{response}</answer>\n"})
|
||||
# Use the lock to safely write to the file
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(messages)
|
||||
|
||||
return f"Processed successfully"
|
||||
except Exception as e:
|
||||
return f"Error processing: {str(e)}"
|
||||
|
||||
|
||||
def main(input_file_path, output_file_path, max_workers=1):
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
datas = None
|
||||
with jsonlines.open(input_file_path, mode='r') as reader:
|
||||
datas = [line for line in reader]
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(datas), desc="Processing CIF files")
|
||||
|
||||
# 创建一个线程池
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交任务到执行器
|
||||
future_to_data = {}
|
||||
for data in datas:
|
||||
future = executor.submit(worker, data, output_file_path)
|
||||
future_to_data[future] = data
|
||||
|
||||
# 处理结果
|
||||
completed = 0
|
||||
failed = 0
|
||||
for future in concurrent.futures.as_completed(future_to_data):
|
||||
data = future_to_data[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if "successfully" in result:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
# 每100个文件更新一次统计信息
|
||||
if (completed + failed) % 100 == 0:
|
||||
pbar.set_postfix(completed=completed, failed=failed)
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
pbar.update(1)
|
||||
print(f"\nWorker for {data} generated an exception: {e}")
|
||||
|
||||
pbar.close()
|
||||
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import datetime
|
||||
origin_file = "/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl"
|
||||
output_file = f"agent_questions_solutions_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||
main(origin_file, output_file)
|
||||
|
||||
13
generate_data/read_data.py
Executable file
13
generate_data/read_data.py
Executable file
@@ -0,0 +1,13 @@
|
||||
import copy
|
||||
import jsonlines
|
||||
|
||||
data_path = '/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl'
|
||||
#output_path = './agent_questions_solutions_qwq1.jsonl'
|
||||
with jsonlines.open(data_path, mode='r') as reader:
|
||||
datas = [line for line in reader if len(line['messages']) == 4]
|
||||
|
||||
print(datas[0])
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user