生成sft数据,设置OQMD的代理,测试mars-t1

This commit is contained in:
lzy
2025-04-22 16:44:26 +08:00
parent 6b92e54a41
commit a7964add00
38 changed files with 888 additions and 191 deletions

4
.gitignore vendored
View File

@@ -7,5 +7,7 @@ pyproject.toml
/pretrained_models
/mcp-python-sdk
/.vscode
*.jsonl
/*filter_ok_questions_solutions_agent*
# 忽略所有目录下的__pycache__文件夹
__pycache__/

View File

@@ -0,0 +1,79 @@
import json
import tiktoken
import numpy as np
import statistics
from pathlib import Path
# 均值: 13716.062458398048
# 最大值: 106876
# 最小值: 5108
# 中值: 13285.5
# 样本数: 9014
def count_tokens_in_string(text):
"""使用tiktoken库计算字符串中的token数量"""
# 使用cl100k_base编码器这是GPT-4使用的编码器
encoding = tiktoken.get_encoding("cl100k_base")
# 计算tokens
tokens = encoding.encode(text)
return len(tokens)
def process_jsonl_file(file_path):
"""处理JSONL文件并计算token统计信息"""
token_counts = []
count=0
# 读取JSONL文件
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
try:
# 解析JSON行
data = json.loads(line)
if len(data['messages'])==4:
# 将数据转换为字符串
count+=1
data_str = json.dumps(data)
# 计算tokens
token_count = count_tokens_in_string(data_str)
token_counts.append(token_count)
else:
pass
except Exception as e:
print(f"处理行时出错: {e}")
print("countnumber",count)
# 计算统计信息
if token_counts:
mean_value = statistics.mean(token_counts)
max_value = max(token_counts)
min_value = min(token_counts)
median_value = statistics.median(token_counts)
# 计算token数小于32k的样本数量
count_less_than_32k = sum(1 for count in token_counts if count < 32000)
count_less_than_24k = sum(1 for count in token_counts if count < 24000)
count_less_than_16k = sum(1 for count in token_counts if count < 16000)
return {
"均值": mean_value,
"最大值": max_value,
"最小值": min_value,
"中值": median_value,
"样本数": len(token_counts),
"token数小于32k的样本数": count_less_than_32k,
"token数小于32k的样本百分比": (count_less_than_32k / len(token_counts)) * 100 if token_counts else 0,
"token数小于24k的样本数": count_less_than_24k,
"token数小于24k的样本百分比": (count_less_than_24k / len(token_counts)) * 100 if token_counts else 0,
"token数小于16k的样本数": count_less_than_16k,
"token数小于16k的样本百分比": (count_less_than_16k / len(token_counts)) * 100 if token_counts else 0
}
else:
return {"错误": "没有找到有效数据"}
if __name__ == "__main__":
file_path = "/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl"
# 确认文件存在
if not Path(file_path).exists():
print(f"错误: 文件不存在 - {file_path}")
else:
# 处理文件并打印结果
results = process_jsonl_file(file_path)
print("\n统计结果:")
for key, value in results.items():
print(f"{key}: {value}")

View File

@@ -1,3 +1,4 @@
# 原始数据分为两类 一种是带solution的一种是没有solution的这个是提取了各5000条
import json
import asyncio
import concurrent.futures

View File

@@ -0,0 +1,77 @@
# 这个代码用于生成/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl
# 往messages中添加完调用工具的消息后的下一步处理
#address_failed_data用于处理调用大模型返回失败的样本
#address_non_answer_data 用于处理大模型返回的answer为空的样本
#address_tool_call_data 用于处理answer 中需要调用工具的代码
from generate_data.generate_sft_data.utils import read_jsonline_file
from generate_data.generate_sft_data.generate_llms_ans_multiturn import main_for_datas
from generate_data.generate_sft_data.generate_tool_observation_multiturn import main
from collections import Counter
def print_message_len(data_path):
datas = read_jsonline_file(data_path)
data_messages_len = []
for data in datas:
data_messages_len.append(len(data['messages']))
ele_counts=Counter(data_messages_len)
print("message len and number:",ele_counts)
def address_failed_data(original_data_path,failed_data_path):
original_datas=read_jsonline_file(original_data_path)
#original_datas_question = [original_data['messages'][0]['content'] for original_data in original_datas]
failed_datas = read_jsonline_file(failed_data_path)
failed_datas_question = [failed_data['messages'][0]['content'] for failed_data in failed_datas]
need_to_address_datas = []
for original_data in original_datas:
if original_data['messages'][0]['content'] in failed_datas_question:
pass
else:
need_to_address_datas.append(original_data)
print('need to address',len(need_to_address_datas))
main_for_datas(need_to_address_datas,failed_data_path,3)
print("after process")
new_datas=read_jsonline_file(failed_data_path)
print("new datas num",len(new_datas))
from .filter_messages import is_answer_none
def address_non_answer_data(non_answer_data_path,output_path):
non_answer_datas = read_jsonline_file(non_answer_data_path)
new_datas=[]
need_to_address_datas = []
for data in non_answer_datas:
if is_answer_none(data):
data['messages'].pop()
need_to_address_datas.append(data)
else:
new_datas.append(data)
print('non answer number',len(need_to_address_datas))
main_for_datas(non_answer_datas,output_path,3)
print("afer process",)
new_output_datas=read_jsonline_file(output_path)
print("new output data num",len(new_output_datas))
print("answer is none number",len([data for data in new_output_datas if is_answer_none(data)]))
from .filter_messages import is_tool_call_last
def address_tool_call_data(data_path,output_path):
datas = read_jsonline_file(data_path)
need_to_address_datas = []
print("total data",len(datas))
for data in datas:
if is_tool_call_last(data):
need_to_address_datas.append(data)
else:
pass
print("need to address",len(need_to_address_datas))
# main(datas,output_file_path=output_path,max_workers=7)
# print("after process")
# new_output_datas=read_jsonline_file(output_path)
# print("new output data num",len(new_output_datas))
# print("need to tool call number",len([data for data in new_output_datas if is_tool_call_last(data)]))
#main_1(need_to_address_datas,non_answer_data_path,48)
original_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl'
failed_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl'
print_message_len(original_data_path)
#address_failed_data(original_data_path,failed_data_path)
#address_non_answer_data(original_data_path,'/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl')
address_tool_call_data(original_data_path,'/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn6_obs.jsonl')

View File

@@ -0,0 +1 @@
大概是345的格式才是统一的obs是最后一条message为工具调用的ans是工具调用后再调用qwq32b产生的结果,ans-no-none是重复调用大模型重新生成answer中为空的数据

View File

@@ -78,7 +78,7 @@ def filter_generate_material(file_path):
if __name__ == "__main__":
# 默认文件路径
file_path ='/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl'
file_path ='/home/ubuntu/50T/nfs/lzy/mars-mcp/agent_questions_solutions_test20250416152446.jsonl'
# "/home/ubuntu/50T/lzy/mars-mcp/mars-agent_data_20250408205427.jsonl"
# 如果提供了命令行参数,则使用命令行参数作为文件路径

View File

@@ -0,0 +1,129 @@
# 用于处理
#from generate_tool_observation_multiturn import worker
from generate_data.generate_sft_data.utils import read_jsonline_file
import jsonlines
def is_data_with_last_obs(data):
'''用于准备用大模型生成工具结果分析和答案时判断出数据中最后一条消息是否为工具调用的结果observation
是的话,则需要大模型进一步回答;否则也就是上一条消息为答案,即可以跳过大模型的回答'''
if data['messages'][-1]['role']=='user':
return True
else:
return False
def is_obs_none(data):
'''用于检查调用工具生成的结果是否为空'''
if is_data_with_last_obs(data) and data['messages'][-1]['content']==None:
return True
else:
return False
def is_tool_call_last(data):
"""用于检查最后一条消息是否包含工具调用用于生成observation步骤"""
tool_call_str=data['messages'][-1]['content'].split("<answer>")[-1]
if data['messages'][-1]['role']=='assistant' and '<tool_call>' in tool_call_str:
return True
else:
return False
def is_answer_none(data):
# 修正换行符表示并获取answer内容
try:
answer_str = data['messages'][-1]['content'].split("<answer>\n")[-1].split("</answer>\n")[0]
except:
# 如果分割出错,返回原始内容
answer_str = data['messages'][-1]['content']
# 检查是否为assistant角色且answer_str为空或只包含空白字符
if data['messages'][-1]['role'] == 'assistant' and (not answer_str or answer_str.strip() == ''):
return True
else:
return False
from tqdm import tqdm
# 创建进度条
# for data in ans_datas:
# if is_answer_none(data):
# data['messages'].pop()
# with jsonlines.open(output_file_path, mode='a') as writer:
# writer.write(data) # observation . data
#print("c",c)
file_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn4_ans.jsonl'
#output_file_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn2_obs_1.jsonl'
output_datas = read_jsonline_file(file_path)
print(len(output_datas))
from collections import Counter
data_messages_len = []
for data in output_datas:
data_messages_len.append(len(data['messages']))
ele_counts=Counter(data_messages_len)
print(ele_counts)
count=0
for data in output_datas:
if is_data_with_last_obs(data):
count+=1
print("last obs",count)
# for data in output_datas:
# #if is_data_with_last_obs(data):
# if is_answer_none(data):
# #data['messages'].pop()
# c+=1
# # with jsonlines.open(output_file_path, mode='a') as writer:
# # writer.write(data) # observation . data
# print("c",c)
# print(c)
# d=0
# new_data=read_jsonline_file(output_file_path)
# data_lens=[]
# for data in new_data:
# data_lens.append(len(data['messages']))
# #if is_data_with_last_obs(data):
# #d+=1
# print(set(data_lens))
# print(d)
# data['messages'].pop()
# print("数据中最后一条消息为observation且为空")
# worker(data,output_file_path)
# pbar.update(1)
# else:
# with jsonlines.open(output_file_path, mode='a') as writer:
# writer.write(data) # observation . data
# pbar.update(1)
# pbar.close
# from generate_tool_observation_multiturn import main
# raw_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_qwq1.jsonl'
# unfinish_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn2_20250418113514.jsonl'
# unfinish_datas = read_jsonline_file(unfinish_data_path)
# raw_datas = read_jsonline_file(raw_data_path)
# unfinish_datas_question =[unfinish_data['messages'][0]['content'] for unfinish_data in unfinish_datas]
# from generate_tool_observation_multiturn import worker
# filtered_unfinish_datas = []
# for raw_data in raw_datas:
# if raw_data['messages'][0]['content'] in unfinish_datas_question:
# pass
# else:
# filtered_unfinish_datas.append(raw_data)
# # print(raw_data['messages'][-1]['content'])
# # worker(raw_data,unfinish_data_path)
# #exit(0)
# main(filtered_unfinish_datas,unfinish_data_path,16)
#print(len(filtered_unfinish_datas))

View File

@@ -0,0 +1,218 @@
import jsonlines
import argparse
import generate_data.generate_sft_data.utils as utils
import glob
import json
from ase import io
import tempfile
import re
from pymatgen.io.vasp import Poscar
from pymatgen.io.cif import CifParser
import threading
import concurrent.futures
import copy
from generate_data.generate_sft_data.sft_utils import generate_design_question, generate_props_question#, generate_obs_response
import sys
sys.path.append('/home/ubuntu/sas0/lzy/mars-mcp/')
from mars_toolkit import get_tool_schemas
tools=get_tool_schemas()
tools_description=''
tools_description+='\n'.join (json.dumps(tool) for tool in tools)
# Create a lock for file writing
file_lock = threading.Lock()
def generate_deepseek_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1):
messages_ = copy.deepcopy(messages)
system_prompt = '''
你是一个有用的AI助手可以使用以下工具来帮助用户完成任务。
可用工具:
{tools_description}
当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考,
并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案;
若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
'''
messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)})
# instrument='''
# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考,
# 并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案;
# 若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
# '''
# messages_.append({"role": "user", "content": instrument})
_reasoning_content, response = utils.get_response_from_deepseek_r1(
messages_,prefix=False,max_retries=max_retries,initial_backoff=initial_backoff)
return _reasoning_content, response
def generate_qwq_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1):
messages_ = copy.deepcopy(messages)
system_prompt = '''
你是一个有用的AI助手可以使用以下工具来帮助用户完成任务。
可用工具:
{tools_description}.
当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考,工具调用失败的原因可能是因为你写的工具调用方式不规范导致调用时解析错误或者是网络原因等等。
并自行判断这些工具调用的结果能否让你直接给出答案情况1.如果能给出答案,请在回答中直接给出答案;
情况2.若无法给出答案,你将再拥有一次调用多个工具获取信息的机会,请在思考过程中尽可能地调用多个提供给你的工具查询该问题的相关信息(之前的工具调用如果调用失败你也可以自行决定要不要重新调用),而不是直接回答该问题。
情况1.的回答内容为思考内容以及答案情况2.的回答内容为思考内容以及每一个工具调用的名称和参数(不要省略具体的参数内容,因为这会导致无法解析参数)。
思考和回答时使用和问题相同的语言。
'''
messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)})
# instrument='''
# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考,
# 并自行判断这些工具调用的结果能否让你直接给出答案情况1.如果能给出答案,请在回答中直接给出答案;
# 情况2.若无法给出答案,你将拥有一次调用多个工具获取信息的机会,在思考过程中尽可能地描述如何通过调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题,因此,在这种情况下你在回答中只需要一次给出多个工具调用,不需要额外的文字描述。
# '''
# messages_.append({"role": "assistant", "content": instrument})
think_answer,tool_info=utils.get_response_from_qwq(messages_,model_name='qwq-32b',tools=tools,max_retries=max_retries,initial_backoff=initial_backoff)
return think_answer,tool_info
from .filter_messages import is_answer_none,is_data_with_last_obs
def worker(data, output_file_path):
if is_data_with_last_obs(data):
try:
messages = copy.deepcopy(data['messages'])
# print(messages)
# print(obs)
retry=0
think_answer,tool_info = generate_qwq_obs_response(messages,tools_description)
#print("reasoning_content",reasoning_content)
#print("response",response)
data['messages'].append({"role": "assistant", "content": think_answer})
while is_answer_none(data) and retry<5:
# 如果答案为空,则需要重新调用工具
retry+=1
# 重新生成observation
# print("数据中最后一条消息为observation且为空")
# print("retry",retry)
# 重新生成observation
messages = copy.deepcopy(data['messages'].pop())
think_answer,tool_info = generate_qwq_obs_response(messages,tools_description,max_retries=5)
data['messages'].append({"role": "assistant", "content": think_answer})
# Use the lock to safely write to the file
data['observation']=''
data['function_calls']=''
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data)
return f"Processed successfully"
except Exception as e:
return f"Error processing: {str(e)}"
else:
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data)
return f"Processed successfully"
def main(input_file_path, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
datas = None
with jsonlines.open(input_file_path, mode='r') as reader:
datas = [line for line in reader]
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing CIF files")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_data = {}
for data in datas:
future = executor.submit(worker, data, output_file_path)
future_to_data[future] = data
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_data):
data = future_to_data[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {data} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
def main_for_datas(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
#datas = None
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing CIF files")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_data = {}
for data in datas:
future = executor.submit(worker, data, output_file_path)
future_to_data[future] = data
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_data):
data = future_to_data[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {data} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
origin_file = "/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_obs.jsonl"
#output_file = f"/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_turn2_ans_no_none1{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
output_file='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans.jsonl'
main(origin_file, output_file,max_workers=36)
print("output_file",output_file)
output_datas = utils.read_jsonline_file(output_file)
c=0
print("len output_datas",len(output_datas))
for data in output_datas:
if is_answer_none(data):
#print("data len",len(data['messages']))
c+=1
print("answer none number",c)

View File

@@ -1,7 +1,8 @@
import json
import asyncio
import concurrent.futures
import sys
sys.path.append('/home/ubuntu/sas0/lzy/mars-mcp/')
import jsonlines
from mars_toolkit import *
import threading
@@ -19,6 +20,37 @@ init(autoreset=True)
from typing import Dict, Union, Any, Optional, List
import re
def extract_tool_calls(text):
"""
提取字符串中所有包裹在<tool_call>\n</tool_call>\n 中的JSON内容并转换为字典列表
参数:
text (str): 包含工具调用的文本
返回:
list: 包含所有工具调用的字典列表
"""
# 使用正则表达式提取<tool_call>和</tool_call>之间的内容
# (?s)表示让.也匹配换行符,使模式可以跨行匹配
pattern = r'<tool_call>\n(.*?)</tool_call>'
matches = re.finditer(pattern, text, re.DOTALL)
tool_calls = []
for match in matches:
json_str = match.group(1).strip()
try:
# 将JSON字符串转换为Python字典
tool_call_dict = json.loads(json_str)
tool_calls.append(tool_call_dict)
except json.JSONDecodeError as e:
tool_calls.append(f"无法解析JSON: {e},问题字符串{json_str}")
return tool_calls
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化传递给generate_material函数的参数格式
@@ -237,7 +269,7 @@ async def execute_tool_from_dict(input_dict: dict):
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"}
return f"函数 '{func_name}' 不存在于工具函数字典中"
# 获取对应的工具函数
tool_func = tools[func_name]
@@ -265,102 +297,117 @@ async def execute_tool_from_dict(input_dict: dict):
arguments = {"raw_string": arguments_data}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
try:
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
except Exception as e:
result = f'工具函数调用时出错str{e}'
# if func_name=='generate_material':
# print("xxxxx",result)
return result
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
return formatted_result
except json.JSONDecodeError as e:
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
except Exception as e:
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
pass
def worker(data, output_file_path):
try:
func_contents = data["function_calls"]
func_results = []
formatted_results = [] # 新增一个列表来存储格式化后的结果
for func in func_contents:
func_name = func.get("name")
arguments_data = func.get("arguments")
tool_call_str = data['messages'][-1]['content'].split("<answer>")[-1]
if '<tool_call>' in tool_call_str:
# 使用富文本打印函数名
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
func_contents=extract_tool_calls(tool_call_str)
#print(func_contents)
# 使用富文本打印参数
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
#func_results = []
formatted_results = [] # 新增一个列表来存储格式化后的结果
for func in func_contents:
if isinstance(func,Dict):
func_name = func.get("name")
arguments_data = func.get("arguments")
if func.get("name") == 'retrieval_from_knowledge_base':
# 使用富文本打印函数名
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
# 使用富文本打印参数
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
elif func.get("name") == 'generate_material':
try:
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
continue
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
# 优先使用mattergen函数
try:
output = generate_material(**normalized_args)
except Exception as e:
#print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
formatted_result = f"调用时出错,请检查输入的参数,异常为{e}"
# 格式化结果
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
continue
else:
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
# 优先使用mattergen函数
try:
output = generate_material(**normalized_args)
formatted_result = asyncio.run(execute_tool_from_dict(func))
except Exception as e:
#print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
continue
# 将结果添加到func_results
func_results.append({"function": func_name, "result": output})
# 格式化结果
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
continue
else:
result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
func_name = func.get("name")
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['observation'] = final_result
#使用富文本打印开始和结束标记
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
# print(data['observation'])
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
formatted_results.append(formatted_result)
else:
formatted_results.append(func)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['messages'].append({"role": "user", "content": final_result})
#print("last message",data["messages"][-1])
#使用富文本打印开始和结束标记
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
# print(data['observation'])
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
else:
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}"
@@ -407,19 +454,41 @@ def main(datas, output_file_path, max_workers=1):
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
import jsonlines
datas = []
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
total_count=0
filtered_count=0
with jsonlines.open('/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_qwq1.jsonl') as reader:
for obj in reader:
#if obj['solution']!='':
datas.append(obj)
print(len(datas))
# print()
output_file = f"./filter_ok_questions_solutions_agent_data_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=32)
for data in datas:
tool_call_str=data['messages'][-1]['content'].split("<answer>\n")[-1]
if '<tool_call>' in tool_call_str:
filtered_count+=1
total_count+=1
print("total count",total_count)
print("filtered count",filtered_count)
# for data in datas[:5]:
# tool_call_str=data['messages'][-1]['content'].split("<answer>\n")[-1].split("<answer>")[0]
# tool_call_dict_list=extract_tool_calls(tool_call_str)
# for tool_call_dict in tool_call_dict_list:
# print("tool name",tool_call_dict['name'])
# print("tool arguments",tool_call_dict['arguments'])
# print("xxx")
# print("==="*20)
# # print()
# exit()
output_file = f"./agent_questions_solution_turn2_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=48)
# 示例1使用正确的JSON格式
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'

View File

@@ -1,6 +1,6 @@
import jsonlines
import argparse
import generate_data.utils as utils
import generate_data.generate_sft_data.utils as utils
import glob
import json
from ase import io

View File

@@ -5,6 +5,7 @@ It uses the OpenAI API and MySQL for storing and retrieving data.
"""
import multiprocessing
import sqlite3
import jsonlines
import tiktoken
import re
from fractions import Fraction
@@ -48,9 +49,24 @@ def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, ma
messages=messages,
temperature=0.6
)
#print("response",response)
# reasoning_content = "null" if prefix else "<think>\n" + response.choices[0].message.model_extra['reasoning_content'] + "\n</think>\n"
reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
if reasoning_content=='':
reasoning_content=response.choices[0].message.content.split("</think>\n")[1]
# while reasoning_content == "" :
# if retries<max_retries:
# response = client.chat.completions.create(
# model="deepseek-r1",
# messages=messages,
# temperature=0.6
# )
# retries+=1
# else:
# print(f"Max retries exceeded for RateLimitError: {rate_error}")
# return 'apierror', 'apierror'
# reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
content = response.choices[0].message.content.split("</think>\n")[-1]
return reasoning_content, content
@@ -173,13 +189,13 @@ def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = N
retries = 0
while retries <= max_retries:
try:
# client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
# client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
import random
if random.random() > 0.5:
client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
else:
client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
# import random
# if random.random() > 0.5:
# client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
# else:
# client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
if tools is None:
response = client.chat.completions.create(
@@ -295,8 +311,10 @@ def read_json_file(file_path):
print(f"Error reading file {file_path}: {e}")
return None
def read_jsonline_file(file_path):
with jsonlines.open(file_path, mode='r') as reader:
datas = [line for line in reader]
return datas
################################## utils
def clean_all_repetitions_with_details(text, min_length=10, threshold=10):

View File

@@ -1,91 +0,0 @@
import jsonlines
import argparse
import generate_data.utils as utils
import glob
import json
from ase import io
import tempfile
import re
from pymatgen.io.vasp import Poscar
from pymatgen.io.cif import CifParser
import threading
import concurrent.futures
import copy
from grpo_utils import generate_design_question, generate_props_question, generate_obs_response
# Create a lock for file writing
file_lock = threading.Lock()
def worker(data, output_file_path):
try:
messages = copy.deepcopy(data['messages'])
obs = data['observation']
messages[-1]['content'] = messages[-1]['content'].split("<answer>")[-1].split("</answer>")[0]
messages.append({"role": "user", "content": obs})
data['messages'].append({"role": "user", "content": obs})
# print(messages)
# print(obs)
reasoning_content, response = generate_obs_response(messages)
data['messages'].append({"role": "assistant", "content": f"<think>\n{reasoning_content}</think>\n<answer>\n{response}</answer>\n"})
# Use the lock to safely write to the file
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(messages)
return f"Processed successfully"
except Exception as e:
return f"Error processing: {str(e)}"
def main(input_file_path, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
datas = None
with jsonlines.open(input_file_path, mode='r') as reader:
datas = [line for line in reader]
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing CIF files")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_data = {}
for data in datas:
future = executor.submit(worker, data, output_file_path)
future_to_data[future] = data
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_data):
data = future_to_data[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {data} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
origin_file = "/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl"
output_file = f"agent_questions_solutions_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(origin_file, output_file)

13
generate_data/read_data.py Executable file
View File

@@ -0,0 +1,13 @@
import copy
import jsonlines
data_path = '/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl'
#output_path = './agent_questions_solutions_qwq1.jsonl'
with jsonlines.open(data_path, mode='r') as reader:
datas = [line for line in reader if len(line['messages']) == 4]
print(datas[0])

View File

@@ -15,7 +15,7 @@ from mars_toolkit.compute.structure_opt import optimize_crystal_structure, conve
from mars_toolkit.query.mp_query import (
search_material_property_from_material_project,
get_crystal_structures_from_materials_project,
get_mpid_from_formula
#get_mpid_from_formula
)
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base

BIN
mars_toolkit/__pycache__/__init__.cpython-310.pyc Executable file → Normal file

Binary file not shown.

BIN
mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc Executable file → Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
mars_toolkit/core/__pycache__/__init__.cpython-310.pyc Executable file → Normal file

Binary file not shown.

BIN
mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc Executable file → Normal file

Binary file not shown.

BIN
mars_toolkit/core/__pycache__/config.cpython-310.pyc Executable file → Normal file

Binary file not shown.

BIN
mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc Executable file → Normal file

Binary file not shown.

Binary file not shown.

View File

@@ -22,11 +22,11 @@ class Config:
HTTPS_PROXY = 'http://192.168.168.1:20171'
# FairChem
FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
FAIRCHEM_MODEL_PATH = '/home/ubuntu/sas0/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
FMAX = 0.05
# MatterGen
MATTERGENMODEL_ROOT = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
MATTERGENMODEL_ROOT = '/home/ubuntu/sas0/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
MATTERGEN_ROOT='/home/ubuntu/50T/nfs/lzy/mars-mcp/mattergen'
MATTERGENMODEL_RESULT_PATH = 'results/'

BIN
mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc Executable file → Normal file

Binary file not shown.

BIN
mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc Executable file → Normal file

Binary file not shown.

BIN
mars_toolkit/query/__pycache__/__init__.cpython-310.pyc Executable file → Normal file

Binary file not shown.

Binary file not shown.

BIN
mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc Executable file → Normal file

Binary file not shown.

BIN
mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc Executable file → Normal file

Binary file not shown.

BIN
mars_toolkit/query/__pycache__/web_search.cpython-310.pyc Executable file → Normal file

Binary file not shown.

View File

@@ -1,11 +1,13 @@
import logging
import os
import httpx
import pandas as pd
from bs4 import BeautifulSoup
from io import StringIO
from typing import Annotated
from mars_toolkit.core.llm_tools import llm_tool
from ..core import config
from ..core.llm_tools import llm_tool
logger = logging.getLogger(__name__)
@@ -23,6 +25,8 @@ async def fetch_chemical_composition_from_OQMD(
Formatted text with material information and property tables
"""
# Fetch data from OQMD
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
url = f"https://www.oqmd.org/materials/composition/{composition}"
try:
async with httpx.AsyncClient(timeout=100.0) as client:

Binary file not shown.

Binary file not shown.

BIN
prompts/__pycache__/material_synthesis.cpython-310.pyc Executable file → Normal file

Binary file not shown.

177
test_mars_t1.py Normal file
View File

@@ -0,0 +1,177 @@
import asyncio
import json
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
# 创建Rich控制台对象
console = Console()
# 定义分隔符样式
def print_separator(title=""):
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
api_key="gpustack_72b0d41ec69eddab_bce1ea964ddc277ac6aed46b67b03960"
base_url="http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
from mars_toolkit import get_tool_schemas,get_tools
tools_schemas = get_tool_schemas()
tool_map = get_tools()
# 打印消息的函数
def print_message(message):
# 处理不同类型的消息对象
if hasattr(message, 'role'): # ChatCompletionMessage 对象
role = message.role
content = message.content if hasattr(message, 'content') else ""
# 如果是工具消息,获取工具名称
tool_name = None
if role == "tool" and hasattr(message, 'name'):
tool_name = message.name
else: # 字典类型
role = message.get("role", "unknown")
content = message.get("content", "")
# 如果是工具消息,获取工具名称
tool_name = message.get("name") if role == "tool" else None
# 根据角色选择不同的颜色
role_colors = {
"system": "bright_blue",
"user": "green",
"assistant": "yellow",
"tool": "bright_red"
}
color = role_colors.get(role, "white")
# 创建富文本面板
text = Text()
# 如果是工具消息,添加工具名称
if role == "tool" and tool_name:
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
else:
text.append(f"{role}: ", style=f"bold {color}")
text.append(str(content))
console.print(Panel(text, border_style=color))
messages = [
{"role": "system",
"content": "You are MARS-T1, a professional assistant in materials science. You first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> <structured_answer> </structured_answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here <structured_answer> structured answer here <structured_answer> </answer>'"},
{"role": "user", "content": "how to synthesize CsPbBr3 at room temperature"}
]
# 打印初始消息
print_separator("初始消息")
for message in messages:
print_message(message)
finish_reason = None
async def execute_tool(tool_name,tool_arguments):
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
while finish_reason is None or finish_reason == "tool_calls":
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas, # <-- 我们通过 tools 参数,将定义好的 tools 提交给 Kimi 大模型
)
choice = completion.choices[0]
finish_reason = choice.finish_reason
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
# 打印assistant消息
print_separator("Assistant消息")
print_message(choice.message)
# 将ChatCompletionMessage对象转换为字典
assistant_message = {
"role": "assistant",
"content": choice.message.content if hasattr(choice.message, 'content') else None
}
# 如果有工具调用,添加到字典中
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
# 将tool_calls对象转换为字典列表
tool_calls_list = []
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
assistant_message["tool_calls"] = tool_calls_list
# 添加消息到上下文
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
# 打印工具调用信息
print_separator("工具调用")
for tool_call in choice.message.tool_calls:
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
console.print("")
tool_call_name = tool_call.function.name
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object我们需要使用 json.loads 反序列化一下
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
# 构造工具响应消息
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call_name,
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
}
# 打印工具响应
print_separator(f"工具响应: {tool_call_name}")
print_message(tool_message)
# 添加消息到上下文
messages.append(tool_message)
# 打印最终响应
if choice.message.content:
print_separator("最终响应")
console.print(Panel(choice.message.content, border_style="green"))

View File

@@ -172,7 +172,7 @@ if __name__ == "__main__":
]
# 选择要测试的工具
tool_name = tools_to_test[2] # 测试 search_online 工具
tool_name = tools_to_test[7] # 测试 search_online 工具
# 运行测试
result = asyncio.run(test_tool(tool_name))