Files
mars-mcp/generate_data/generate_sft_data/address_sft_data.py

78 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 这个代码用于生成/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl
# 往messages中添加完调用工具的消息后的下一步处理
#address_failed_data用于处理调用大模型返回失败的样本
#address_non_answer_data 用于处理大模型返回的answer为空的样本
#address_tool_call_data 用于处理answer 中需要调用工具的代码
from generate_data.generate_sft_data.utils import read_jsonline_file
from generate_data.generate_sft_data.generate_llms_ans_multiturn import main_for_datas
from generate_data.generate_sft_data.generate_tool_observation_multiturn import main
from collections import Counter
def print_message_len(data_path):
datas = read_jsonline_file(data_path)
data_messages_len = []
for data in datas:
data_messages_len.append(len(data['messages']))
ele_counts=Counter(data_messages_len)
print("message len and number:",ele_counts)
def address_failed_data(original_data_path,failed_data_path):
original_datas=read_jsonline_file(original_data_path)
#original_datas_question = [original_data['messages'][0]['content'] for original_data in original_datas]
failed_datas = read_jsonline_file(failed_data_path)
failed_datas_question = [failed_data['messages'][0]['content'] for failed_data in failed_datas]
need_to_address_datas = []
for original_data in original_datas:
if original_data['messages'][0]['content'] in failed_datas_question:
pass
else:
need_to_address_datas.append(original_data)
print('need to address',len(need_to_address_datas))
main_for_datas(need_to_address_datas,failed_data_path,3)
print("after process")
new_datas=read_jsonline_file(failed_data_path)
print("new datas num",len(new_datas))
from .filter_messages import is_answer_none
def address_non_answer_data(non_answer_data_path,output_path):
non_answer_datas = read_jsonline_file(non_answer_data_path)
new_datas=[]
need_to_address_datas = []
for data in non_answer_datas:
if is_answer_none(data):
data['messages'].pop()
need_to_address_datas.append(data)
else:
new_datas.append(data)
print('non answer number',len(need_to_address_datas))
main_for_datas(non_answer_datas,output_path,3)
print("afer process",)
new_output_datas=read_jsonline_file(output_path)
print("new output data num",len(new_output_datas))
print("answer is none number",len([data for data in new_output_datas if is_answer_none(data)]))
from .filter_messages import is_tool_call_last
def address_tool_call_data(data_path,output_path):
datas = read_jsonline_file(data_path)
need_to_address_datas = []
print("total data",len(datas))
for data in datas:
if is_tool_call_last(data):
need_to_address_datas.append(data)
else:
pass
print("need to address",len(need_to_address_datas))
# main(datas,output_file_path=output_path,max_workers=7)
# print("after process")
# new_output_datas=read_jsonline_file(output_path)
# print("new output data num",len(new_output_datas))
# print("need to tool call number",len([data for data in new_output_datas if is_tool_call_last(data)]))
#main_1(need_to_address_datas,non_answer_data_path,48)
original_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl'
failed_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl'
print_message_len(original_data_path)
#address_failed_data(original_data_path,failed_data_path)
#address_non_answer_data(original_data_path,'/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl')
address_tool_call_data(original_data_path,'/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn6_obs.jsonl')