Files
multi_mcp/test_tools/demo_conversation.py
2025-05-09 14:16:33 +08:00

376 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import sys
import os
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
from rich import box
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}]
# 初始化rich控制台
console = Console()
# 获取工具模式和映射
tools_schemas = get_domain_tool_schemas(["material", 'general'])
tool_map = get_domain_tools(["material", 'general'])
# API配置
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url = "http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
def get_t1_response(messages):
"""获取T1模型的响应"""
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]正在调用MARS-T1模型..."),
transient=True,
) as progress:
progress.add_task("", total=None)
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas,
)
choice = completion.choices[0]
reasoning_content = choice.message.content
tool_calls_list = []
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
return reasoning_content, tool_calls_list
async def execute_tool(tool_name, tool_arguments):
"""执行工具调用"""
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
return result
finally:
# 清除LLM调用上下文标记
pass
def get_all_tool_calls_results(tool_calls_list):
"""获取所有工具调用的结果"""
all_results = []
with Progress(
SpinnerColumn(),
TextColumn("[bold green]正在执行工具调用..."),
transient=True,
) as progress:
task = progress.add_task("", total=len(tool_calls_list))
for tool_call in tool_calls_list:
tool_name = tool_call['function']['name']
tool_arguments = tool_call['function']['arguments']
# 显示当前执行的工具
progress.update(task, description=f"执行 {tool_name}")
result = asyncio.run(execute_tool(tool_name, tool_arguments))
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
all_results.append(result_str)
# 更新进度
progress.update(task, advance=1)
return all_results
def get_response_from_r1(messages):
"""获取R1模型的响应"""
with Progress(
SpinnerColumn(),
TextColumn("[bold purple]正在调用MARS-R1模型..."),
transient=True,
) as progress:
progress.add_task("", total=None)
completion = client.chat.completions.create(
model="MARS-R1",
messages=messages,
temperature=0.3,
)
choice = completion.choices[0]
return choice.message.content
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
"""显示单条消息"""
title = role.capitalize()
if model:
title = f"{model} {title}"
if role == "user":
console.print(Panel(
content,
title=f"[{title_style}]{title}[/{title_style}]",
border_style=border_style,
expand=False
))
elif role == "assistant" and model == "MARS-T1":
console.print(Panel(
content,
title=f"[bold yellow]{title}[/bold yellow]",
border_style="yellow",
expand=False
))
elif role == "tool":
# 创建一个表格来显示工具调用结果
table = Table(box=box.ROUNDED, expand=False, show_header=False)
table.add_column("内容", style="green")
# 分割工具调用结果并添加到表格
results = content.split("\n")
for result in results:
table.add_row(result)
console.print(Panel(
table,
title=f"[bold green]{title}[/bold green]",
border_style="green",
expand=False
))
elif role == "assistant" and model == "MARS-R1":
try:
# 尝试将内容解析为Markdown
md = Markdown(content)
console.print(Panel(
md,
title=f"[bold purple]{title}[/bold purple]",
border_style="purple",
expand=False
))
except:
# 如果解析失败,直接显示文本
console.print(Panel(
content,
title=f"[bold purple]{title}[/bold purple]",
border_style="purple",
expand=False
))
def process_conversation_round(user_input, conversation_history=None):
"""处理一轮对话,返回更新后的对话历史"""
if conversation_history is None:
conversation_history = []
# 添加用户消息到历史
conversation_history.append({
"role": "user",
"content": user_input
})
# 显示用户消息
display_message("user", user_input)
# 准备发送给T1模型的消息
t1_messages = []
for msg in conversation_history:
if msg["role"] in ["user", "assistant"]:
t1_messages.append({
"role": msg["role"],
"content": msg["content"]
})
# 获取T1模型的响应
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
# 添加T1推理到历史
conversation_history.append({
"role": "assistant",
"content": reasoning_content,
"model": "MARS-T1"
})
# 显示T1推理
display_message("assistant", reasoning_content, model="MARS-T1")
# 如果有工具调用,执行并获取结果
if tool_calls_list:
tool_call_results = get_all_tool_calls_results(tool_calls_list)
tool_call_results_str = "\n".join(tool_call_results)
# 添加工具调用结果到历史
conversation_history.append({
"role": "tool",
"content": tool_call_results_str
})
# 显示工具调用结果
display_message("tool", tool_call_results_str)
# 准备发送给R1模型的消息
user_message = {
"role": "user",
"content": f"# 信息如下:\n{tool_call_results_str}\n# 问题如下:\n{user_input}"
}
# 获取R1模型的响应
r1_response = get_response_from_r1([user_message])
# 添加R1回答到历史
conversation_history.append({
"role": "assistant",
"content": r1_response,
"model": "MARS-R1"
})
# 显示R1回答
display_message("assistant", r1_response, model="MARS-R1")
else:
# 如果没有工具调用直接使用T1的推理作为回答
conversation_history.append({
"role": "assistant",
"content": reasoning_content,
"model": "MARS-R1"
})
# 显示R1回答实际上是T1的推理
display_message("assistant", reasoning_content, model="MARS-R1")
return conversation_history
def run_demo():
"""运行演示,使用初始消息作为第一个用户问题"""
console.print(Panel.fit(
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
border_style="cyan"
))
# 获取初始用户问题
initial_user_input = initial_message[0]["content"]
# 处理第一轮对话
conversation_history = process_conversation_round(initial_user_input)
# 多轮对话循环
while True:
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit''quit' 退出[/bold cyan]")
user_input = input("> ")
# 检查是否退出
if user_input.lower() in ['exit', 'quit', '退出']:
console.print("[bold cyan]演示结束,再见![/bold cyan]")
break
# 处理用户输入
conversation_history = process_conversation_round(user_input, conversation_history)
if __name__ == '__main__':
try:
run_demo()
except KeyboardInterrupt:
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
except Exception as e:
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
import traceback
console.print(traceback.format_exc())