Files
multi_mcp/test_tools/test_mars_t1_.py
2025-05-09 14:16:33 +08:00

178 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
from openai import OpenAI
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
tools_schemas = get_domain_tool_schemas(["material",'general'])
tool_map = get_domain_tools(["material",'general'])
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url="http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
messages = [ {"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}]
def get_t1_response(messages):
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas,
)
choice = completion.choices[0]
reasoning_content = choice.message.content
#print("Reasoning content:", reasoning_content)
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
tool_calls_list = []
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
return reasoning_content, tool_calls_list
async def execute_tool(tool_name,tool_arguments):
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
finally:
# 清除LLM调用上下文标记
pass
def get_all_tool_calls_results(tool_calls_list):
all_results = []
for tool_call in tool_calls_list:
tool_name = tool_call['function']['name']
tool_arguments = tool_call['function']['arguments']
result = asyncio.run(execute_tool(tool_name,tool_arguments))
result_str = f"[{tool_name} content begin]\n"+result+f"\n[{tool_name} content end]\n"
all_results.append(result_str)
return all_results
def get_response_from_r1(messages):
completion = client.chat.completions.create(
model="MARS-R1",
messages=messages,
temperature=0.3,
)
choice = completion.choices[0]
return choice.message.content
print("R1 RESPONSE:", choice.message.content)
if __name__ == '__main__':
reasoning_content, tool_calls_list=get_t1_response(messages)
print("Reasoning content:", reasoning_content)
tool_call_results=get_all_tool_calls_results(tool_calls_list)
tool_call_results_str = "\n".join(tool_call_results)
# for tool_call in tool_call_results:
# print(tool_call)
user_message = {
"role": "user",
"content": f"# 信息如下:{tool_call_results_str}# 问题如下 {messages[0]['content']}"
}
print("user_message_for_r1:", user_message)
get_response_from_r1([user_message])