178 lines
6.1 KiB
Python
178 lines
6.1 KiB
Python
import asyncio
|
||
import json
|
||
from openai import OpenAI
|
||
import sys
|
||
import os
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from sci_mcp import *
|
||
|
||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||
tool_map = get_domain_tools(["material",'general'])
|
||
|
||
|
||
|
||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||
client = OpenAI(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
messages = [ {"role": "user", "content": """data_Ti4V
|
||
_symmetry_space_group_name_H-M Fmmm
|
||
_cell_length_a 3.18353600
|
||
_cell_length_b 4.52677200
|
||
_cell_length_c 22.74397000
|
||
_cell_angle_alpha 90.00000000
|
||
_cell_angle_beta 90.00000000
|
||
_cell_angle_gamma 90.00000000
|
||
_symmetry_Int_Tables_number 69
|
||
_chemical_formula_structural Ti4V
|
||
_chemical_formula_sum 'Ti16 V4'
|
||
_cell_volume 327.76657340
|
||
_cell_formula_units_Z 4
|
||
loop_
|
||
_symmetry_equiv_pos_site_id
|
||
_symmetry_equiv_pos_as_xyz
|
||
1 'x, y, z'
|
||
2 '-x, -y, -z'
|
||
3 '-x, -y, z'
|
||
4 'x, y, -z'
|
||
5 'x, -y, -z'
|
||
6 '-x, y, z'
|
||
7 '-x, y, -z'
|
||
8 'x, -y, z'
|
||
9 'x+1/2, y, z+1/2'
|
||
10 '-x+1/2, -y, -z+1/2'
|
||
11 '-x+1/2, -y, z+1/2'
|
||
12 'x+1/2, y, -z+1/2'
|
||
13 'x+1/2, -y, -z+1/2'
|
||
14 '-x+1/2, y, z+1/2'
|
||
15 '-x+1/2, y, -z+1/2'
|
||
16 'x+1/2, -y, z+1/2'
|
||
17 'x+1/2, y+1/2, z'
|
||
18 '-x+1/2, -y+1/2, -z'
|
||
19 '-x+1/2, -y+1/2, z'
|
||
20 'x+1/2, y+1/2, -z'
|
||
21 'x+1/2, -y+1/2, -z'
|
||
22 '-x+1/2, y+1/2, z'
|
||
23 '-x+1/2, y+1/2, -z'
|
||
24 'x+1/2, -y+1/2, z'
|
||
25 'x, y+1/2, z+1/2'
|
||
26 '-x, -y+1/2, -z+1/2'
|
||
27 '-x, -y+1/2, z+1/2'
|
||
28 'x, y+1/2, -z+1/2'
|
||
29 'x, -y+1/2, -z+1/2'
|
||
30 '-x, y+1/2, z+1/2'
|
||
31 '-x, y+1/2, -z+1/2'
|
||
32 'x, -y+1/2, z+1/2'
|
||
loop_
|
||
_atom_site_type_symbol
|
||
_atom_site_label
|
||
_atom_site_symmetry_multiplicity
|
||
_atom_site_fract_x
|
||
_atom_site_fract_y
|
||
_atom_site_fract_z
|
||
_atom_site_occupancy
|
||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}]
|
||
|
||
def get_t1_response(messages):
|
||
completion = client.chat.completions.create(
|
||
model="MARS-T1",
|
||
messages=messages,
|
||
temperature=0.3,
|
||
tools=tools_schemas,
|
||
)
|
||
choice = completion.choices[0]
|
||
reasoning_content = choice.message.content
|
||
#print("Reasoning content:", reasoning_content)
|
||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||
tool_calls_list = []
|
||
for tool_call in choice.message.tool_calls:
|
||
tool_call_dict = {
|
||
"id": tool_call.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_call.function.name,
|
||
"arguments": tool_call.function.arguments
|
||
}
|
||
}
|
||
tool_calls_list.append(tool_call_dict)
|
||
return reasoning_content, tool_calls_list
|
||
|
||
async def execute_tool(tool_name,tool_arguments):
|
||
|
||
try:
|
||
tool_func = tool_map[tool_name] # 获取工具函数
|
||
arguments = {}
|
||
if tool_arguments:
|
||
# 检查arguments是字符串还是字典
|
||
if isinstance(tool_arguments, dict):
|
||
# 如果已经是字典,直接使用
|
||
arguments = tool_arguments
|
||
elif isinstance(tool_arguments, str):
|
||
# 如果是字符串,尝试解析为JSON
|
||
try:
|
||
# 尝试直接解析为JSON对象
|
||
arguments = json.loads(tool_arguments)
|
||
except json.JSONDecodeError:
|
||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||
# 尝试修复常见的JSON字符串问题
|
||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||
try:
|
||
arguments = json.loads(fixed_str)
|
||
except json.JSONDecodeError:
|
||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||
arguments = {"raw_string": tool_arguments}
|
||
|
||
# 调用工具函数
|
||
if asyncio.iscoroutinefunction(tool_func):
|
||
# 如果是异步函数,使用await调用
|
||
result = await tool_func(**arguments)
|
||
else:
|
||
# 如果是同步函数,直接调用
|
||
result = tool_func(**arguments)
|
||
# if func_name=='generate_material':
|
||
# print("xxxxx",result)
|
||
return result
|
||
finally:
|
||
# 清除LLM调用上下文标记
|
||
pass
|
||
|
||
def get_all_tool_calls_results(tool_calls_list):
|
||
all_results = []
|
||
for tool_call in tool_calls_list:
|
||
tool_name = tool_call['function']['name']
|
||
tool_arguments = tool_call['function']['arguments']
|
||
result = asyncio.run(execute_tool(tool_name,tool_arguments))
|
||
result_str = f"[{tool_name} content begin]\n"+result+f"\n[{tool_name} content end]\n"
|
||
all_results.append(result_str)
|
||
|
||
return all_results
|
||
def get_response_from_r1(messages):
|
||
completion = client.chat.completions.create(
|
||
model="MARS-R1",
|
||
messages=messages,
|
||
temperature=0.3,
|
||
|
||
)
|
||
choice = completion.choices[0]
|
||
return choice.message.content
|
||
print("R1 RESPONSE:", choice.message.content)
|
||
if __name__ == '__main__':
|
||
reasoning_content, tool_calls_list=get_t1_response(messages)
|
||
print("Reasoning content:", reasoning_content)
|
||
tool_call_results=get_all_tool_calls_results(tool_calls_list)
|
||
tool_call_results_str = "\n".join(tool_call_results)
|
||
# for tool_call in tool_call_results:
|
||
# print(tool_call)
|
||
|
||
user_message = {
|
||
"role": "user",
|
||
"content": f"# 信息如下:{tool_call_results_str}# 问题如下 {messages[0]['content']}"
|
||
}
|
||
print("user_message_for_r1:", user_message)
|
||
get_response_from_r1([user_message])
|