Files
multi_mcp/test_tools/multi_round_conversation.py
2025-05-09 14:16:33 +08:00

522 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import sys
import os
import re
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
from rich import box
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}]
# 初始化rich控制台
console = Console()
# 获取工具模式和映射
tools_schemas = get_domain_tool_schemas(["material", 'general'])
tool_map = get_domain_tools(["material", 'general'])
# API配置
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url = "http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
def get_t1_response(messages):
"""获取T1模型的响应"""
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]正在调用MARS-T1模型..."),
transient=True,
) as progress:
progress.add_task("", total=None)
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas,
)
choice = completion.choices[0]
reasoning_content = choice.message.content
tool_calls_list = []
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
return reasoning_content, tool_calls_list
async def execute_tool(tool_name, tool_arguments):
"""执行工具调用"""
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
return result
finally:
# 清除LLM调用上下文标记
pass
def get_all_tool_calls_results(tool_calls_list):
"""获取所有工具调用的结果"""
all_results = []
with Progress(
SpinnerColumn(),
TextColumn("[bold green]正在执行工具调用..."),
transient=True,
) as progress:
task = progress.add_task("", total=len(tool_calls_list))
for tool_call in tool_calls_list:
tool_name = tool_call['function']['name']
tool_arguments = tool_call['function']['arguments']
# 显示当前执行的工具
progress.update(task, description=f"执行 {tool_name}")
result = asyncio.run(execute_tool(tool_name, tool_arguments))
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
all_results.append(result_str)
# 更新进度
progress.update(task, advance=1)
return all_results
def get_response_from_r1(messages):
"""获取R1模型的响应"""
with Progress(
SpinnerColumn(),
TextColumn("[bold purple]正在调用MARS-R1模型..."),
transient=True,
) as progress:
progress.add_task("", total=None)
completion = client.chat.completions.create(
model="MARS-R1",
messages=messages,
temperature=0.3,
)
choice = completion.choices[0]
return choice.message.content
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
"""显示单条消息"""
title = role.capitalize()
if model:
title = f"{model} {title}"
if role == "user":
console.print(Panel(
content,
title=f"[{title_style}]{title}[/{title_style}]",
border_style=border_style,
expand=False
))
elif role == "assistant" and model == "MARS-T1":
console.print(Panel(
content,
title=f"[bold yellow]{title}[/bold yellow]",
border_style="yellow",
expand=False
))
elif role == "tool":
# 创建一个表格来显示工具调用结果
table = Table(box=box.ROUNDED, expand=False, show_header=False)
table.add_column("内容", style="green")
# 分割工具调用结果并添加到表格
results = content.split("\n")
for result in results:
table.add_row(result)
console.print(Panel(
table,
title=f"[bold green]{title}[/bold green]",
border_style="green",
expand=False
))
elif role == "assistant" and model == "MARS-R1":
try:
# 尝试将内容解析为Markdown
md = Markdown(content)
console.print(Panel(
md,
title=f"[bold purple]{title}[/bold purple]",
border_style="purple",
expand=False
))
except:
# 如果解析失败,直接显示文本
console.print(Panel(
content,
title=f"[bold purple]{title}[/bold purple]",
border_style="purple",
expand=False
))
def process_conversation_round(user_input, conversation_history=None):
"""处理一轮对话,返回更新后的对话历史"""
if conversation_history is None:
conversation_history = []
# 添加用户消息到外部历史
conversation_history.append({
"role": "user",
"content": user_input
})
# 显示用户消息
display_message("user", user_input)
# 内部循环变量
max_iterations = 3 # 防止无限循环
iterations = 0
# 分别管理T1和R1的对话历史
t1_messages = []
r1_messages = []
# 初始化T1消息历史从外部历史中提取用户和助手消息
for msg in conversation_history:
if msg["role"] in ["user", "assistant"]:
t1_messages.append({
"role": msg["role"],
"content": msg["content"]
})
# 当前问题(初始为用户输入)
current_question = user_input
while iterations < max_iterations:
iterations += 1
# 如果不是第一次迭代添加R1生成的后续问题作为新的用户消息
if iterations > 1:
# 显示后续问题
display_message("user", f"[后续问题] {current_question}")
# 添加到T1消息历史
t1_messages.append({
"role": "user",
"content": current_question
})
# 获取T1模型的响应
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
# 显示T1推理
display_message("assistant", reasoning_content, model="MARS-T1")
# 添加T1的回答到外部历史
conversation_history.append({
"role": "assistant",
"content": reasoning_content,
"model": "MARS-T1"
})
# 添加T1的回答到T1消息历史
t1_messages.append({
"role": "assistant",
"content": reasoning_content
})
# 如果没有工具调用使用T1的推理作为最终答案
if not tool_calls_list:
# 添加相同的回答作为R1的回答因为没有工具调用
conversation_history.append({
"role": "assistant",
"content": reasoning_content,
"model": "MARS-R1"
})
display_message("assistant", reasoning_content, model="MARS-R1")
break
# 执行工具调用并获取结果
tool_call_results = get_all_tool_calls_results(tool_calls_list)
tool_call_results_str = "\n".join(tool_call_results)
# 添加工具调用结果到外部历史
conversation_history.append({
"role": "tool",
"content": tool_call_results_str
})
# 显示工具调用结果
display_message("tool", tool_call_results_str)
# 重置R1消息历史每次迭代都重新构建
r1_messages = []
# 添加系统消息指导R1如何处理信息
r1_messages.append({
"role": "system",
"content": """你是一个能够分析工具调用结果并回答问题的助手。
请分析提供的信息,并执行以下操作之一:
1. 如果你能够基于提供的工具调用信息直接回答原始问题,请提供完整的回答。
2. 如果目前的工具调用信息不足以让你回答原始问题,请明确说明缺少哪些信息,并生成一个新的问题来获取这些信息。
新问题格式:<FOLLOW_UP_QUESTION>你的问题</FOLLOW_UP_QUESTION>
注意:如果你生成了后续问题,系统将自动将其发送给工具调用模型以获取更多信息。"""
})
# 构建R1的用户消息包含原始问题、工具调用信息和结果
r1_user_message = f"""# 原始问题
{user_input}
# 工具调用信息
{reasoning_content}
# 工具调用结果
{tool_call_results_str}"""
# 如果有后续问题添加到R1用户消息
if iterations > 1:
r1_user_message += f"\n\n# 后续问题\n{current_question}"
# 添加构建好的用户消息
r1_messages.append({
"role": "user",
"content": r1_user_message
})
# 获取R1模型的响应
r1_response = get_response_from_r1(r1_messages)
# 显示R1回答
display_message("assistant", r1_response, model="MARS-R1")
# 检查R1是否生成了后续问题
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', r1_response, re.DOTALL)
if follow_up_match:
# 提取后续问题
follow_up_question = follow_up_match.group(1).strip()
# 将后续问题作为新的当前问题
current_question = follow_up_question
# 添加R1的回答到外部历史不包括后续问题标记
clean_response = r1_response.replace(follow_up_match.group(0), "")
conversation_history.append({
"role": "assistant",
"content": clean_response,
"model": "MARS-R1"
})
# 继续循环使用新问题调用T1
else:
# R1能够回答问题添加回答到历史并结束循环
conversation_history.append({
"role": "assistant",
"content": r1_response,
"model": "MARS-R1"
})
break
return conversation_history
def run_demo():
"""运行演示,使用初始消息作为第一个用户问题"""
console.print(Panel.fit(
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
border_style="cyan"
))
# 获取初始用户问题
initial_user_input = initial_message[0]["content"]
# 处理第一轮对话
conversation_history = process_conversation_round(initial_user_input)
# 检查R1是否生成了后续问题并自动处理
auto_process_follow_up_questions(conversation_history)
# 多轮对话循环
while True:
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit''quit' 退出[/bold cyan]")
user_input = input("> ")
# 检查是否退出
if user_input.lower() in ['exit', 'quit', '退出']:
console.print("[bold cyan]演示结束,再见![/bold cyan]")
break
# 处理用户输入
conversation_history = process_conversation_round(user_input, conversation_history)
# 检查R1是否生成了后续问题并自动处理
auto_process_follow_up_questions(conversation_history)
def auto_process_follow_up_questions(conversation_history):
"""自动处理R1生成的后续问题"""
# 检查最后一条消息是否是R1的回答
if not conversation_history or len(conversation_history) == 0:
return
last_message = conversation_history[-1]
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
# 检查是否包含后续问题
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
if follow_up_match:
# 提取后续问题
follow_up_question = follow_up_match.group(1).strip()
# 显示检测到的后续问题
console.print(Panel(
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
border_style="yellow",
expand=False
))
# 自动处理后续问题
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
# 递归处理后续问题,直到没有更多后续问题或达到最大迭代次数
max_auto_iterations = 3
current_iterations = 0
while current_iterations < max_auto_iterations:
current_iterations += 1
# 处理后续问题
conversation_history = process_conversation_round(follow_up_question, conversation_history)
# 检查是否还有后续问题
if len(conversation_history) > 0:
last_message = conversation_history[-1]
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
if follow_up_match:
# 提取后续问题
follow_up_question = follow_up_match.group(1).strip()
# 显示检测到的后续问题
console.print(Panel(
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
border_style="yellow",
expand=False
))
# 自动处理后续问题
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
continue
# 如果没有更多后续问题,退出循环
break
if __name__ == '__main__':
try:
run_demo()
except KeyboardInterrupt:
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
except Exception as e:
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
import traceback
console.print(traceback.format_exc())