From a7964add00b4711dfa71ed057f77ee8010dd6d36 Mon Sep 17 00:00:00 2001 From: lzy <949777411@qq.com> Date: Tue, 22 Apr 2025 16:44:26 +0800 Subject: [PATCH] =?UTF-8?q?=E7=94=9F=E6=88=90sft=E6=95=B0=E6=8D=AE,?= =?UTF-8?q?=E8=AE=BE=E7=BD=AEOQMD=E7=9A=84=E4=BB=A3=E7=90=86,=E6=B5=8B?= =?UTF-8?q?=E8=AF=95mars-t1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 +- generate_data/calculate_tokens.py | 79 ++++++ generate_data/generate_data10000.py | 1 + .../generate_sft_data/address_sft_data.py | 77 ++++++ .../generate_sft_data/data_version.md | 1 + .../filter_generate_material_data.py | 2 +- .../generate_sft_data/filter_messages.py | 129 ++++++++++ .../generate_llms_ans_multiturn.py | 218 ++++++++++++++++ .../generate_tool_observation_multiturn.py} | 235 +++++++++++------- .../sft_utils.py} | 2 +- .../{ => generate_sft_data}/utils.py | 36 ++- generate_data/grpo_tools.py | 91 ------- generate_data/read_data.py | 13 + mars_toolkit/__init__.py | 2 +- .../__pycache__/__init__.cpython-310.pyc | Bin 1119 -> 1084 bytes .../__pycache__/__init__.cpython-310.pyc | Bin 602 -> 599 bytes .../__pycache__/material_gen.cpython-310.pyc | Bin 10467 -> 10464 bytes .../__pycache__/property_pred.cpython-310.pyc | Bin 2156 -> 2153 bytes .../__pycache__/structure_opt.cpython-310.pyc | Bin 5233 -> 5230 bytes .../core/__pycache__/__init__.cpython-310.pyc | Bin 364 -> 361 bytes .../__pycache__/cif_utils.cpython-310.pyc | Bin 3398 -> 3395 bytes .../core/__pycache__/config.cpython-310.pyc | Bin 2141 -> 2132 bytes .../__pycache__/llm_tools.cpython-310.pyc | Bin 5229 -> 5226 bytes .../mattergen_wrapper.cpython-310.pyc | Bin 1017 -> 1014 bytes mars_toolkit/core/config.py | 4 +- .../misc/__pycache__/__init__.cpython-310.pyc | Bin 324 -> 321 bytes .../__pycache__/misc_tools.cpython-310.pyc | Bin 1211 -> 1208 bytes .../__pycache__/__init__.cpython-310.pyc | Bin 789 -> 786 bytes .../__pycache__/dify_search.cpython-310.pyc | Bin 1725 -> 1722 bytes .../__pycache__/mp_query.cpython-310.pyc | Bin 12948 -> 12945 bytes .../__pycache__/oqmd_query.cpython-310.pyc | Bin 3010 -> 3144 bytes .../__pycache__/web_search.cpython-310.pyc | Bin 2247 -> 2244 bytes mars_toolkit/query/oqmd_query.py | 6 +- .../__pycache__/__init__.cpython-310.pyc | Bin 458 -> 459 bytes .../mattergen_service.cpython-310.pyc | Bin 10140 -> 10138 bytes .../material_synthesis.cpython-310.pyc | Bin 5363 -> 5364 bytes test_mars_t1.py | 177 +++++++++++++ test_mars_toolkit.py | 2 +- 38 files changed, 888 insertions(+), 191 deletions(-) create mode 100644 generate_data/calculate_tokens.py create mode 100644 generate_data/generate_sft_data/address_sft_data.py create mode 100644 generate_data/generate_sft_data/data_version.md rename generate_data/{ => generate_sft_data}/filter_generate_material_data.py (96%) create mode 100644 generate_data/generate_sft_data/filter_messages.py create mode 100644 generate_data/generate_sft_data/generate_llms_ans_multiturn.py rename generate_data/{generate_tool_observation.py => generate_sft_data/generate_tool_observation_multiturn.py} (63%) rename generate_data/{grpo_utils.py => generate_sft_data/sft_utils.py} (99%) rename generate_data/{ => generate_sft_data}/utils.py (95%) delete mode 100755 generate_data/grpo_tools.py create mode 100755 generate_data/read_data.py mode change 100755 => 100644 mars_toolkit/__pycache__/__init__.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/core/__pycache__/__init__.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/core/__pycache__/config.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/query/__pycache__/__init__.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/query/__pycache__/web_search.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/services/__pycache__/__init__.cpython-310.pyc mode change 100755 => 100644 mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc mode change 100755 => 100644 prompts/__pycache__/material_synthesis.cpython-310.pyc create mode 100644 test_mars_t1.py diff --git a/.gitignore b/.gitignore index 1ce8b20..4b50c46 100755 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,7 @@ pyproject.toml /pretrained_models /mcp-python-sdk /.vscode +*.jsonl -/*filter_ok_questions_solutions_agent* +# 忽略所有目录下的__pycache__文件夹 +__pycache__/ diff --git a/generate_data/calculate_tokens.py b/generate_data/calculate_tokens.py new file mode 100644 index 0000000..ead04d7 --- /dev/null +++ b/generate_data/calculate_tokens.py @@ -0,0 +1,79 @@ +import json +import tiktoken +import numpy as np +import statistics +from pathlib import Path +# 均值: 13716.062458398048 +# 最大值: 106876 +# 最小值: 5108 +# 中值: 13285.5 +# 样本数: 9014 +def count_tokens_in_string(text): + """使用tiktoken库计算字符串中的token数量""" + # 使用cl100k_base编码器,这是GPT-4使用的编码器 + encoding = tiktoken.get_encoding("cl100k_base") + # 计算tokens + tokens = encoding.encode(text) + return len(tokens) + +def process_jsonl_file(file_path): + """处理JSONL文件并计算token统计信息""" + token_counts = [] + count=0 + # 读取JSONL文件 + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + try: + # 解析JSON行 + data = json.loads(line) + if len(data['messages'])==4: + # 将数据转换为字符串 + count+=1 + data_str = json.dumps(data) + # 计算tokens + token_count = count_tokens_in_string(data_str) + token_counts.append(token_count) + else: + pass + except Exception as e: + print(f"处理行时出错: {e}") + print("countnumber",count) + # 计算统计信息 + if token_counts: + mean_value = statistics.mean(token_counts) + max_value = max(token_counts) + min_value = min(token_counts) + median_value = statistics.median(token_counts) + + # 计算token数小于32k的样本数量 + count_less_than_32k = sum(1 for count in token_counts if count < 32000) + count_less_than_24k = sum(1 for count in token_counts if count < 24000) + count_less_than_16k = sum(1 for count in token_counts if count < 16000) + return { + "均值": mean_value, + "最大值": max_value, + "最小值": min_value, + "中值": median_value, + "样本数": len(token_counts), + "token数小于32k的样本数": count_less_than_32k, + "token数小于32k的样本百分比": (count_less_than_32k / len(token_counts)) * 100 if token_counts else 0, + "token数小于24k的样本数": count_less_than_24k, + "token数小于24k的样本百分比": (count_less_than_24k / len(token_counts)) * 100 if token_counts else 0, + "token数小于16k的样本数": count_less_than_16k, + "token数小于16k的样本百分比": (count_less_than_16k / len(token_counts)) * 100 if token_counts else 0 + } + else: + return {"错误": "没有找到有效数据"} + +if __name__ == "__main__": + file_path = "/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl" + + # 确认文件存在 + if not Path(file_path).exists(): + print(f"错误: 文件不存在 - {file_path}") + else: + # 处理文件并打印结果 + results = process_jsonl_file(file_path) + print("\n统计结果:") + for key, value in results.items(): + print(f"{key}: {value}") diff --git a/generate_data/generate_data10000.py b/generate_data/generate_data10000.py index 36e1857..bb9517a 100755 --- a/generate_data/generate_data10000.py +++ b/generate_data/generate_data10000.py @@ -1,3 +1,4 @@ +# 原始数据分为两类 一种是带solution的,一种是没有solution的,这个是提取了各5000条 import json import asyncio import concurrent.futures diff --git a/generate_data/generate_sft_data/address_sft_data.py b/generate_data/generate_sft_data/address_sft_data.py new file mode 100644 index 0000000..9d8d623 --- /dev/null +++ b/generate_data/generate_sft_data/address_sft_data.py @@ -0,0 +1,77 @@ +# 这个代码用于生成/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl +# 往messages中添加完调用工具的消息后的下一步处理, +#address_failed_data用于处理调用大模型返回失败的样本 +#address_non_answer_data 用于处理大模型返回的answer为空的样本 +#address_tool_call_data 用于处理answer 中需要调用工具的代码 +from generate_data.generate_sft_data.utils import read_jsonline_file +from generate_data.generate_sft_data.generate_llms_ans_multiturn import main_for_datas +from generate_data.generate_sft_data.generate_tool_observation_multiturn import main +from collections import Counter +def print_message_len(data_path): + datas = read_jsonline_file(data_path) + data_messages_len = [] + for data in datas: + data_messages_len.append(len(data['messages'])) + ele_counts=Counter(data_messages_len) + print("message len and number:",ele_counts) + +def address_failed_data(original_data_path,failed_data_path): + original_datas=read_jsonline_file(original_data_path) + #original_datas_question = [original_data['messages'][0]['content'] for original_data in original_datas] + failed_datas = read_jsonline_file(failed_data_path) + failed_datas_question = [failed_data['messages'][0]['content'] for failed_data in failed_datas] + need_to_address_datas = [] + for original_data in original_datas: + if original_data['messages'][0]['content'] in failed_datas_question: + pass + else: + need_to_address_datas.append(original_data) + print('need to address',len(need_to_address_datas)) + main_for_datas(need_to_address_datas,failed_data_path,3) + print("after process") + new_datas=read_jsonline_file(failed_data_path) + print("new datas num",len(new_datas)) +from .filter_messages import is_answer_none +def address_non_answer_data(non_answer_data_path,output_path): + non_answer_datas = read_jsonline_file(non_answer_data_path) + new_datas=[] + need_to_address_datas = [] + for data in non_answer_datas: + if is_answer_none(data): + data['messages'].pop() + need_to_address_datas.append(data) + + else: + new_datas.append(data) + print('non answer number',len(need_to_address_datas)) + main_for_datas(non_answer_datas,output_path,3) + print("afer process",) + new_output_datas=read_jsonline_file(output_path) + print("new output data num",len(new_output_datas)) + print("answer is none number",len([data for data in new_output_datas if is_answer_none(data)])) + +from .filter_messages import is_tool_call_last +def address_tool_call_data(data_path,output_path): + datas = read_jsonline_file(data_path) + need_to_address_datas = [] + print("total data",len(datas)) + for data in datas: + if is_tool_call_last(data): + need_to_address_datas.append(data) + else: + pass + print("need to address",len(need_to_address_datas)) + # main(datas,output_file_path=output_path,max_workers=7) + # print("after process") + # new_output_datas=read_jsonline_file(output_path) + # print("new output data num",len(new_output_datas)) + # print("need to tool call number",len([data for data in new_output_datas if is_tool_call_last(data)])) + #main_1(need_to_address_datas,non_answer_data_path,48) +original_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl' +failed_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl' +print_message_len(original_data_path) +#address_failed_data(original_data_path,failed_data_path) + + +#address_non_answer_data(original_data_path,'/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl') +address_tool_call_data(original_data_path,'/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn6_obs.jsonl') diff --git a/generate_data/generate_sft_data/data_version.md b/generate_data/generate_sft_data/data_version.md new file mode 100644 index 0000000..cd471c2 --- /dev/null +++ b/generate_data/generate_sft_data/data_version.md @@ -0,0 +1 @@ +大概是345的格式才是统一的,即;obs是最后一条message为工具调用的,ans是工具调用后再调用qwq32b产生的结果,ans-no-none是重复调用大模型重新生成answer中为空的数据 diff --git a/generate_data/filter_generate_material_data.py b/generate_data/generate_sft_data/filter_generate_material_data.py similarity index 96% rename from generate_data/filter_generate_material_data.py rename to generate_data/generate_sft_data/filter_generate_material_data.py index 1d5f2a7..f8c88c1 100755 --- a/generate_data/filter_generate_material_data.py +++ b/generate_data/generate_sft_data/filter_generate_material_data.py @@ -78,7 +78,7 @@ def filter_generate_material(file_path): if __name__ == "__main__": # 默认文件路径 - file_path ='/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl' + file_path ='/home/ubuntu/50T/nfs/lzy/mars-mcp/agent_questions_solutions_test20250416152446.jsonl' # "/home/ubuntu/50T/lzy/mars-mcp/mars-agent_data_20250408205427.jsonl" # 如果提供了命令行参数,则使用命令行参数作为文件路径 diff --git a/generate_data/generate_sft_data/filter_messages.py b/generate_data/generate_sft_data/filter_messages.py new file mode 100644 index 0000000..52e68ff --- /dev/null +++ b/generate_data/generate_sft_data/filter_messages.py @@ -0,0 +1,129 @@ +# 用于处理 +#from generate_tool_observation_multiturn import worker +from generate_data.generate_sft_data.utils import read_jsonline_file +import jsonlines + +def is_data_with_last_obs(data): + + '''用于准备用大模型生成工具结果分析和答案时,判断出数据中最后一条消息是否为工具调用的结果observation, + 是的话,则需要大模型进一步回答;否则也就是上一条消息为答案,即可以跳过大模型的回答''' + if data['messages'][-1]['role']=='user': + return True + else: + return False +def is_obs_none(data): + '''用于检查调用工具生成的结果是否为空''' + if is_data_with_last_obs(data) and data['messages'][-1]['content']==None: + return True + else: + return False + +def is_tool_call_last(data): + """用于检查最后一条消息是否包含工具调用,用于生成observation步骤""" + tool_call_str=data['messages'][-1]['content'].split("")[-1] + if data['messages'][-1]['role']=='assistant' and '' in tool_call_str: + return True + else: + return False +def is_answer_none(data): + # 修正换行符表示并获取answer内容 + try: + answer_str = data['messages'][-1]['content'].split("\n")[-1].split("\n")[0] + except: + # 如果分割出错,返回原始内容 + answer_str = data['messages'][-1]['content'] + + # 检查是否为assistant角色且answer_str为空或只包含空白字符 + if data['messages'][-1]['role'] == 'assistant' and (not answer_str or answer_str.strip() == ''): + return True + else: + return False + +from tqdm import tqdm + + # 创建进度条 + + +# for data in ans_datas: +# if is_answer_none(data): +# data['messages'].pop() +# with jsonlines.open(output_file_path, mode='a') as writer: +# writer.write(data) # observation . data + +#print("c",c) +file_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn4_ans.jsonl' +#output_file_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn2_obs_1.jsonl' +output_datas = read_jsonline_file(file_path) +print(len(output_datas)) +from collections import Counter +data_messages_len = [] +for data in output_datas: + data_messages_len.append(len(data['messages'])) +ele_counts=Counter(data_messages_len) +print(ele_counts) + +count=0 +for data in output_datas: + if is_data_with_last_obs(data): + count+=1 +print("last obs",count) + +# for data in output_datas: +# #if is_data_with_last_obs(data): +# if is_answer_none(data): +# #data['messages'].pop() +# c+=1 +# # with jsonlines.open(output_file_path, mode='a') as writer: +# # writer.write(data) # observation . data +# print("c",c) + +# print(c) +# d=0 +# new_data=read_jsonline_file(output_file_path) +# data_lens=[] +# for data in new_data: +# data_lens.append(len(data['messages'])) + +# #if is_data_with_last_obs(data): +# #d+=1 +# print(set(data_lens)) +# print(d) +# data['messages'].pop() + +# print("数据中最后一条消息为observation且为空") + +# worker(data,output_file_path) +# pbar.update(1) +# else: + +# with jsonlines.open(output_file_path, mode='a') as writer: +# writer.write(data) # observation . data +# pbar.update(1) +# pbar.close +# from generate_tool_observation_multiturn import main +# raw_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_qwq1.jsonl' +# unfinish_data_path='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn2_20250418113514.jsonl' + +# unfinish_datas = read_jsonline_file(unfinish_data_path) +# raw_datas = read_jsonline_file(raw_data_path) +# unfinish_datas_question =[unfinish_data['messages'][0]['content'] for unfinish_data in unfinish_datas] + +# from generate_tool_observation_multiturn import worker +# filtered_unfinish_datas = [] +# for raw_data in raw_datas: + +# if raw_data['messages'][0]['content'] in unfinish_datas_question: + +# pass +# else: +# filtered_unfinish_datas.append(raw_data) +# # print(raw_data['messages'][-1]['content']) + +# # worker(raw_data,unfinish_data_path) +# #exit(0) +# main(filtered_unfinish_datas,unfinish_data_path,16) +#print(len(filtered_unfinish_datas)) + + + + diff --git a/generate_data/generate_sft_data/generate_llms_ans_multiturn.py b/generate_data/generate_sft_data/generate_llms_ans_multiturn.py new file mode 100644 index 0000000..188d660 --- /dev/null +++ b/generate_data/generate_sft_data/generate_llms_ans_multiturn.py @@ -0,0 +1,218 @@ + +import jsonlines +import argparse +import generate_data.generate_sft_data.utils as utils +import glob +import json +from ase import io +import tempfile +import re +from pymatgen.io.vasp import Poscar +from pymatgen.io.cif import CifParser +import threading +import concurrent.futures +import copy +from generate_data.generate_sft_data.sft_utils import generate_design_question, generate_props_question#, generate_obs_response +import sys +sys.path.append('/home/ubuntu/sas0/lzy/mars-mcp/') + +from mars_toolkit import get_tool_schemas +tools=get_tool_schemas() +tools_description='' +tools_description+='\n'.join (json.dumps(tool) for tool in tools) +# Create a lock for file writing +file_lock = threading.Lock() + + +def generate_deepseek_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1): + messages_ = copy.deepcopy(messages) + system_prompt = ''' + 你是一个有用的AI助手,可以使用以下工具来帮助用户完成任务。 +可用工具: +{tools_description}, +当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考, +并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案; +若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。 +''' + messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)}) + + +# instrument=''' +# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考, +# 并自行判断这些工具调用的结果能否让你直接给出答案,如果能给出答案,请在回答中直接给出答案; +# 若无法给出答案,你将拥有一次调用工具获取信息的机会,请在思考过程中尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。 +# ''' +# messages_.append({"role": "user", "content": instrument}) + _reasoning_content, response = utils.get_response_from_deepseek_r1( + messages_,prefix=False,max_retries=max_retries,initial_backoff=initial_backoff) + return _reasoning_content, response +def generate_qwq_obs_response(messages,tools_description=None,max_retries=3,initial_backoff=1): + messages_ = copy.deepcopy(messages) + system_prompt = ''' + 你是一个有用的AI助手,可以使用以下工具来帮助用户完成任务。 +可用工具: +{tools_description}. +当用户提供给你工具调用的结果后,请你根据当前所拥有的信息继续深入思考,工具调用失败的原因可能是因为你写的工具调用方式不规范导致调用时解析错误或者是网络原因等等。 +并自行判断这些工具调用的结果能否让你直接给出答案,情况1.如果能给出答案,请在回答中直接给出答案; +情况2.若无法给出答案,你将再拥有一次调用多个工具获取信息的机会,请在思考过程中尽可能地调用多个提供给你的工具查询该问题的相关信息(之前的工具调用如果调用失败你也可以自行决定要不要重新调用),而不是直接回答该问题。 +情况1.的回答内容为思考内容以及答案,情况2.的回答内容为思考内容以及每一个工具调用的名称和参数(不要省略具体的参数内容,因为这会导致无法解析参数)。 +思考和回答时使用和问题相同的语言。 +''' + messages_.insert(0, {"role": "system", "content": system_prompt.format(tools_description=tools_description)}) + + +# instrument=''' +# 上面是你请求的工具调用的结果,请根据现有的信息继续深入思考, +# 并自行判断这些工具调用的结果能否让你直接给出答案,情况1.如果能给出答案,请在回答中直接给出答案; +# 情况2.若无法给出答案,你将拥有一次调用多个工具获取信息的机会,在思考过程中尽可能地描述如何通过调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题,因此,在这种情况下你在回答中只需要一次给出多个工具调用,不需要额外的文字描述。 + +# ''' + # messages_.append({"role": "assistant", "content": instrument}) + think_answer,tool_info=utils.get_response_from_qwq(messages_,model_name='qwq-32b',tools=tools,max_retries=max_retries,initial_backoff=initial_backoff) + return think_answer,tool_info +from .filter_messages import is_answer_none,is_data_with_last_obs + +def worker(data, output_file_path): + if is_data_with_last_obs(data): + try: + messages = copy.deepcopy(data['messages']) + + # print(messages) + # print(obs) + retry=0 + think_answer,tool_info = generate_qwq_obs_response(messages,tools_description) + + #print("reasoning_content",reasoning_content) + #print("response",response) + data['messages'].append({"role": "assistant", "content": think_answer}) + while is_answer_none(data) and retry<5: + # 如果答案为空,则需要重新调用工具 + retry+=1 + # 重新生成observation + # print("数据中最后一条消息为observation且为空") + # print("retry",retry) + # 重新生成observation + messages = copy.deepcopy(data['messages'].pop()) + think_answer,tool_info = generate_qwq_obs_response(messages,tools_description,max_retries=5) + data['messages'].append({"role": "assistant", "content": think_answer}) + + + # Use the lock to safely write to the file + data['observation']='' + data['function_calls']='' + with file_lock: + with jsonlines.open(output_file_path, mode='a') as writer: + writer.write(data) + + return f"Processed successfully" + except Exception as e: + return f"Error processing: {str(e)}" + else: + with file_lock: + with jsonlines.open(output_file_path, mode='a') as writer: + writer.write(data) + return f"Processed successfully" +def main(input_file_path, output_file_path, max_workers=1): + import random + from tqdm import tqdm + import os + + datas = None + with jsonlines.open(input_file_path, mode='r') as reader: + datas = [line for line in reader] + + # 创建进度条 + pbar = tqdm(total=len(datas), desc="Processing CIF files") + + # 创建一个线程池 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交任务到执行器 + future_to_data = {} + for data in datas: + future = executor.submit(worker, data, output_file_path) + future_to_data[future] = data + + # 处理结果 + completed = 0 + failed = 0 + for future in concurrent.futures.as_completed(future_to_data): + data = future_to_data[future] + try: + result = future.result() + if "successfully" in result: + completed += 1 + else: + failed += 1 + # 更新进度条 + pbar.update(1) + # 每100个文件更新一次统计信息 + if (completed + failed) % 100 == 0: + pbar.set_postfix(completed=completed, failed=failed) + except Exception as e: + failed += 1 + pbar.update(1) + print(f"\nWorker for {data} generated an exception: {e}") + + pbar.close() + print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") + + +def main_for_datas(datas, output_file_path, max_workers=1): + import random + from tqdm import tqdm + import os + + #datas = None + + + # 创建进度条 + pbar = tqdm(total=len(datas), desc="Processing CIF files") + + # 创建一个线程池 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交任务到执行器 + future_to_data = {} + for data in datas: + future = executor.submit(worker, data, output_file_path) + future_to_data[future] = data + + # 处理结果 + completed = 0 + failed = 0 + for future in concurrent.futures.as_completed(future_to_data): + data = future_to_data[future] + try: + result = future.result() + if "successfully" in result: + completed += 1 + else: + failed += 1 + # 更新进度条 + pbar.update(1) + # 每100个文件更新一次统计信息 + if (completed + failed) % 100 == 0: + pbar.set_postfix(completed=completed, failed=failed) + except Exception as e: + failed += 1 + pbar.update(1) + print(f"\nWorker for {data} generated an exception: {e}") + + pbar.close() + print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") + +if __name__ == '__main__': + import datetime + origin_file = "/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_obs.jsonl" + #output_file = f"/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_turn2_ans_no_none1{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" + + output_file='/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans.jsonl' + main(origin_file, output_file,max_workers=36) + print("output_file",output_file) + output_datas = utils.read_jsonline_file(output_file) + c=0 + print("len output_datas",len(output_datas)) + for data in output_datas: + if is_answer_none(data): + #print("data len",len(data['messages'])) + c+=1 + print("answer none number",c) diff --git a/generate_data/generate_tool_observation.py b/generate_data/generate_sft_data/generate_tool_observation_multiturn.py similarity index 63% rename from generate_data/generate_tool_observation.py rename to generate_data/generate_sft_data/generate_tool_observation_multiturn.py index 02eb9e1..6b6a208 100755 --- a/generate_data/generate_tool_observation.py +++ b/generate_data/generate_sft_data/generate_tool_observation_multiturn.py @@ -1,7 +1,8 @@ import json import asyncio import concurrent.futures - +import sys +sys.path.append('/home/ubuntu/sas0/lzy/mars-mcp/') import jsonlines from mars_toolkit import * import threading @@ -19,6 +20,37 @@ init(autoreset=True) from typing import Dict, Union, Any, Optional, List +import re +def extract_tool_calls(text): + """ + 提取字符串中所有包裹在\n\n 中的JSON内容并转换为字典列表 + + 参数: + text (str): 包含工具调用的文本 + + 返回: + list: 包含所有工具调用的字典列表 + """ + # 使用正则表达式提取之间的内容 + # (?s)表示让.也匹配换行符,使模式可以跨行匹配 + pattern = r'\n(.*?)' + matches = re.finditer(pattern, text, re.DOTALL) + + tool_calls = [] + + for match in matches: + json_str = match.group(1).strip() + try: + # 将JSON字符串转换为Python字典 + tool_call_dict = json.loads(json_str) + tool_calls.append(tool_call_dict) + except json.JSONDecodeError as e: + tool_calls.append(f"无法解析JSON: {e},问题字符串{json_str}") + + + return tool_calls + + def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: """ 规范化传递给generate_material函数的参数格式。 @@ -237,7 +269,7 @@ async def execute_tool_from_dict(input_dict: dict): # 检查函数名是否存在于工具函数字典中 if func_name not in tools: - return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"} + return f"函数 '{func_name}' 不存在于工具函数字典中" # 获取对应的工具函数 tool_func = tools[func_name] @@ -265,102 +297,117 @@ async def execute_tool_from_dict(input_dict: dict): arguments = {"raw_string": arguments_data} # 调用工具函数 - if asyncio.iscoroutinefunction(tool_func): - # 如果是异步函数,使用await调用 - result = await tool_func(**arguments) - else: - # 如果是同步函数,直接调用 - result = tool_func(**arguments) + try: + if asyncio.iscoroutinefunction(tool_func): + # 如果是异步函数,使用await调用 + result = await tool_func(**arguments) + else: + # 如果是同步函数,直接调用 + result = tool_func(**arguments) + except Exception as e: + result = f'工具函数调用时出错:str{e}' # if func_name=='generate_material': # print("xxxxx",result) - return result + formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + return formatted_result except json.JSONDecodeError as e: return {"status": "error", "message": f"JSON解析错误: {str(e)}"} except Exception as e: return {"status": "error", "message": f"执行过程中出错: {str(e)}"} - + pass def worker(data, output_file_path): try: - func_contents = data["function_calls"] - func_results = [] - formatted_results = [] # 新增一个列表来存储格式化后的结果 - for func in func_contents: - func_name = func.get("name") - arguments_data = func.get("arguments") + tool_call_str = data['messages'][-1]['content'].split("")[-1] + if '' in tool_call_str: - # 使用富文本打印函数名 - #print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + func_contents=extract_tool_calls(tool_call_str) + #print(func_contents) - # 使用富文本打印参数 - #print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") + #func_results = [] + formatted_results = [] # 新增一个列表来存储格式化后的结果 + for func in func_contents: + if isinstance(func,Dict): + func_name = func.get("name") + arguments_data = func.get("arguments") - if func.get("name") == 'retrieval_from_knowledge_base': + # 使用富文本打印函数名 + #print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") - # delay_time = random.uniform(5, 10) - # time.sleep(delay_time) - result = asyncio.run(process_retrieval_from_knowledge_base(data)) - func_results.append({"function": func['name'], "result": result}) - # 格式化结果 - formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" - formatted_results.append(formatted_result) + # 使用富文本打印参数 + #print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") + + if func.get("name") == 'retrieval_from_knowledge_base': + + # delay_time = random.uniform(5, 10) + # time.sleep(delay_time) + result = asyncio.run(process_retrieval_from_knowledge_base(data)) + + # 格式化结果 + formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + formatted_results.append(formatted_result) + + elif func.get("name") == 'generate_material': - elif func.get("name") == 'generate_material': - try: - # 确保arguments_data是字典 - if isinstance(arguments_data, str): try: - arguments_data = json.loads(arguments_data) - except json.JSONDecodeError as e: - #print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") + # 确保arguments_data是字典 + if isinstance(arguments_data, str): + try: + arguments_data = json.loads(arguments_data) + except json.JSONDecodeError as e: + #print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") + continue + + # 规范化参数 + normalized_args = normalize_material_args(arguments_data) + + # 优先使用mattergen函数 + try: + + output = generate_material(**normalized_args) + + + except Exception as e: + #print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") + formatted_result = f"调用时出错,请检查输入的参数,异常为{e}" + + + + # 格式化结果 + formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" + formatted_results.append(formatted_result) + except Exception as e: + #print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") + import traceback + #print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") continue + else: - # 规范化参数 - normalized_args = normalize_material_args(arguments_data) - - # 优先使用mattergen函数 - try: - - output = generate_material(**normalized_args) + formatted_result = asyncio.run(execute_tool_from_dict(func)) - except Exception as e: - #print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") - continue - # 将结果添加到func_results - func_results.append({"function": func_name, "result": output}) - - # 格式化结果 - formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" - formatted_results.append(formatted_result) - except Exception as e: - #print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") - import traceback - #print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") - continue - else: - - result = asyncio.run(execute_tool_from_dict(func)) - func_results.append({"function": func['name'], "result": result}) - # 格式化结果 - func_name = func.get("name") - formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" - formatted_results.append(formatted_result) - - # 将所有格式化后的结果连接起来 - final_result = "\n\n\n".join(formatted_results) - data['observation'] = final_result - - #使用富文本打印开始和结束标记 - # print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") - # print(data['observation']) - # print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") - with file_lock: - with jsonlines.open(output_file_path, mode='a') as writer: - writer.write(data) # observation . data - return f"Processed successfully" + formatted_results.append(formatted_result) + else: + formatted_results.append(func) + # 将所有格式化后的结果连接起来 + final_result = "\n\n\n".join(formatted_results) + data['messages'].append({"role": "user", "content": final_result}) + #print("last message",data["messages"][-1]) + #使用富文本打印开始和结束标记 + # print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") + # print(data['observation']) + # print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") + with file_lock: + with jsonlines.open(output_file_path, mode='a') as writer: + writer.write(data) # observation . data + return f"Processed successfully" + else: + with file_lock: + with jsonlines.open(output_file_path, mode='a') as writer: + writer.write(data) # observation . data + return f"Processed successfully" except Exception as e: #print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") return f"Error processing: {str(e)}" @@ -407,19 +454,41 @@ def main(datas, output_file_path, max_workers=1): print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") + + + + + if __name__ == '__main__': import datetime import jsonlines datas = [] - with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader: + total_count=0 + filtered_count=0 + with jsonlines.open('/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_qwq1.jsonl') as reader: for obj in reader: - #if obj['solution']!='': datas.append(obj) - print(len(datas)) - # print() - output_file = f"./filter_ok_questions_solutions_agent_data_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" - main(datas, output_file, max_workers=32) + for data in datas: + tool_call_str=data['messages'][-1]['content'].split("\n")[-1] + if '' in tool_call_str: + filtered_count+=1 + total_count+=1 + print("total count",total_count) + print("filtered count",filtered_count) + + # for data in datas[:5]: + # tool_call_str=data['messages'][-1]['content'].split("\n")[-1].split("")[0] + # tool_call_dict_list=extract_tool_calls(tool_call_str) + # for tool_call_dict in tool_call_dict_list: + # print("tool name",tool_call_dict['name']) + # print("tool arguments",tool_call_dict['arguments']) + # print("xxx") + # print("==="*20) + # # print() + # exit() + output_file = f"./agent_questions_solution_turn2_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" + main(datas, output_file, max_workers=48) # 示例1:使用正确的JSON格式 # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' diff --git a/generate_data/grpo_utils.py b/generate_data/generate_sft_data/sft_utils.py similarity index 99% rename from generate_data/grpo_utils.py rename to generate_data/generate_sft_data/sft_utils.py index 9f16107..2b9e5cf 100755 --- a/generate_data/grpo_utils.py +++ b/generate_data/generate_sft_data/sft_utils.py @@ -1,6 +1,6 @@ import jsonlines import argparse -import generate_data.utils as utils +import generate_data.generate_sft_data.utils as utils import glob import json from ase import io diff --git a/generate_data/utils.py b/generate_data/generate_sft_data/utils.py similarity index 95% rename from generate_data/utils.py rename to generate_data/generate_sft_data/utils.py index 544d958..8fcca14 100755 --- a/generate_data/utils.py +++ b/generate_data/generate_sft_data/utils.py @@ -5,6 +5,7 @@ It uses the OpenAI API and MySQL for storing and retrieving data. """ import multiprocessing import sqlite3 +import jsonlines import tiktoken import re from fractions import Fraction @@ -48,9 +49,24 @@ def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, ma messages=messages, temperature=0.6 ) + #print("response",response) - # reasoning_content = "null" if prefix else "\n" + response.choices[0].message.model_extra['reasoning_content'] + "\n\n" reasoning_content = response.choices[0].message.content.split("\n")[0].split("\n")[-1] + if reasoning_content=='': + reasoning_content=response.choices[0].message.content.split("\n")[1] + # while reasoning_content == "" : + # if retries\n")[0].split("\n")[-1] + content = response.choices[0].message.content.split("\n")[-1] return reasoning_content, content @@ -173,13 +189,13 @@ def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = N retries = 0 while retries <= max_retries: try: - # client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1") + client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1") # client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") - import random - if random.random() > 0.5: - client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") - else: - client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") + # import random + # if random.random() > 0.5: + # client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") + # else: + # client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") # messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] if tools is None: response = client.chat.completions.create( @@ -295,8 +311,10 @@ def read_json_file(file_path): print(f"Error reading file {file_path}: {e}") return None - - +def read_jsonline_file(file_path): + with jsonlines.open(file_path, mode='r') as reader: + datas = [line for line in reader] + return datas ################################## utils def clean_all_repetitions_with_details(text, min_length=10, threshold=10): diff --git a/generate_data/grpo_tools.py b/generate_data/grpo_tools.py deleted file mode 100755 index e06cc70..0000000 --- a/generate_data/grpo_tools.py +++ /dev/null @@ -1,91 +0,0 @@ -import jsonlines -import argparse -import generate_data.utils as utils -import glob -import json -from ase import io -import tempfile -import re -from pymatgen.io.vasp import Poscar -from pymatgen.io.cif import CifParser -import threading -import concurrent.futures -import copy -from grpo_utils import generate_design_question, generate_props_question, generate_obs_response -# Create a lock for file writing -file_lock = threading.Lock() - - -def worker(data, output_file_path): - try: - messages = copy.deepcopy(data['messages']) - obs = data['observation'] - messages[-1]['content'] = messages[-1]['content'].split("")[-1].split("")[0] - messages.append({"role": "user", "content": obs}) - data['messages'].append({"role": "user", "content": obs}) - # print(messages) - # print(obs) - - reasoning_content, response = generate_obs_response(messages) - data['messages'].append({"role": "assistant", "content": f"\n{reasoning_content}\n\n{response}\n"}) - # Use the lock to safely write to the file - with file_lock: - with jsonlines.open(output_file_path, mode='a') as writer: - writer.write(messages) - - return f"Processed successfully" - except Exception as e: - return f"Error processing: {str(e)}" - - -def main(input_file_path, output_file_path, max_workers=1): - import random - from tqdm import tqdm - import os - - datas = None - with jsonlines.open(input_file_path, mode='r') as reader: - datas = [line for line in reader] - - # 创建进度条 - pbar = tqdm(total=len(datas), desc="Processing CIF files") - - # 创建一个线程池 - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - # 提交任务到执行器 - future_to_data = {} - for data in datas: - future = executor.submit(worker, data, output_file_path) - future_to_data[future] = data - - # 处理结果 - completed = 0 - failed = 0 - for future in concurrent.futures.as_completed(future_to_data): - data = future_to_data[future] - try: - result = future.result() - if "successfully" in result: - completed += 1 - else: - failed += 1 - # 更新进度条 - pbar.update(1) - # 每100个文件更新一次统计信息 - if (completed + failed) % 100 == 0: - pbar.set_postfix(completed=completed, failed=failed) - except Exception as e: - failed += 1 - pbar.update(1) - print(f"\nWorker for {data} generated an exception: {e}") - - pbar.close() - print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") - - -if __name__ == '__main__': - import datetime - origin_file = "/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl" - output_file = f"agent_questions_solutions_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" - main(origin_file, output_file) - diff --git a/generate_data/read_data.py b/generate_data/read_data.py new file mode 100755 index 0000000..8542dd1 --- /dev/null +++ b/generate_data/read_data.py @@ -0,0 +1,13 @@ +import copy +import jsonlines + +data_path = '/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solution_turn5_ans_no_none.jsonl' +#output_path = './agent_questions_solutions_qwq1.jsonl' +with jsonlines.open(data_path, mode='r') as reader: + datas = [line for line in reader if len(line['messages']) == 4] + +print(datas[0]) + + + + diff --git a/mars_toolkit/__init__.py b/mars_toolkit/__init__.py index 8c760db..563e016 100755 --- a/mars_toolkit/__init__.py +++ b/mars_toolkit/__init__.py @@ -15,7 +15,7 @@ from mars_toolkit.compute.structure_opt import optimize_crystal_structure, conve from mars_toolkit.query.mp_query import ( search_material_property_from_material_project, get_crystal_structures_from_materials_project, - get_mpid_from_formula + #get_mpid_from_formula ) from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD from mars_toolkit.query.dify_search import retrieval_from_knowledge_base diff --git a/mars_toolkit/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/__pycache__/__init__.cpython-310.pyc old mode 100755 new mode 100644 index 10a881077d313868a9940da5f4008018fe8d726c..832449009c3c7b1a9e3b692df0ff85ddc91758a9 GIT binary patch delta 225 zcmcc5v4?{+5y%yc5(Kk3 zbA)n*qlCe1t{jnE(J0Yeu_&=z@hEYi7A0$?^rj$p1( zln|KBnIoJl5+wp=bLEKUibaX#ibsj(N<>Kj#kiv+(-~5@7BNOiO+3`b$UNDTF<$(Z zXnJZ%d~QKzN_<*Ter|kPeo<~|PU7Shj6ICflTDdkFcoo4mSR?B=*9(s$ZN~Y%p1n$rAt>cMUNB delta 47 zcmcc4a*Ks0pO=@50SG4i_?y0w=L4gtlYT~iZmNE1QfXdEslKT}h<;vL@nk+GPXKEg B4?X|@ diff --git a/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc b/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc old mode 100755 new mode 100644 index ebc0e1916e939158cb939186195982cbbb2698d5..fbed40be3535dff6ede650999d3c6dc8ad7a5d45 GIT binary patch delta 45 zcmaDH_#lucpO=@50SJ!eFl21xQQ#1A*3Zb#P1P?=D$Oe?)h|vgHrVXQu}A>`G0hF< delta 48 zcmaD5_&AU!pO=@50SG4i_?y0wM}b4sO+O<)H&wqhsWh*oRNvGfL_aUBc(WbHA_V}2 CG7t0s diff --git a/mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc b/mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc old mode 100755 new mode 100644 index 55e5c50675e68d03e13bfce58268f6864157445d..79ed1b2d2ce86eb8bf9570cb6cd64d96fbf43b4d GIT binary patch delta 45 zcmaDO@KS&$pO=@50SJ!eFl21xsbCXw(a*@wP1P?=D$Oe?)h|vgHrPCq?J^?(DS-|) delta 48 zcmaDU@J4_qpO=@50SG4i_?y0wr-Dt?T|Xl~H&wqhsWh*oRNvGfL_aUBc=I&2%Zvbg Cn-D_) diff --git a/mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc b/mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc old mode 100755 new mode 100644 index 5062a7b477d01866a5dd884b0c5b7abbb01f5116..0ba60b1f9cc7dfa5a00ea971b635f26e64055013 GIT binary patch delta 45 zcmeyU@lJy$pO=@50SJ!eFl21xv1Ac)(a*@wP1P?=D$Oe?)h|vgHrO1+A}#;`Dn<=K delta 48 zcmaE-@lk^(pO=@50SG4i_?y0w$C5?VT|Xl~H&wqhsWh*oRNvGfL_aUBcyl<5xBvis COAk%} diff --git a/mars_toolkit/core/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/core/__pycache__/__init__.cpython-310.pyc old mode 100755 new mode 100644 index 247ab4236c5930cca951ebb9228df94715c92c7a..007263a13aa71d97b54eb01f71407c956869fcd0 GIT binary patch delta 44 ycmaFE^pc4upO=@50SJ!eFl21xVP_Pw(a*@wP1P?=D$Oe?)h|vgHkhoz=mG!;77TU( delta 47 zcmaFK^oEHipO=@50SNkk{7v7;!_Fvbub+{ho2p-$RGL>(s&8r#qMw&mJXw{|1prqw B4uSvx diff --git a/mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc b/mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc old mode 100755 new mode 100644 index ec7f17d30576fd10cfe7482c5b0f75ae07f1dd41..ccb16a589e95075e513daa5dc448cf0d528493da GIT binary patch delta 45 zcmX>mby$igpO=@50SJ!eFl21xVPqDv)z8S!P1P?=D$Oe?)h|vgHrTAhtilBV50MOA delta 48 zcmX>sbxevUpO=@50SG4i_?y0whml#-K|douH&wqhsWh*oRNvGfL_aUBc(XjS3Ksxi CfDUQ^ diff --git a/mars_toolkit/core/__pycache__/config.cpython-310.pyc b/mars_toolkit/core/__pycache__/config.cpython-310.pyc old mode 100755 new mode 100644 index c47f1a04baf81162594715be10097af57b090737..a52eb09cb2a744b5c91253dea24e62541524de2c GIT binary patch delta 97 zcmcaBa7BPOpO=@50SK&SGGu()$Q#2X6rrDypPQ;*npB!sQmS8^SZpx4jww*c9-FYj TW)|k-j6zlza+@Ep)G`AA$21@2 delta 58 zcmca2a94mgpO=@50SMTH{-^)h$Q#4N7(JPfS!wbcCZWl*nG6}7K)lUwm~JvM+5(v> Ko1e4fG6MkJE)cc= diff --git a/mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc b/mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc old mode 100755 new mode 100644 index 7a9f7f5584d3591c1ee5a75d3093b31d5b9ff308..421af473f5fa3b44a07cf8b6de0e96417bd26b02 GIT binary patch delta 45 zcmaE>@k)aypO=@50SJ!eFl21x`O79`tDljdo2p-$RGL>(s$ZN~Y_M65-Jcf#JLL`Z delta 48 zcmaE*@m7N;pO=@50SNkk{7v7;^OsH3K|douH&wqhsWh*oRNvGfL_aUBc(XLSKQ91~ C><|C| diff --git a/mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc b/mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc old mode 100755 new mode 100644 index b18130bc80cd4fb4a2b5136f86b955713ba8c821..4119ff25b5ce57c4d0e66e245ec4a4e3de16f317 GIT binary patch delta 44 ycmey#{*9d{pO=@50SJ!eFl21xNoN*v)z8S!P1P?=D$Oe?)h|vgHkjPOTnGRezYUZC delta 47 zcmeyy{*#?2pO=@50SNkk{7v7;lg=#ap`VeTo2p-$RGL>(s&8r#qMw&mJh_#*5CCe4 B51;@5 diff --git a/mars_toolkit/core/config.py b/mars_toolkit/core/config.py index 41f1684..03a831c 100755 --- a/mars_toolkit/core/config.py +++ b/mars_toolkit/core/config.py @@ -22,11 +22,11 @@ class Config: HTTPS_PROXY = 'http://192.168.168.1:20171' # FairChem - FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt' + FAIRCHEM_MODEL_PATH = '/home/ubuntu/sas0/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt' FMAX = 0.05 # MatterGen - MATTERGENMODEL_ROOT = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/mattergen_ckpt' + MATTERGENMODEL_ROOT = '/home/ubuntu/sas0/lzy/mars-mcp/pretrained_models/mattergen_ckpt' MATTERGEN_ROOT='/home/ubuntu/50T/nfs/lzy/mars-mcp/mattergen' MATTERGENMODEL_RESULT_PATH = 'results/' diff --git a/mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc old mode 100755 new mode 100644 index 1bc8892f72cae8133e0a0f0ae283a73b3ef95bf6..32332a55008d72d6e08ad33a88430a1901fabf2b GIT binary patch delta 43 xcmX@YbdZTBpO=@50SJ!eFl0>Rc_?I~pOK%Ns$ZH^npaY)Uz}KMFqw(b2>=Ge47UIP delta 45 zcmX@ebcBf~pO=@50SG4i_?teF=b@;*enx(7s(xuwX(s$ZN~Y%uvY^8)}5d=4rA delta 47 zcmdnNxto(GpO=@50SG4i_?y0wXBV@mqkcwyZmNE1QfXdEslKT}h<;vL@#GuK4**`| B5H|n- diff --git a/mars_toolkit/query/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/query/__pycache__/__init__.cpython-310.pyc old mode 100755 new mode 100644 index 04f283c99595d946b804bfab137c93e57943b4bd..991b559d5c893b4249f32f4e7a0a6e4e6f402e73 GIT binary patch delta 44 ycmbQrHi?ZVpO=@50SJ!eFl21xS(s$ZN~Y%uu@lP3WFLJc4Q delta 47 zcmbQlHkFMhpO=@50SG4i_?y0wXEl?kgMLPSZmNE1QfXdEslKT}h<;vL@#K?Co&Zgd B4=n%y diff --git a/mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc b/mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc old mode 100755 new mode 100644 index e20c07863850f8e8a4376a8e8d8937187273c9f6..1997296904c55faeb28bb1e7e857b530df963913 GIT binary patch delta 45 zcmdnXyNj16pO=@50SJ!eFl21xNoN&u(9g)vP1P?=D$Oe?)h|vgHrU+3YRU)z6r~Le delta 48 zcmdnRyO)o-$kDU@y(~fCtqa$1(DTrC)I4hoy(m0osauFzjJ5i$NW#@ zb~Bwefv%C9w!N7-wWsjpba|r9IB;|mC54{MP1e+&_C%Nic z^;9VpW6e$1?UIe)KsV!LdsrHBY_4;Io5ze>j_C~Z)D|lMPdmd~3a8wMN(W0L95-Mo z^MKTxezTH!zz8g1q~2&Ywsxwgzy03+ECKL7`cT8YBuAgs5}d68`O04U%ulQVR|MRJ zdn~{cJRJ3tHB|~c#^>P;;M!IA{I7rs*3|^|upbE$BGJ`*fFKez!+Xg$<+Foa-v#ew z-{OQnK+&dPp!g}FaI?dLWEBEcB)e&D^)S$c<`sq7wS`tQx|tsI5TT*V2v79|$q3bF zdAf(M!n^B%E_ClZj)5Wc+7Lyq1g1YMOm26W7YegB!ZW*sc$RSP7PJWtQ94_IJOn5V zPR0FZcY%)xB%8aJd=UcUW<4evC895d@oyC$%R@ga41ZLZ9h5>JiAwqZiNDE7niUV2 zrvX!5u`fUp+n^31qjk2-7RqPVl?4T@Y2VHdn48<<;6r7h^W$U5z}0h6VWUSXN@jW{-bk^ z76-iX^iK@@BD5dVh^D`N{~XPegKaU*Qv?68a^;}`g(c&?Kj`qj4X*a`c62jevo@@$!B%~qs2jaxfjv*BeZ5`L<$U_SgozlKI=8(*Lmer@b77UE*Q6`}JL zc(0N7=atp%X5Crce{a9x@2|EhtxMr$b9GdwoTP4jOorXb+F3L)JThM>sK{`PoN>(3 z%F(UgVGqru8vbHln#o0)+pJc18dYzKc6u{pB6Y9Q;FXrQ5{_95ZRJMOb7Jw4E#Wh+ zTQB>~rrYu~S#4Au{~B>PUZj`HjY{1qm%ZmC*|IcJSuFh#^Sa!puhpA;-*qk{d967z YI_M}@s=rtYO=en*k$UzwW7&uQ0G4rQ)&Kwi delta 1340 zcmZ`(O>7%g5T1G7zqe~|66eQulBWEmr53af5n2o99AG+QU zHF}Ffmm?7gD=saSY>5kJE^tGfICAz4PCdd6keV{PPH_3#r=2%%e!ltU&1!!reqV6A zxttCB9RBv8@uPd>E~7enda;oqAkos=dQC@QJ<~GUX3b1>t!1@kO{TiuvfED0LGWOZ z^>Rb3xt>D?nPeRZlAiRY$lekqAjg~9(kM&D$6AE7X~GRC=Z?NxJ9_wY-_HPigoFz2 z!aWhAt0zsQ}fG#p6n5lz777x zz#t>S+nCrwEijcvW+Xzr0imJvfkoC3LsOZ4{|#jfq%xb@z#hU75js|4Avq?LSUABU zdx)>Y^(&#JEdNV#LaD4~mMxbx$>r)>9Yg6oXOHhCim85C5 zr?u_^%_!{imcOUJ2%)@rJq4O&pwE;%4WqdNjJ-+*oE%|7^bu=9U!Ib`d6lU-X@~RN zVOC}RU-wpnnL$37V`e6T=;2w4^*YzT6OBC$6MRHDTL z44o=;Rx}8Tw|+qgT?8IeX|PC34C_Xi3@^DWT23Ob@z6r|L${-k_g<Ix|>O23M2H z@sIf?R+LkKe$tzgRS=w_fd)}#JYR$=Y{0W8^FPF1ewX2sb&cODM{jU6e<{m{cdazO|nvy3zPGIqwGStych55_v6U; zR2zYBem!UYgm^a2%V%&lzAV?#j6acIpp47*_Jw)=Rg45}jnMlF{8#w&$7_}Mx^1tr z^Uh8u*r`0fcD2&k>c!vMl^KKGS$=!BS%`hxjUU)Alr$8$fHoF!8eN=YY3+++Ti4=s zXRS0dGF#n7<3^|97ddNWyuaC@&7OZgzT>R+m)e_tuO4)}t?e7ZrFXl&cd6BCr-q)d zC-Ji31?vpp`6H`d?`*ccdfmUs9{*e-5=trUN5rd1p7m0@OFJ!(zal2Q%)$Z-EzFLW Puz3l|w8cXqrjP#w=?6mz diff --git a/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc b/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc old mode 100755 new mode 100644 index 43b40a916b915af7a239e016cc4290daac91fcce..ba3ce69510edd0743945a65519d33222cb6f25b9 GIT binary patch delta 45 zcmX>uctnsVpO=@50SJ!eFl21xsbdwg*U!k$P1P?=D$Oe?)h|vgHrPC$Rf8D-8}JQE delta 48 zcmX>icwCStpO=@50SG4i_?y0wr;b(BNk1b$H&wqhsWh*oRNvGfL_aUBc=H@q4Q2pt CG7nb( diff --git a/mars_toolkit/query/oqmd_query.py b/mars_toolkit/query/oqmd_query.py index 1a582a2..b95f3e0 100755 --- a/mars_toolkit/query/oqmd_query.py +++ b/mars_toolkit/query/oqmd_query.py @@ -1,11 +1,13 @@ import logging +import os import httpx import pandas as pd from bs4 import BeautifulSoup from io import StringIO from typing import Annotated -from mars_toolkit.core.llm_tools import llm_tool +from ..core import config +from ..core.llm_tools import llm_tool logger = logging.getLogger(__name__) @@ -23,6 +25,8 @@ async def fetch_chemical_composition_from_OQMD( Formatted text with material information and property tables """ # Fetch data from OQMD + os.environ['HTTP_PROXY'] = config.HTTP_PROXY or '' + os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or '' url = f"https://www.oqmd.org/materials/composition/{composition}" try: async with httpx.AsyncClient(timeout=100.0) as client: diff --git a/mars_toolkit/services/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/services/__pycache__/__init__.cpython-310.pyc old mode 100755 new mode 100644 index cad462695e3722ac1d6fc1fa088601c877790354..9cad6cf37b5567c2b23ffb84e9da10dc46c9cf2a GIT binary patch delta 44 ycmX@be43djpO=@50SJ!eFl21xNo5pr(9g)vP1P?=D$Oe?)h|vgHkjPR=m7u-J`DH( delta 43 xcmX@je2SSTpO=@50SNLRy-DB5lgcP)ub+{ho2p-$RGL>(s&8r#GP#-20{|5)4MP9` diff --git a/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc b/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc old mode 100755 new mode 100644 index cb8e3dc1af4feb71c8e2ae971c3c8c2e4130a4f1..75356aa6b07ab8e1b70334adf6ba2a841c0a4bf7 GIT binary patch delta 1172 zcmZva%TE(Q9LGD`?b7W7=mQZEP-u~gKn3I5Xndgf090bIYUm1N zcrbI);2+>gP2<&g){9`EE%>;wJOi-~8tLoBSTLI`e!cb-n7;^pwj;}HCAwz*;s9%W>;%t6|0hNM2v%45Qp4F zqH)DvAtcuLVD8q86#P2{2_?v|7!1u#4l0KLfTw4Mqyji79)6zzP8WqGd{k0wc=lE_!?E?;ou9$70 zAeGRQ*q(Z`SnN*$7&xB=PMbyjVNeRUa5_y zrEXixO#1FvrP@DZXk6`k>?KR$NB>b$5Xphmf?JtpmHxFUvSw*RxEuxX4Xr|h zq^{;?OPw=|@@o)X2W|iv;Hr2%&~w6lC;w>#UxLQ1oWeQ;WEEshEC0j z_QZq!n>e>{Zm5^DT&`|>WRJw##23STDHp}l-(^%&G;4@+912zzE<2yuf`x2UZlMZ?WW9RbGI@03HE)!ur}T$>>ZR~;i zM@bjiMuRj&n`m=RZFjUgNGok2+vRGb;hdMWlO4R+I39`t{w`2U3`_H?T=-52UzwR@%BlrgVJD=AjR0q?8^K9e>nQ9f z5RjN;NM#Ct-;(yQE@>xy{Cms(X4@8?N}-f5(peVc>2N~X#b?6zl_UIPc(`>?YNy!& zWFH2ON<3`i>;4E&MP@VV$A+E_+v{#Np3l3rc diff --git a/test_mars_t1.py b/test_mars_t1.py new file mode 100644 index 0000000..f5f3140 --- /dev/null +++ b/test_mars_t1.py @@ -0,0 +1,177 @@ +import asyncio +import json +from openai import OpenAI +from rich.console import Console +from rich.panel import Panel +from rich.text import Text + +# 创建Rich控制台对象 +console = Console() + +# 定义分隔符样式 +def print_separator(title=""): + console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center") + +api_key="gpustack_72b0d41ec69eddab_bce1ea964ddc277ac6aed46b67b03960" +base_url="http://gpustack.ddwtop.team/v1-openai" +client = OpenAI( + api_key=api_key, + base_url=base_url, +) +from mars_toolkit import get_tool_schemas,get_tools +tools_schemas = get_tool_schemas() +tool_map = get_tools() + +# 打印消息的函数 +def print_message(message): + # 处理不同类型的消息对象 + if hasattr(message, 'role'): # ChatCompletionMessage 对象 + role = message.role + content = message.content if hasattr(message, 'content') else "" + # 如果是工具消息,获取工具名称 + tool_name = None + if role == "tool" and hasattr(message, 'name'): + tool_name = message.name + else: # 字典类型 + role = message.get("role", "unknown") + content = message.get("content", "") + # 如果是工具消息,获取工具名称 + tool_name = message.get("name") if role == "tool" else None + + # 根据角色选择不同的颜色 + role_colors = { + "system": "bright_blue", + "user": "green", + "assistant": "yellow", + "tool": "bright_red" + } + color = role_colors.get(role, "white") + + # 创建富文本面板 + text = Text() + + # 如果是工具消息,添加工具名称 + if role == "tool" and tool_name: + text.append(f"{role} ({tool_name}): ", style=f"bold {color}") + else: + text.append(f"{role}: ", style=f"bold {color}") + + text.append(str(content)) + console.print(Panel(text, border_style=color)) + +messages = [ + {"role": "system", + "content": "You are MARS-T1, a professional assistant in materials science. You first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here structured answer here '"}, + {"role": "user", "content": "how to synthesize CsPbBr3 at room temperature"} +] + +# 打印初始消息 +print_separator("初始消息") +for message in messages: + print_message(message) +finish_reason = None + +async def execute_tool(tool_name,tool_arguments): + tool_func = tool_map[tool_name] # 获取工具函数 + arguments = {} + if tool_arguments: + # 检查arguments是字符串还是字典 + if isinstance(tool_arguments, dict): + # 如果已经是字典,直接使用 + arguments = tool_arguments + elif isinstance(tool_arguments, str): + # 如果是字符串,尝试解析为JSON + try: + # 尝试直接解析为JSON对象 + arguments = json.loads(tool_arguments) + except json.JSONDecodeError: + # 如果解析失败,可能是因为字符串中包含转义字符 + # 尝试修复常见的JSON字符串问题 + fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\') + try: + arguments = json.loads(fixed_str) + except json.JSONDecodeError: + # 如果仍然失败,尝试将字符串作为原始字符串处理 + arguments = {"raw_string": tool_arguments} + + # 调用工具函数 + if asyncio.iscoroutinefunction(tool_func): + # 如果是异步函数,使用await调用 + result = await tool_func(**arguments) + else: + # 如果是同步函数,直接调用 + result = tool_func(**arguments) + # if func_name=='generate_material': + # print("xxxxx",result) + return result + +while finish_reason is None or finish_reason == "tool_calls": + completion = client.chat.completions.create( + model="MARS-T1", + messages=messages, + temperature=0.3, + tools=tools_schemas, # <-- 我们通过 tools 参数,将定义好的 tools 提交给 Kimi 大模型 + ) + choice = completion.choices[0] + finish_reason = choice.finish_reason + if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls + # 打印assistant消息 + print_separator("Assistant消息") + print_message(choice.message) + + # 将ChatCompletionMessage对象转换为字典 + assistant_message = { + "role": "assistant", + "content": choice.message.content if hasattr(choice.message, 'content') else None + } + + # 如果有工具调用,添加到字典中 + if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls: + # 将tool_calls对象转换为字典列表 + tool_calls_list = [] + for tool_call in choice.message.tool_calls: + tool_call_dict = { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments + } + } + tool_calls_list.append(tool_call_dict) + assistant_message["tool_calls"] = tool_calls_list + + # 添加消息到上下文 + messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求 + + # 打印工具调用信息 + print_separator("工具调用") + for tool_call in choice.message.tool_calls: + console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]") + console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]") + console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]") + console.print("") + + tool_call_name = tool_call.function.name + tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object,我们需要使用 json.loads 反序列化一下 + tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数 + + # 构造工具响应消息 + tool_message = { + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_call_name, + "content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果 + } + + # 打印工具响应 + print_separator(f"工具响应: {tool_call_name}") + print_message(tool_message) + + # 添加消息到上下文 + messages.append(tool_message) + +# 打印最终响应 +if choice.message.content: + print_separator("最终响应") + console.print(Panel(choice.message.content, border_style="green")) diff --git a/test_mars_toolkit.py b/test_mars_toolkit.py index baa1bc9..f73561f 100755 --- a/test_mars_toolkit.py +++ b/test_mars_toolkit.py @@ -172,7 +172,7 @@ if __name__ == "__main__": ] # 选择要测试的工具 - tool_name = tools_to_test[2] # 测试 search_online 工具 + tool_name = tools_to_test[7] # 测试 search_online 工具 # 运行测试 result = asyncio.run(test_tool(tool_name))