初次提交
This commit is contained in:
521
test_tools/multi_round_conversation.py
Normal file
521
test_tools/multi_round_conversation.py
Normal file
@@ -0,0 +1,521 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}]
|
||||
# 初始化rich控制台
|
||||
console = Console()
|
||||
|
||||
# 获取工具模式和映射
|
||||
tools_schemas = get_domain_tool_schemas(["material", 'general'])
|
||||
tool_map = get_domain_tools(["material", 'general'])
|
||||
|
||||
# API配置
|
||||
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url = "http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
def get_t1_response(messages):
|
||||
"""获取T1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]正在调用MARS-T1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
reasoning_content = choice.message.content
|
||||
|
||||
tool_calls_list = []
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
|
||||
return reasoning_content, tool_calls_list
|
||||
|
||||
async def execute_tool(tool_name, tool_arguments):
|
||||
"""执行工具调用"""
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
def get_all_tool_calls_results(tool_calls_list):
|
||||
"""获取所有工具调用的结果"""
|
||||
all_results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold green]正在执行工具调用..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
task = progress.add_task("", total=len(tool_calls_list))
|
||||
|
||||
for tool_call in tool_calls_list:
|
||||
tool_name = tool_call['function']['name']
|
||||
tool_arguments = tool_call['function']['arguments']
|
||||
|
||||
# 显示当前执行的工具
|
||||
progress.update(task, description=f"执行 {tool_name}")
|
||||
|
||||
result = asyncio.run(execute_tool(tool_name, tool_arguments))
|
||||
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
|
||||
all_results.append(result_str)
|
||||
|
||||
# 更新进度
|
||||
progress.update(task, advance=1)
|
||||
|
||||
return all_results
|
||||
|
||||
def get_response_from_r1(messages):
|
||||
"""获取R1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold purple]正在调用MARS-R1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-R1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
return choice.message.content
|
||||
|
||||
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
|
||||
"""显示单条消息"""
|
||||
title = role.capitalize()
|
||||
if model:
|
||||
title = f"{model} {title}"
|
||||
|
||||
if role == "user":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[{title_style}]{title}[/{title_style}]",
|
||||
border_style=border_style,
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-T1":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold yellow]{title}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
elif role == "tool":
|
||||
# 创建一个表格来显示工具调用结果
|
||||
table = Table(box=box.ROUNDED, expand=False, show_header=False)
|
||||
table.add_column("内容", style="green")
|
||||
|
||||
# 分割工具调用结果并添加到表格
|
||||
results = content.split("\n")
|
||||
for result in results:
|
||||
table.add_row(result)
|
||||
|
||||
console.print(Panel(
|
||||
table,
|
||||
title=f"[bold green]{title}[/bold green]",
|
||||
border_style="green",
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-R1":
|
||||
try:
|
||||
# 尝试将内容解析为Markdown
|
||||
md = Markdown(content)
|
||||
console.print(Panel(
|
||||
md,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
except:
|
||||
# 如果解析失败,直接显示文本
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
|
||||
def process_conversation_round(user_input, conversation_history=None):
|
||||
"""处理一轮对话,返回更新后的对话历史"""
|
||||
if conversation_history is None:
|
||||
conversation_history = []
|
||||
|
||||
# 添加用户消息到外部历史
|
||||
conversation_history.append({
|
||||
"role": "user",
|
||||
"content": user_input
|
||||
})
|
||||
|
||||
# 显示用户消息
|
||||
display_message("user", user_input)
|
||||
|
||||
# 内部循环变量
|
||||
max_iterations = 3 # 防止无限循环
|
||||
iterations = 0
|
||||
|
||||
# 分别管理T1和R1的对话历史
|
||||
t1_messages = []
|
||||
r1_messages = []
|
||||
|
||||
# 初始化T1消息历史(从外部历史中提取用户和助手消息)
|
||||
for msg in conversation_history:
|
||||
if msg["role"] in ["user", "assistant"]:
|
||||
t1_messages.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"]
|
||||
})
|
||||
|
||||
# 当前问题(初始为用户输入)
|
||||
current_question = user_input
|
||||
|
||||
while iterations < max_iterations:
|
||||
iterations += 1
|
||||
|
||||
# 如果不是第一次迭代,添加R1生成的后续问题作为新的用户消息
|
||||
if iterations > 1:
|
||||
# 显示后续问题
|
||||
display_message("user", f"[后续问题] {current_question}")
|
||||
|
||||
# 添加到T1消息历史
|
||||
t1_messages.append({
|
||||
"role": "user",
|
||||
"content": current_question
|
||||
})
|
||||
|
||||
# 获取T1模型的响应
|
||||
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
|
||||
|
||||
# 显示T1推理
|
||||
display_message("assistant", reasoning_content, model="MARS-T1")
|
||||
|
||||
# 添加T1的回答到外部历史
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-T1"
|
||||
})
|
||||
|
||||
# 添加T1的回答到T1消息历史
|
||||
t1_messages.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content
|
||||
})
|
||||
|
||||
# 如果没有工具调用,使用T1的推理作为最终答案
|
||||
if not tool_calls_list:
|
||||
# 添加相同的回答作为R1的回答(因为没有工具调用)
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
display_message("assistant", reasoning_content, model="MARS-R1")
|
||||
break
|
||||
|
||||
# 执行工具调用并获取结果
|
||||
tool_call_results = get_all_tool_calls_results(tool_calls_list)
|
||||
tool_call_results_str = "\n".join(tool_call_results)
|
||||
|
||||
# 添加工具调用结果到外部历史
|
||||
conversation_history.append({
|
||||
"role": "tool",
|
||||
"content": tool_call_results_str
|
||||
})
|
||||
|
||||
# 显示工具调用结果
|
||||
display_message("tool", tool_call_results_str)
|
||||
|
||||
# 重置R1消息历史(每次迭代都重新构建)
|
||||
r1_messages = []
|
||||
|
||||
# 添加系统消息,指导R1如何处理信息
|
||||
r1_messages.append({
|
||||
"role": "system",
|
||||
"content": """你是一个能够分析工具调用结果并回答问题的助手。
|
||||
请分析提供的信息,并执行以下操作之一:
|
||||
1. 如果你能够基于提供的工具调用信息直接回答原始问题,请提供完整的回答。
|
||||
2. 如果目前的工具调用信息不足以让你回答原始问题,请明确说明缺少哪些信息,并生成一个新的问题来获取这些信息。
|
||||
新问题格式:<FOLLOW_UP_QUESTION>你的问题</FOLLOW_UP_QUESTION>
|
||||
|
||||
注意:如果你生成了后续问题,系统将自动将其发送给工具调用模型以获取更多信息。"""
|
||||
})
|
||||
|
||||
# 构建R1的用户消息,包含原始问题、工具调用信息和结果
|
||||
r1_user_message = f"""# 原始问题
|
||||
{user_input}
|
||||
|
||||
# 工具调用信息
|
||||
{reasoning_content}
|
||||
|
||||
# 工具调用结果
|
||||
{tool_call_results_str}"""
|
||||
|
||||
# 如果有后续问题,添加到R1用户消息
|
||||
if iterations > 1:
|
||||
r1_user_message += f"\n\n# 后续问题\n{current_question}"
|
||||
|
||||
# 添加构建好的用户消息
|
||||
r1_messages.append({
|
||||
"role": "user",
|
||||
"content": r1_user_message
|
||||
})
|
||||
|
||||
# 获取R1模型的响应
|
||||
r1_response = get_response_from_r1(r1_messages)
|
||||
|
||||
# 显示R1回答
|
||||
display_message("assistant", r1_response, model="MARS-R1")
|
||||
|
||||
# 检查R1是否生成了后续问题
|
||||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', r1_response, re.DOTALL)
|
||||
|
||||
if follow_up_match:
|
||||
# 提取后续问题
|
||||
follow_up_question = follow_up_match.group(1).strip()
|
||||
|
||||
# 将后续问题作为新的当前问题
|
||||
current_question = follow_up_question
|
||||
|
||||
# 添加R1的回答到外部历史(不包括后续问题标记)
|
||||
clean_response = r1_response.replace(follow_up_match.group(0), "")
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": clean_response,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
# 继续循环,使用新问题调用T1
|
||||
else:
|
||||
# R1能够回答问题,添加回答到历史并结束循环
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": r1_response,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
break
|
||||
|
||||
return conversation_history
|
||||
|
||||
def run_demo():
|
||||
"""运行演示,使用初始消息作为第一个用户问题"""
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# 获取初始用户问题
|
||||
initial_user_input = initial_message[0]["content"]
|
||||
|
||||
# 处理第一轮对话
|
||||
conversation_history = process_conversation_round(initial_user_input)
|
||||
|
||||
# 检查R1是否生成了后续问题并自动处理
|
||||
auto_process_follow_up_questions(conversation_history)
|
||||
|
||||
# 多轮对话循环
|
||||
while True:
|
||||
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit' 或 'quit' 退出[/bold cyan]")
|
||||
user_input = input("> ")
|
||||
|
||||
# 检查是否退出
|
||||
if user_input.lower() in ['exit', 'quit', '退出']:
|
||||
console.print("[bold cyan]演示结束,再见![/bold cyan]")
|
||||
break
|
||||
|
||||
# 处理用户输入
|
||||
conversation_history = process_conversation_round(user_input, conversation_history)
|
||||
|
||||
# 检查R1是否生成了后续问题并自动处理
|
||||
auto_process_follow_up_questions(conversation_history)
|
||||
|
||||
def auto_process_follow_up_questions(conversation_history):
|
||||
"""自动处理R1生成的后续问题"""
|
||||
# 检查最后一条消息是否是R1的回答
|
||||
if not conversation_history or len(conversation_history) == 0:
|
||||
return
|
||||
|
||||
last_message = conversation_history[-1]
|
||||
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
|
||||
# 检查是否包含后续问题
|
||||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
|
||||
if follow_up_match:
|
||||
# 提取后续问题
|
||||
follow_up_question = follow_up_match.group(1).strip()
|
||||
|
||||
# 显示检测到的后续问题
|
||||
console.print(Panel(
|
||||
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
|
||||
# 自动处理后续问题
|
||||
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
|
||||
|
||||
# 递归处理后续问题,直到没有更多后续问题或达到最大迭代次数
|
||||
max_auto_iterations = 3
|
||||
current_iterations = 0
|
||||
|
||||
while current_iterations < max_auto_iterations:
|
||||
current_iterations += 1
|
||||
|
||||
# 处理后续问题
|
||||
conversation_history = process_conversation_round(follow_up_question, conversation_history)
|
||||
|
||||
# 检查是否还有后续问题
|
||||
if len(conversation_history) > 0:
|
||||
last_message = conversation_history[-1]
|
||||
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
|
||||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
|
||||
if follow_up_match:
|
||||
# 提取后续问题
|
||||
follow_up_question = follow_up_match.group(1).strip()
|
||||
|
||||
# 显示检测到的后续问题
|
||||
console.print(Panel(
|
||||
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
|
||||
# 自动处理后续问题
|
||||
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
|
||||
continue
|
||||
|
||||
# 如果没有更多后续问题,退出循环
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
run_demo()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
|
||||
import traceback
|
||||
console.print(traceback.format_exc())
|
||||
Reference in New Issue
Block a user