Files
multi_mcp/test_tools/test_mars_t1_r1.py
2025-05-09 14:16:33 +08:00

340 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import sys
import os
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
# 添加项目根目录到sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
# 创建Rich控制台对象
console = Console()
# 创建一个列表来存储工具调用结果
tool_results = []
# 定义分隔符样式
def print_separator(title=""):
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url="http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
from sci_mcp import *
tools_schemas = get_domain_tool_schemas(["material",'general'])
tool_map = get_domain_tools(["material",'general'])
# 打印消息的函数
def print_message(message):
# 处理不同类型的消息对象
if hasattr(message, 'role'): # ChatCompletionMessage 对象
role = message.role
content = message.content if hasattr(message, 'content') else ""
# 如果是工具消息,获取工具名称
tool_name = None
if role == "tool" and hasattr(message, 'name'):
tool_name = message.name
else: # 字典类型
role = message.get("role", "unknown")
content = message.get("content", "")
# 如果是工具消息,获取工具名称
tool_name = message.get("name") if role == "tool" else None
# 根据角色选择不同的颜色
role_colors = {
"system": "bright_blue",
"user": "green",
"assistant": "yellow",
"tool": "bright_red"
}
color = role_colors.get(role, "white")
# 创建富文本面板
text = Text()
# 如果是工具消息,添加工具名称
if role == "tool" and tool_name:
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
else:
text.append(f"{role}: ", style=f"bold {color}")
text.append(str(content))
console.print(Panel(text, border_style=color))
messages = [
{"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}
]
#how to synthesize CsPbBr3 at room temperature
#
# 打印初始消息
print_separator("初始消息")
for message in messages:
print_message(message)
finish_reason = None
async def execute_tool(tool_name,tool_arguments):
# 设置LLM调用上下文标记
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
finally:
# 清除LLM调用上下文标记
pass
# 定义一个函数来估算消息的token数量粗略估计
def estimate_tokens(messages):
# 简单估计每个英文单词约1.3个token每个中文字符约1个token
total = 0
for msg in messages:
content = msg.get("content", "") if isinstance(msg, dict) else (msg.content if hasattr(msg, "content") else "")
if content:
# 粗略估计内容的token数
total += len(content) * 1.3
# 估计工具调用的token数
tool_calls = msg.get("tool_calls", []) if isinstance(msg, dict) else (msg.tool_calls if hasattr(msg, "tool_calls") else [])
for tool_call in tool_calls:
if isinstance(tool_call, dict):
args = tool_call.get("function", {}).get("arguments", "")
total += len(args) * 1.3
else:
args = tool_call.function.arguments if hasattr(tool_call, "function") else ""
total += len(args) * 1.3
return int(total)
# 管理消息历史保持在token限制以内
def manage_message_history(messages, max_tokens=7000):
# 保留第一条消息(通常是系统消息或初始用户消息)
if len(messages) <= 1:
return messages
# 估算当前消息的token数
current_tokens = estimate_tokens(messages)
# 如果当前token数已经接近限制开始裁剪历史消息
if current_tokens > max_tokens:
# 保留第一条消息和最近的消息
preserved_messages = [messages[0]]
# 从最新的消息开始添加直到接近但不超过token限制
temp_messages = []
for msg in reversed(messages[1:]):
temp_messages.insert(0, msg)
if estimate_tokens([messages[0]] + temp_messages) > max_tokens:
# 如果添加这条消息会超过限制,则停止添加
temp_messages.pop(0)
break
# 如果裁剪后的消息太少,至少保留最近的几条消息
if len(temp_messages) < 4 and len(messages) > 4:
temp_messages = messages[-4:]
return preserved_messages + temp_messages
return messages
while finish_reason is None or finish_reason == "tool_calls":
# 在发送请求前管理消息历史
managed_messages = manage_message_history(messages)
if len(managed_messages) < len(messages):
print_separator(f"消息历史已裁剪,从{len(messages)}条减少到{len(managed_messages)}")
messages = managed_messages
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas,
timeout=120,
)
choice = completion.choices[0]
finish_reason = choice.finish_reason
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
# 打印assistant消息
print_separator("Assistant消息")
print_message(choice.message)
# 将ChatCompletionMessage对象转换为字典
assistant_message = {
"role": "assistant",
"content": choice.message.content if hasattr(choice.message, 'content') else None
}
# 如果有工具调用,添加到字典中
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
# 将tool_calls对象转换为字典列表
tool_calls_list = []
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
assistant_message["tool_calls"] = tool_calls_list
# 添加消息到上下文
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
# 打印工具调用信息
print_separator("工具调用")
for tool_call in choice.message.tool_calls:
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
console.print("")
tool_call_name = tool_call.function.name
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object我们需要使用 json.loads 反序列化一下
try:
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
except Exception as e:
tool_result=f'工具调用失败{e}'
# 将工具调用结果保存到单独的列表中,使用指定格式包裹
formatted_result = f"[{tool_call_name} content begin]{tool_result}[{tool_call_name} content end]"
tool_results.append({
"tool_name": tool_call_name,
"tool_id": tool_call.id,
"formatted_result": formatted_result,
"raw_result": tool_result
})
# 打印保存的工具结果信息
console.print(f"[bold green]已保存工具结果:[/] [cyan]{tool_call_name}[/]")
# 构造工具响应消息
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call_name,
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
}
# 打印工具响应
print_separator(f"工具响应: {tool_call_name}")
print_message(tool_message)
# 添加消息到上下文
#messages.append(tool_message)
# 打印最终响应
if choice.message.content:
print_separator("最终响应")
console.print(Panel(choice.message.content, border_style="green"))
# 打印收集的所有工具调用结果
if tool_results:
print_separator("所有工具调用结果")
console.print(f"[bold cyan]共收集了 {len(tool_results)} 个工具调用结果[/]")
# 将所有格式化的结果写入文件
with open("tool_results.txt", "w", encoding="utf-8") as f:
for result in tool_results:
f.write(f"{result['formatted_result']}\n\n")
console.print(f"[bold green]所有工具调用结果已保存到 tool_results.txt 文件中[/]")