254 lines
9.5 KiB
Python
254 lines
9.5 KiB
Python
import asyncio
|
||
import json
|
||
import sys
|
||
import os
|
||
from openai import OpenAI
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from rich.text import Text
|
||
|
||
# 添加项目根目录到sys.path
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from sci_mcp_server.core.llm_tools import set_llm_context, clear_llm_context
|
||
|
||
# 创建Rich控制台对象
|
||
console = Console()
|
||
|
||
# 定义分隔符样式
|
||
def print_separator(title=""):
|
||
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
|
||
|
||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||
client = OpenAI(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
from sci_mcp import *
|
||
|
||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||
tool_map = get_domain_tools(["material",'general'])
|
||
|
||
|
||
# 打印消息的函数
|
||
def print_message(message):
|
||
# 处理不同类型的消息对象
|
||
if hasattr(message, 'role'): # ChatCompletionMessage 对象
|
||
role = message.role
|
||
content = message.content if hasattr(message, 'content') else ""
|
||
# 如果是工具消息,获取工具名称
|
||
tool_name = None
|
||
if role == "tool" and hasattr(message, 'name'):
|
||
tool_name = message.name
|
||
else: # 字典类型
|
||
role = message.get("role", "unknown")
|
||
content = message.get("content", "")
|
||
# 如果是工具消息,获取工具名称
|
||
tool_name = message.get("name") if role == "tool" else None
|
||
|
||
# 根据角色选择不同的颜色
|
||
role_colors = {
|
||
"system": "bright_blue",
|
||
"user": "green",
|
||
"assistant": "yellow",
|
||
"tool": "bright_red"
|
||
}
|
||
color = role_colors.get(role, "white")
|
||
|
||
# 创建富文本面板
|
||
text = Text()
|
||
|
||
# 如果是工具消息,添加工具名称
|
||
if role == "tool" and tool_name:
|
||
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
|
||
else:
|
||
text.append(f"{role}: ", style=f"bold {color}")
|
||
|
||
text.append(str(content))
|
||
console.print(Panel(text, border_style=color))
|
||
|
||
messages = [
|
||
{"role": "system",
|
||
"content": "You are MARS-R1, a professional assistant in materials science. You first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> <structured_answer> </structured_answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here <structured_answer> structured answer here <structured_answer> </answer>'"},
|
||
{"role": "user", "content": """data_Ti4V
|
||
_symmetry_space_group_name_H-M Fmmm
|
||
_cell_length_a 3.18353600
|
||
_cell_length_b 4.52677200
|
||
_cell_length_c 22.74397000
|
||
_cell_angle_alpha 90.00000000
|
||
_cell_angle_beta 90.00000000
|
||
_cell_angle_gamma 90.00000000
|
||
_symmetry_Int_Tables_number 69
|
||
_chemical_formula_structural Ti4V
|
||
_chemical_formula_sum 'Ti16 V4'
|
||
_cell_volume 327.76657340
|
||
_cell_formula_units_Z 4
|
||
loop_
|
||
_symmetry_equiv_pos_site_id
|
||
_symmetry_equiv_pos_as_xyz
|
||
1 'x, y, z'
|
||
2 '-x, -y, -z'
|
||
3 '-x, -y, z'
|
||
4 'x, y, -z'
|
||
5 'x, -y, -z'
|
||
6 '-x, y, z'
|
||
7 '-x, y, -z'
|
||
8 'x, -y, z'
|
||
9 'x+1/2, y, z+1/2'
|
||
10 '-x+1/2, -y, -z+1/2'
|
||
11 '-x+1/2, -y, z+1/2'
|
||
12 'x+1/2, y, -z+1/2'
|
||
13 'x+1/2, -y, -z+1/2'
|
||
14 '-x+1/2, y, z+1/2'
|
||
15 '-x+1/2, y, -z+1/2'
|
||
16 'x+1/2, -y, z+1/2'
|
||
17 'x+1/2, y+1/2, z'
|
||
18 '-x+1/2, -y+1/2, -z'
|
||
19 '-x+1/2, -y+1/2, z'
|
||
20 'x+1/2, y+1/2, -z'
|
||
21 'x+1/2, -y+1/2, -z'
|
||
22 '-x+1/2, y+1/2, z'
|
||
23 '-x+1/2, y+1/2, -z'
|
||
24 'x+1/2, -y+1/2, z'
|
||
25 'x, y+1/2, z+1/2'
|
||
26 '-x, -y+1/2, -z+1/2'
|
||
27 '-x, -y+1/2, z+1/2'
|
||
28 'x, y+1/2, -z+1/2'
|
||
29 'x, -y+1/2, -z+1/2'
|
||
30 '-x, y+1/2, z+1/2'
|
||
31 '-x, y+1/2, -z+1/2'
|
||
32 'x, -y+1/2, z+1/2'
|
||
loop_
|
||
_atom_site_type_symbol
|
||
_atom_site_label
|
||
_atom_site_symmetry_multiplicity
|
||
_atom_site_fract_x
|
||
_atom_site_fract_y
|
||
_atom_site_fract_z
|
||
_atom_site_occupancy
|
||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}
|
||
]
|
||
#how to synthesize CsPbBr3 at room temperature
|
||
#
|
||
# 打印初始消息
|
||
print_separator("初始消息")
|
||
for message in messages:
|
||
print_message(message)
|
||
finish_reason = None
|
||
|
||
async def execute_tool(tool_name,tool_arguments):
|
||
# 设置LLM调用上下文标记
|
||
set_llm_context(True)
|
||
try:
|
||
tool_func = tool_map[tool_name] # 获取工具函数
|
||
arguments = {}
|
||
if tool_arguments:
|
||
# 检查arguments是字符串还是字典
|
||
if isinstance(tool_arguments, dict):
|
||
# 如果已经是字典,直接使用
|
||
arguments = tool_arguments
|
||
elif isinstance(tool_arguments, str):
|
||
# 如果是字符串,尝试解析为JSON
|
||
try:
|
||
# 尝试直接解析为JSON对象
|
||
arguments = json.loads(tool_arguments)
|
||
except json.JSONDecodeError:
|
||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||
# 尝试修复常见的JSON字符串问题
|
||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||
try:
|
||
arguments = json.loads(fixed_str)
|
||
except json.JSONDecodeError:
|
||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||
arguments = {"raw_string": tool_arguments}
|
||
|
||
# 调用工具函数
|
||
if asyncio.iscoroutinefunction(tool_func):
|
||
# 如果是异步函数,使用await调用
|
||
result = await tool_func(**arguments)
|
||
else:
|
||
# 如果是同步函数,直接调用
|
||
result = tool_func(**arguments)
|
||
# if func_name=='generate_material':
|
||
# print("xxxxx",result)
|
||
return result
|
||
finally:
|
||
# 清除LLM调用上下文标记
|
||
clear_llm_context()
|
||
|
||
while finish_reason is None or finish_reason == "tool_calls":
|
||
completion = client.chat.completions.create(
|
||
model="MARS-R1",
|
||
messages=messages,
|
||
temperature=0.3,
|
||
tools=tools_schemas, # <-- 我们通过 tools 参数,将定义好的 tools 提交给 Kimi 大模型
|
||
)
|
||
choice = completion.choices[0]
|
||
finish_reason = choice.finish_reason
|
||
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
|
||
# 打印assistant消息
|
||
print_separator("Assistant消息")
|
||
print_message(choice.message)
|
||
|
||
# 将ChatCompletionMessage对象转换为字典
|
||
assistant_message = {
|
||
"role": "assistant",
|
||
"content": choice.message.content if hasattr(choice.message, 'content') else None
|
||
}
|
||
|
||
# 如果有工具调用,添加到字典中
|
||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||
# 将tool_calls对象转换为字典列表
|
||
tool_calls_list = []
|
||
for tool_call in choice.message.tool_calls:
|
||
tool_call_dict = {
|
||
"id": tool_call.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_call.function.name,
|
||
"arguments": tool_call.function.arguments
|
||
}
|
||
}
|
||
tool_calls_list.append(tool_call_dict)
|
||
assistant_message["tool_calls"] = tool_calls_list
|
||
|
||
# 添加消息到上下文
|
||
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
|
||
|
||
# 打印工具调用信息
|
||
print_separator("工具调用")
|
||
for tool_call in choice.message.tool_calls:
|
||
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
|
||
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
|
||
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
|
||
console.print("")
|
||
|
||
tool_call_name = tool_call.function.name
|
||
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object,我们需要使用 json.loads 反序列化一下
|
||
try:
|
||
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
|
||
except Exception as e:
|
||
tool_result=f'工具调用失败{e}'
|
||
# 构造工具响应消息
|
||
tool_message = {
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"name": tool_call_name,
|
||
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
|
||
}
|
||
|
||
# 打印工具响应
|
||
print_separator(f"工具响应: {tool_call_name}")
|
||
print_message(tool_message)
|
||
|
||
# 添加消息到上下文
|
||
messages.append(tool_message)
|
||
|
||
# 打印最终响应
|
||
if choice.message.content:
|
||
print_separator("最终响应")
|
||
console.print(Panel(choice.message.content, border_style="green"))
|