340 lines
12 KiB
Python
340 lines
12 KiB
Python
import asyncio
|
||
import json
|
||
import sys
|
||
import os
|
||
from openai import OpenAI
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from rich.text import Text
|
||
|
||
# 添加项目根目录到sys.path
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from sci_mcp import *
|
||
|
||
# 创建Rich控制台对象
|
||
console = Console()
|
||
|
||
# 创建一个列表来存储工具调用结果
|
||
tool_results = []
|
||
|
||
# 定义分隔符样式
|
||
def print_separator(title=""):
|
||
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
|
||
|
||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||
client = OpenAI(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
from sci_mcp import *
|
||
|
||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||
tool_map = get_domain_tools(["material",'general'])
|
||
|
||
|
||
# 打印消息的函数
|
||
def print_message(message):
|
||
# 处理不同类型的消息对象
|
||
if hasattr(message, 'role'): # ChatCompletionMessage 对象
|
||
role = message.role
|
||
content = message.content if hasattr(message, 'content') else ""
|
||
# 如果是工具消息,获取工具名称
|
||
tool_name = None
|
||
if role == "tool" and hasattr(message, 'name'):
|
||
tool_name = message.name
|
||
else: # 字典类型
|
||
role = message.get("role", "unknown")
|
||
content = message.get("content", "")
|
||
# 如果是工具消息,获取工具名称
|
||
tool_name = message.get("name") if role == "tool" else None
|
||
|
||
# 根据角色选择不同的颜色
|
||
role_colors = {
|
||
"system": "bright_blue",
|
||
"user": "green",
|
||
"assistant": "yellow",
|
||
"tool": "bright_red"
|
||
}
|
||
color = role_colors.get(role, "white")
|
||
|
||
# 创建富文本面板
|
||
text = Text()
|
||
|
||
# 如果是工具消息,添加工具名称
|
||
if role == "tool" and tool_name:
|
||
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
|
||
else:
|
||
text.append(f"{role}: ", style=f"bold {color}")
|
||
|
||
text.append(str(content))
|
||
console.print(Panel(text, border_style=color))
|
||
|
||
messages = [
|
||
{"role": "user", "content": """data_Ti4V
|
||
_symmetry_space_group_name_H-M Fmmm
|
||
_cell_length_a 3.18353600
|
||
_cell_length_b 4.52677200
|
||
_cell_length_c 22.74397000
|
||
_cell_angle_alpha 90.00000000
|
||
_cell_angle_beta 90.00000000
|
||
_cell_angle_gamma 90.00000000
|
||
_symmetry_Int_Tables_number 69
|
||
_chemical_formula_structural Ti4V
|
||
_chemical_formula_sum 'Ti16 V4'
|
||
_cell_volume 327.76657340
|
||
_cell_formula_units_Z 4
|
||
loop_
|
||
_symmetry_equiv_pos_site_id
|
||
_symmetry_equiv_pos_as_xyz
|
||
1 'x, y, z'
|
||
2 '-x, -y, -z'
|
||
3 '-x, -y, z'
|
||
4 'x, y, -z'
|
||
5 'x, -y, -z'
|
||
6 '-x, y, z'
|
||
7 '-x, y, -z'
|
||
8 'x, -y, z'
|
||
9 'x+1/2, y, z+1/2'
|
||
10 '-x+1/2, -y, -z+1/2'
|
||
11 '-x+1/2, -y, z+1/2'
|
||
12 'x+1/2, y, -z+1/2'
|
||
13 'x+1/2, -y, -z+1/2'
|
||
14 '-x+1/2, y, z+1/2'
|
||
15 '-x+1/2, y, -z+1/2'
|
||
16 'x+1/2, -y, z+1/2'
|
||
17 'x+1/2, y+1/2, z'
|
||
18 '-x+1/2, -y+1/2, -z'
|
||
19 '-x+1/2, -y+1/2, z'
|
||
20 'x+1/2, y+1/2, -z'
|
||
21 'x+1/2, -y+1/2, -z'
|
||
22 '-x+1/2, y+1/2, z'
|
||
23 '-x+1/2, y+1/2, -z'
|
||
24 'x+1/2, -y+1/2, z'
|
||
25 'x, y+1/2, z+1/2'
|
||
26 '-x, -y+1/2, -z+1/2'
|
||
27 '-x, -y+1/2, z+1/2'
|
||
28 'x, y+1/2, -z+1/2'
|
||
29 'x, -y+1/2, -z+1/2'
|
||
30 '-x, y+1/2, z+1/2'
|
||
31 '-x, y+1/2, -z+1/2'
|
||
32 'x, -y+1/2, z+1/2'
|
||
loop_
|
||
_atom_site_type_symbol
|
||
_atom_site_label
|
||
_atom_site_symmetry_multiplicity
|
||
_atom_site_fract_x
|
||
_atom_site_fract_y
|
||
_atom_site_fract_z
|
||
_atom_site_occupancy
|
||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}
|
||
]
|
||
#how to synthesize CsPbBr3 at room temperature
|
||
#
|
||
# 打印初始消息
|
||
print_separator("初始消息")
|
||
for message in messages:
|
||
print_message(message)
|
||
finish_reason = None
|
||
|
||
async def execute_tool(tool_name,tool_arguments):
|
||
# 设置LLM调用上下文标记
|
||
|
||
try:
|
||
tool_func = tool_map[tool_name] # 获取工具函数
|
||
arguments = {}
|
||
if tool_arguments:
|
||
# 检查arguments是字符串还是字典
|
||
if isinstance(tool_arguments, dict):
|
||
# 如果已经是字典,直接使用
|
||
arguments = tool_arguments
|
||
elif isinstance(tool_arguments, str):
|
||
# 如果是字符串,尝试解析为JSON
|
||
try:
|
||
# 尝试直接解析为JSON对象
|
||
arguments = json.loads(tool_arguments)
|
||
except json.JSONDecodeError:
|
||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||
# 尝试修复常见的JSON字符串问题
|
||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||
try:
|
||
arguments = json.loads(fixed_str)
|
||
except json.JSONDecodeError:
|
||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||
arguments = {"raw_string": tool_arguments}
|
||
|
||
# 调用工具函数
|
||
if asyncio.iscoroutinefunction(tool_func):
|
||
# 如果是异步函数,使用await调用
|
||
result = await tool_func(**arguments)
|
||
else:
|
||
# 如果是同步函数,直接调用
|
||
result = tool_func(**arguments)
|
||
# if func_name=='generate_material':
|
||
# print("xxxxx",result)
|
||
return result
|
||
finally:
|
||
# 清除LLM调用上下文标记
|
||
pass
|
||
|
||
# 定义一个函数来估算消息的token数量(粗略估计)
|
||
def estimate_tokens(messages):
|
||
# 简单估计:每个英文单词约1.3个token,每个中文字符约1个token
|
||
total = 0
|
||
for msg in messages:
|
||
content = msg.get("content", "") if isinstance(msg, dict) else (msg.content if hasattr(msg, "content") else "")
|
||
if content:
|
||
# 粗略估计内容的token数
|
||
total += len(content) * 1.3
|
||
|
||
# 估计工具调用的token数
|
||
tool_calls = msg.get("tool_calls", []) if isinstance(msg, dict) else (msg.tool_calls if hasattr(msg, "tool_calls") else [])
|
||
for tool_call in tool_calls:
|
||
if isinstance(tool_call, dict):
|
||
args = tool_call.get("function", {}).get("arguments", "")
|
||
total += len(args) * 1.3
|
||
else:
|
||
args = tool_call.function.arguments if hasattr(tool_call, "function") else ""
|
||
total += len(args) * 1.3
|
||
|
||
return int(total)
|
||
|
||
# 管理消息历史,保持在token限制以内
|
||
def manage_message_history(messages, max_tokens=7000):
|
||
# 保留第一条消息(通常是系统消息或初始用户消息)
|
||
if len(messages) <= 1:
|
||
return messages
|
||
|
||
# 估算当前消息的token数
|
||
current_tokens = estimate_tokens(messages)
|
||
|
||
# 如果当前token数已经接近限制,开始裁剪历史消息
|
||
if current_tokens > max_tokens:
|
||
# 保留第一条消息和最近的消息
|
||
preserved_messages = [messages[0]]
|
||
|
||
# 从最新的消息开始添加,直到接近但不超过token限制
|
||
temp_messages = []
|
||
for msg in reversed(messages[1:]):
|
||
temp_messages.insert(0, msg)
|
||
if estimate_tokens([messages[0]] + temp_messages) > max_tokens:
|
||
# 如果添加这条消息会超过限制,则停止添加
|
||
temp_messages.pop(0)
|
||
break
|
||
|
||
# 如果裁剪后的消息太少,至少保留最近的几条消息
|
||
if len(temp_messages) < 4 and len(messages) > 4:
|
||
temp_messages = messages[-4:]
|
||
|
||
return preserved_messages + temp_messages
|
||
|
||
return messages
|
||
|
||
while finish_reason is None or finish_reason == "tool_calls":
|
||
# 在发送请求前管理消息历史
|
||
managed_messages = manage_message_history(messages)
|
||
if len(managed_messages) < len(messages):
|
||
print_separator(f"消息历史已裁剪,从{len(messages)}条减少到{len(managed_messages)}条")
|
||
messages = managed_messages
|
||
|
||
completion = client.chat.completions.create(
|
||
model="MARS-T1",
|
||
messages=messages,
|
||
temperature=0.3,
|
||
tools=tools_schemas,
|
||
timeout=120,
|
||
)
|
||
choice = completion.choices[0]
|
||
finish_reason = choice.finish_reason
|
||
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
|
||
# 打印assistant消息
|
||
print_separator("Assistant消息")
|
||
print_message(choice.message)
|
||
|
||
# 将ChatCompletionMessage对象转换为字典
|
||
assistant_message = {
|
||
"role": "assistant",
|
||
"content": choice.message.content if hasattr(choice.message, 'content') else None
|
||
}
|
||
|
||
# 如果有工具调用,添加到字典中
|
||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||
# 将tool_calls对象转换为字典列表
|
||
tool_calls_list = []
|
||
for tool_call in choice.message.tool_calls:
|
||
tool_call_dict = {
|
||
"id": tool_call.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_call.function.name,
|
||
"arguments": tool_call.function.arguments
|
||
}
|
||
}
|
||
tool_calls_list.append(tool_call_dict)
|
||
assistant_message["tool_calls"] = tool_calls_list
|
||
|
||
# 添加消息到上下文
|
||
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
|
||
|
||
# 打印工具调用信息
|
||
print_separator("工具调用")
|
||
for tool_call in choice.message.tool_calls:
|
||
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
|
||
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
|
||
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
|
||
console.print("")
|
||
|
||
tool_call_name = tool_call.function.name
|
||
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object,我们需要使用 json.loads 反序列化一下
|
||
try:
|
||
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
|
||
except Exception as e:
|
||
tool_result=f'工具调用失败{e}'
|
||
|
||
# 将工具调用结果保存到单独的列表中,使用指定格式包裹
|
||
formatted_result = f"[{tool_call_name} content begin]{tool_result}[{tool_call_name} content end]"
|
||
tool_results.append({
|
||
"tool_name": tool_call_name,
|
||
"tool_id": tool_call.id,
|
||
"formatted_result": formatted_result,
|
||
"raw_result": tool_result
|
||
})
|
||
|
||
# 打印保存的工具结果信息
|
||
console.print(f"[bold green]已保存工具结果:[/] [cyan]{tool_call_name}[/]")
|
||
|
||
# 构造工具响应消息
|
||
tool_message = {
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"name": tool_call_name,
|
||
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
|
||
}
|
||
|
||
# 打印工具响应
|
||
print_separator(f"工具响应: {tool_call_name}")
|
||
print_message(tool_message)
|
||
|
||
# 添加消息到上下文
|
||
#messages.append(tool_message)
|
||
|
||
# 打印最终响应
|
||
if choice.message.content:
|
||
print_separator("最终响应")
|
||
console.print(Panel(choice.message.content, border_style="green"))
|
||
|
||
# 打印收集的所有工具调用结果
|
||
if tool_results:
|
||
print_separator("所有工具调用结果")
|
||
console.print(f"[bold cyan]共收集了 {len(tool_results)} 个工具调用结果[/]")
|
||
|
||
# 将所有格式化的结果写入文件
|
||
with open("tool_results.txt", "w", encoding="utf-8") as f:
|
||
for result in tool_results:
|
||
f.write(f"{result['formatted_result']}\n\n")
|
||
|
||
console.print(f"[bold green]所有工具调用结果已保存到 tool_results.txt 文件中[/]")
|