初次提交
This commit is contained in:
339
test_tools/test_mars_t1_r1.py
Normal file
339
test_tools/test_mars_t1_r1.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
# 创建Rich控制台对象
|
||||
console = Console()
|
||||
|
||||
# 创建一个列表来存储工具调用结果
|
||||
tool_results = []
|
||||
|
||||
# 定义分隔符样式
|
||||
def print_separator(title=""):
|
||||
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
|
||||
|
||||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
from sci_mcp import *
|
||||
|
||||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||||
tool_map = get_domain_tools(["material",'general'])
|
||||
|
||||
|
||||
# 打印消息的函数
|
||||
def print_message(message):
|
||||
# 处理不同类型的消息对象
|
||||
if hasattr(message, 'role'): # ChatCompletionMessage 对象
|
||||
role = message.role
|
||||
content = message.content if hasattr(message, 'content') else ""
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = None
|
||||
if role == "tool" and hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
else: # 字典类型
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = message.get("name") if role == "tool" else None
|
||||
|
||||
# 根据角色选择不同的颜色
|
||||
role_colors = {
|
||||
"system": "bright_blue",
|
||||
"user": "green",
|
||||
"assistant": "yellow",
|
||||
"tool": "bright_red"
|
||||
}
|
||||
color = role_colors.get(role, "white")
|
||||
|
||||
# 创建富文本面板
|
||||
text = Text()
|
||||
|
||||
# 如果是工具消息,添加工具名称
|
||||
if role == "tool" and tool_name:
|
||||
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
|
||||
else:
|
||||
text.append(f"{role}: ", style=f"bold {color}")
|
||||
|
||||
text.append(str(content))
|
||||
console.print(Panel(text, border_style=color))
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}
|
||||
]
|
||||
#how to synthesize CsPbBr3 at room temperature
|
||||
#
|
||||
# 打印初始消息
|
||||
print_separator("初始消息")
|
||||
for message in messages:
|
||||
print_message(message)
|
||||
finish_reason = None
|
||||
|
||||
async def execute_tool(tool_name,tool_arguments):
|
||||
# 设置LLM调用上下文标记
|
||||
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
# if func_name=='generate_material':
|
||||
# print("xxxxx",result)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
# 定义一个函数来估算消息的token数量(粗略估计)
|
||||
def estimate_tokens(messages):
|
||||
# 简单估计:每个英文单词约1.3个token,每个中文字符约1个token
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "") if isinstance(msg, dict) else (msg.content if hasattr(msg, "content") else "")
|
||||
if content:
|
||||
# 粗略估计内容的token数
|
||||
total += len(content) * 1.3
|
||||
|
||||
# 估计工具调用的token数
|
||||
tool_calls = msg.get("tool_calls", []) if isinstance(msg, dict) else (msg.tool_calls if hasattr(msg, "tool_calls") else [])
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
args = tool_call.get("function", {}).get("arguments", "")
|
||||
total += len(args) * 1.3
|
||||
else:
|
||||
args = tool_call.function.arguments if hasattr(tool_call, "function") else ""
|
||||
total += len(args) * 1.3
|
||||
|
||||
return int(total)
|
||||
|
||||
# 管理消息历史,保持在token限制以内
|
||||
def manage_message_history(messages, max_tokens=7000):
|
||||
# 保留第一条消息(通常是系统消息或初始用户消息)
|
||||
if len(messages) <= 1:
|
||||
return messages
|
||||
|
||||
# 估算当前消息的token数
|
||||
current_tokens = estimate_tokens(messages)
|
||||
|
||||
# 如果当前token数已经接近限制,开始裁剪历史消息
|
||||
if current_tokens > max_tokens:
|
||||
# 保留第一条消息和最近的消息
|
||||
preserved_messages = [messages[0]]
|
||||
|
||||
# 从最新的消息开始添加,直到接近但不超过token限制
|
||||
temp_messages = []
|
||||
for msg in reversed(messages[1:]):
|
||||
temp_messages.insert(0, msg)
|
||||
if estimate_tokens([messages[0]] + temp_messages) > max_tokens:
|
||||
# 如果添加这条消息会超过限制,则停止添加
|
||||
temp_messages.pop(0)
|
||||
break
|
||||
|
||||
# 如果裁剪后的消息太少,至少保留最近的几条消息
|
||||
if len(temp_messages) < 4 and len(messages) > 4:
|
||||
temp_messages = messages[-4:]
|
||||
|
||||
return preserved_messages + temp_messages
|
||||
|
||||
return messages
|
||||
|
||||
while finish_reason is None or finish_reason == "tool_calls":
|
||||
# 在发送请求前管理消息历史
|
||||
managed_messages = manage_message_history(messages)
|
||||
if len(managed_messages) < len(messages):
|
||||
print_separator(f"消息历史已裁剪,从{len(messages)}条减少到{len(managed_messages)}条")
|
||||
messages = managed_messages
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
timeout=120,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
|
||||
# 打印assistant消息
|
||||
print_separator("Assistant消息")
|
||||
print_message(choice.message)
|
||||
|
||||
# 将ChatCompletionMessage对象转换为字典
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": choice.message.content if hasattr(choice.message, 'content') else None
|
||||
}
|
||||
|
||||
# 如果有工具调用,添加到字典中
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
# 将tool_calls对象转换为字典列表
|
||||
tool_calls_list = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
assistant_message["tool_calls"] = tool_calls_list
|
||||
|
||||
# 添加消息到上下文
|
||||
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
|
||||
|
||||
# 打印工具调用信息
|
||||
print_separator("工具调用")
|
||||
for tool_call in choice.message.tool_calls:
|
||||
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
|
||||
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
|
||||
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
|
||||
console.print("")
|
||||
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object,我们需要使用 json.loads 反序列化一下
|
||||
try:
|
||||
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
|
||||
except Exception as e:
|
||||
tool_result=f'工具调用失败{e}'
|
||||
|
||||
# 将工具调用结果保存到单独的列表中,使用指定格式包裹
|
||||
formatted_result = f"[{tool_call_name} content begin]{tool_result}[{tool_call_name} content end]"
|
||||
tool_results.append({
|
||||
"tool_name": tool_call_name,
|
||||
"tool_id": tool_call.id,
|
||||
"formatted_result": formatted_result,
|
||||
"raw_result": tool_result
|
||||
})
|
||||
|
||||
# 打印保存的工具结果信息
|
||||
console.print(f"[bold green]已保存工具结果:[/] [cyan]{tool_call_name}[/]")
|
||||
|
||||
# 构造工具响应消息
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
|
||||
}
|
||||
|
||||
# 打印工具响应
|
||||
print_separator(f"工具响应: {tool_call_name}")
|
||||
print_message(tool_message)
|
||||
|
||||
# 添加消息到上下文
|
||||
#messages.append(tool_message)
|
||||
|
||||
# 打印最终响应
|
||||
if choice.message.content:
|
||||
print_separator("最终响应")
|
||||
console.print(Panel(choice.message.content, border_style="green"))
|
||||
|
||||
# 打印收集的所有工具调用结果
|
||||
if tool_results:
|
||||
print_separator("所有工具调用结果")
|
||||
console.print(f"[bold cyan]共收集了 {len(tool_results)} 个工具调用结果[/]")
|
||||
|
||||
# 将所有格式化的结果写入文件
|
||||
with open("tool_results.txt", "w", encoding="utf-8") as f:
|
||||
for result in tool_results:
|
||||
f.write(f"{result['formatted_result']}\n\n")
|
||||
|
||||
console.print(f"[bold green]所有工具调用结果已保存到 tool_results.txt 文件中[/]")
|
||||
Reference in New Issue
Block a user