522 lines
18 KiB
Python
522 lines
18 KiB
Python
import asyncio
|
||
import json
|
||
import sys
|
||
import os
|
||
import re
|
||
from openai import OpenAI
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from rich.markdown import Markdown
|
||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||
from rich.table import Table
|
||
from rich import box
|
||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from sci_mcp import *
|
||
|
||
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
|
||
_symmetry_space_group_name_H-M Fmmm
|
||
_cell_length_a 3.18353600
|
||
_cell_length_b 4.52677200
|
||
_cell_length_c 22.74397000
|
||
_cell_angle_alpha 90.00000000
|
||
_cell_angle_beta 90.00000000
|
||
_cell_angle_gamma 90.00000000
|
||
_symmetry_Int_Tables_number 69
|
||
_chemical_formula_structural Ti4V
|
||
_chemical_formula_sum 'Ti16 V4'
|
||
_cell_volume 327.76657340
|
||
_cell_formula_units_Z 4
|
||
loop_
|
||
_symmetry_equiv_pos_site_id
|
||
_symmetry_equiv_pos_as_xyz
|
||
1 'x, y, z'
|
||
2 '-x, -y, -z'
|
||
3 '-x, -y, z'
|
||
4 'x, y, -z'
|
||
5 'x, -y, -z'
|
||
6 '-x, y, z'
|
||
7 '-x, y, -z'
|
||
8 'x, -y, z'
|
||
9 'x+1/2, y, z+1/2'
|
||
10 '-x+1/2, -y, -z+1/2'
|
||
11 '-x+1/2, -y, z+1/2'
|
||
12 'x+1/2, y, -z+1/2'
|
||
13 'x+1/2, -y, -z+1/2'
|
||
14 '-x+1/2, y, z+1/2'
|
||
15 '-x+1/2, y, -z+1/2'
|
||
16 'x+1/2, -y, z+1/2'
|
||
17 'x+1/2, y+1/2, z'
|
||
18 '-x+1/2, -y+1/2, -z'
|
||
19 '-x+1/2, -y+1/2, z'
|
||
20 'x+1/2, y+1/2, -z'
|
||
21 'x+1/2, -y+1/2, -z'
|
||
22 '-x+1/2, y+1/2, z'
|
||
23 '-x+1/2, y+1/2, -z'
|
||
24 'x+1/2, -y+1/2, z'
|
||
25 'x, y+1/2, z+1/2'
|
||
26 '-x, -y+1/2, -z+1/2'
|
||
27 '-x, -y+1/2, z+1/2'
|
||
28 'x, y+1/2, -z+1/2'
|
||
29 'x, -y+1/2, -z+1/2'
|
||
30 '-x, y+1/2, z+1/2'
|
||
31 '-x, y+1/2, -z+1/2'
|
||
32 'x, -y+1/2, z+1/2'
|
||
loop_
|
||
_atom_site_type_symbol
|
||
_atom_site_label
|
||
_atom_site_symmetry_multiplicity
|
||
_atom_site_fract_x
|
||
_atom_site_fract_y
|
||
_atom_site_fract_z
|
||
_atom_site_occupancy
|
||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}]
|
||
# 初始化rich控制台
|
||
console = Console()
|
||
|
||
# 获取工具模式和映射
|
||
tools_schemas = get_domain_tool_schemas(["material", 'general'])
|
||
tool_map = get_domain_tools(["material", 'general'])
|
||
|
||
# API配置
|
||
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||
base_url = "http://gpustack.ddwtop.team/v1-openai"
|
||
client = OpenAI(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
|
||
def get_t1_response(messages):
|
||
"""获取T1模型的响应"""
|
||
with Progress(
|
||
SpinnerColumn(),
|
||
TextColumn("[bold blue]正在调用MARS-T1模型..."),
|
||
transient=True,
|
||
) as progress:
|
||
progress.add_task("", total=None)
|
||
completion = client.chat.completions.create(
|
||
model="MARS-T1",
|
||
messages=messages,
|
||
temperature=0.3,
|
||
tools=tools_schemas,
|
||
)
|
||
|
||
choice = completion.choices[0]
|
||
reasoning_content = choice.message.content
|
||
|
||
tool_calls_list = []
|
||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||
for tool_call in choice.message.tool_calls:
|
||
tool_call_dict = {
|
||
"id": tool_call.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_call.function.name,
|
||
"arguments": tool_call.function.arguments
|
||
}
|
||
}
|
||
tool_calls_list.append(tool_call_dict)
|
||
|
||
return reasoning_content, tool_calls_list
|
||
|
||
async def execute_tool(tool_name, tool_arguments):
|
||
"""执行工具调用"""
|
||
try:
|
||
tool_func = tool_map[tool_name] # 获取工具函数
|
||
arguments = {}
|
||
if tool_arguments:
|
||
# 检查arguments是字符串还是字典
|
||
if isinstance(tool_arguments, dict):
|
||
# 如果已经是字典,直接使用
|
||
arguments = tool_arguments
|
||
elif isinstance(tool_arguments, str):
|
||
# 如果是字符串,尝试解析为JSON
|
||
try:
|
||
# 尝试直接解析为JSON对象
|
||
arguments = json.loads(tool_arguments)
|
||
except json.JSONDecodeError:
|
||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||
# 尝试修复常见的JSON字符串问题
|
||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||
try:
|
||
arguments = json.loads(fixed_str)
|
||
except json.JSONDecodeError:
|
||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||
arguments = {"raw_string": tool_arguments}
|
||
|
||
# 调用工具函数
|
||
if asyncio.iscoroutinefunction(tool_func):
|
||
# 如果是异步函数,使用await调用
|
||
result = await tool_func(**arguments)
|
||
else:
|
||
# 如果是同步函数,直接调用
|
||
result = tool_func(**arguments)
|
||
return result
|
||
finally:
|
||
# 清除LLM调用上下文标记
|
||
pass
|
||
|
||
def get_all_tool_calls_results(tool_calls_list):
|
||
"""获取所有工具调用的结果"""
|
||
all_results = []
|
||
|
||
with Progress(
|
||
SpinnerColumn(),
|
||
TextColumn("[bold green]正在执行工具调用..."),
|
||
transient=True,
|
||
) as progress:
|
||
task = progress.add_task("", total=len(tool_calls_list))
|
||
|
||
for tool_call in tool_calls_list:
|
||
tool_name = tool_call['function']['name']
|
||
tool_arguments = tool_call['function']['arguments']
|
||
|
||
# 显示当前执行的工具
|
||
progress.update(task, description=f"执行 {tool_name}")
|
||
|
||
result = asyncio.run(execute_tool(tool_name, tool_arguments))
|
||
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
|
||
all_results.append(result_str)
|
||
|
||
# 更新进度
|
||
progress.update(task, advance=1)
|
||
|
||
return all_results
|
||
|
||
def get_response_from_r1(messages):
|
||
"""获取R1模型的响应"""
|
||
with Progress(
|
||
SpinnerColumn(),
|
||
TextColumn("[bold purple]正在调用MARS-R1模型..."),
|
||
transient=True,
|
||
) as progress:
|
||
progress.add_task("", total=None)
|
||
completion = client.chat.completions.create(
|
||
model="MARS-R1",
|
||
messages=messages,
|
||
temperature=0.3,
|
||
|
||
)
|
||
|
||
choice = completion.choices[0]
|
||
return choice.message.content
|
||
|
||
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
|
||
"""显示单条消息"""
|
||
title = role.capitalize()
|
||
if model:
|
||
title = f"{model} {title}"
|
||
|
||
if role == "user":
|
||
console.print(Panel(
|
||
content,
|
||
title=f"[{title_style}]{title}[/{title_style}]",
|
||
border_style=border_style,
|
||
expand=False
|
||
))
|
||
elif role == "assistant" and model == "MARS-T1":
|
||
console.print(Panel(
|
||
content,
|
||
title=f"[bold yellow]{title}[/bold yellow]",
|
||
border_style="yellow",
|
||
expand=False
|
||
))
|
||
elif role == "tool":
|
||
# 创建一个表格来显示工具调用结果
|
||
table = Table(box=box.ROUNDED, expand=False, show_header=False)
|
||
table.add_column("内容", style="green")
|
||
|
||
# 分割工具调用结果并添加到表格
|
||
results = content.split("\n")
|
||
for result in results:
|
||
table.add_row(result)
|
||
|
||
console.print(Panel(
|
||
table,
|
||
title=f"[bold green]{title}[/bold green]",
|
||
border_style="green",
|
||
expand=False
|
||
))
|
||
elif role == "assistant" and model == "MARS-R1":
|
||
try:
|
||
# 尝试将内容解析为Markdown
|
||
md = Markdown(content)
|
||
console.print(Panel(
|
||
md,
|
||
title=f"[bold purple]{title}[/bold purple]",
|
||
border_style="purple",
|
||
expand=False
|
||
))
|
||
except:
|
||
# 如果解析失败,直接显示文本
|
||
console.print(Panel(
|
||
content,
|
||
title=f"[bold purple]{title}[/bold purple]",
|
||
border_style="purple",
|
||
expand=False
|
||
))
|
||
|
||
def process_conversation_round(user_input, conversation_history=None):
|
||
"""处理一轮对话,返回更新后的对话历史"""
|
||
if conversation_history is None:
|
||
conversation_history = []
|
||
|
||
# 添加用户消息到外部历史
|
||
conversation_history.append({
|
||
"role": "user",
|
||
"content": user_input
|
||
})
|
||
|
||
# 显示用户消息
|
||
display_message("user", user_input)
|
||
|
||
# 内部循环变量
|
||
max_iterations = 3 # 防止无限循环
|
||
iterations = 0
|
||
|
||
# 分别管理T1和R1的对话历史
|
||
t1_messages = []
|
||
r1_messages = []
|
||
|
||
# 初始化T1消息历史(从外部历史中提取用户和助手消息)
|
||
for msg in conversation_history:
|
||
if msg["role"] in ["user", "assistant"]:
|
||
t1_messages.append({
|
||
"role": msg["role"],
|
||
"content": msg["content"]
|
||
})
|
||
|
||
# 当前问题(初始为用户输入)
|
||
current_question = user_input
|
||
|
||
while iterations < max_iterations:
|
||
iterations += 1
|
||
|
||
# 如果不是第一次迭代,添加R1生成的后续问题作为新的用户消息
|
||
if iterations > 1:
|
||
# 显示后续问题
|
||
display_message("user", f"[后续问题] {current_question}")
|
||
|
||
# 添加到T1消息历史
|
||
t1_messages.append({
|
||
"role": "user",
|
||
"content": current_question
|
||
})
|
||
|
||
# 获取T1模型的响应
|
||
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
|
||
|
||
# 显示T1推理
|
||
display_message("assistant", reasoning_content, model="MARS-T1")
|
||
|
||
# 添加T1的回答到外部历史
|
||
conversation_history.append({
|
||
"role": "assistant",
|
||
"content": reasoning_content,
|
||
"model": "MARS-T1"
|
||
})
|
||
|
||
# 添加T1的回答到T1消息历史
|
||
t1_messages.append({
|
||
"role": "assistant",
|
||
"content": reasoning_content
|
||
})
|
||
|
||
# 如果没有工具调用,使用T1的推理作为最终答案
|
||
if not tool_calls_list:
|
||
# 添加相同的回答作为R1的回答(因为没有工具调用)
|
||
conversation_history.append({
|
||
"role": "assistant",
|
||
"content": reasoning_content,
|
||
"model": "MARS-R1"
|
||
})
|
||
|
||
display_message("assistant", reasoning_content, model="MARS-R1")
|
||
break
|
||
|
||
# 执行工具调用并获取结果
|
||
tool_call_results = get_all_tool_calls_results(tool_calls_list)
|
||
tool_call_results_str = "\n".join(tool_call_results)
|
||
|
||
# 添加工具调用结果到外部历史
|
||
conversation_history.append({
|
||
"role": "tool",
|
||
"content": tool_call_results_str
|
||
})
|
||
|
||
# 显示工具调用结果
|
||
display_message("tool", tool_call_results_str)
|
||
|
||
# 重置R1消息历史(每次迭代都重新构建)
|
||
r1_messages = []
|
||
|
||
# 添加系统消息,指导R1如何处理信息
|
||
r1_messages.append({
|
||
"role": "system",
|
||
"content": """你是一个能够分析工具调用结果并回答问题的助手。
|
||
请分析提供的信息,并执行以下操作之一:
|
||
1. 如果你能够基于提供的工具调用信息直接回答原始问题,请提供完整的回答。
|
||
2. 如果目前的工具调用信息不足以让你回答原始问题,请明确说明缺少哪些信息,并生成一个新的问题来获取这些信息。
|
||
新问题格式:<FOLLOW_UP_QUESTION>你的问题</FOLLOW_UP_QUESTION>
|
||
|
||
注意:如果你生成了后续问题,系统将自动将其发送给工具调用模型以获取更多信息。"""
|
||
})
|
||
|
||
# 构建R1的用户消息,包含原始问题、工具调用信息和结果
|
||
r1_user_message = f"""# 原始问题
|
||
{user_input}
|
||
|
||
# 工具调用信息
|
||
{reasoning_content}
|
||
|
||
# 工具调用结果
|
||
{tool_call_results_str}"""
|
||
|
||
# 如果有后续问题,添加到R1用户消息
|
||
if iterations > 1:
|
||
r1_user_message += f"\n\n# 后续问题\n{current_question}"
|
||
|
||
# 添加构建好的用户消息
|
||
r1_messages.append({
|
||
"role": "user",
|
||
"content": r1_user_message
|
||
})
|
||
|
||
# 获取R1模型的响应
|
||
r1_response = get_response_from_r1(r1_messages)
|
||
|
||
# 显示R1回答
|
||
display_message("assistant", r1_response, model="MARS-R1")
|
||
|
||
# 检查R1是否生成了后续问题
|
||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', r1_response, re.DOTALL)
|
||
|
||
if follow_up_match:
|
||
# 提取后续问题
|
||
follow_up_question = follow_up_match.group(1).strip()
|
||
|
||
# 将后续问题作为新的当前问题
|
||
current_question = follow_up_question
|
||
|
||
# 添加R1的回答到外部历史(不包括后续问题标记)
|
||
clean_response = r1_response.replace(follow_up_match.group(0), "")
|
||
conversation_history.append({
|
||
"role": "assistant",
|
||
"content": clean_response,
|
||
"model": "MARS-R1"
|
||
})
|
||
|
||
# 继续循环,使用新问题调用T1
|
||
else:
|
||
# R1能够回答问题,添加回答到历史并结束循环
|
||
conversation_history.append({
|
||
"role": "assistant",
|
||
"content": r1_response,
|
||
"model": "MARS-R1"
|
||
})
|
||
break
|
||
|
||
return conversation_history
|
||
|
||
def run_demo():
|
||
"""运行演示,使用初始消息作为第一个用户问题"""
|
||
console.print(Panel.fit(
|
||
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
|
||
border_style="cyan"
|
||
))
|
||
|
||
# 获取初始用户问题
|
||
initial_user_input = initial_message[0]["content"]
|
||
|
||
# 处理第一轮对话
|
||
conversation_history = process_conversation_round(initial_user_input)
|
||
|
||
# 检查R1是否生成了后续问题并自动处理
|
||
auto_process_follow_up_questions(conversation_history)
|
||
|
||
# 多轮对话循环
|
||
while True:
|
||
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit' 或 'quit' 退出[/bold cyan]")
|
||
user_input = input("> ")
|
||
|
||
# 检查是否退出
|
||
if user_input.lower() in ['exit', 'quit', '退出']:
|
||
console.print("[bold cyan]演示结束,再见![/bold cyan]")
|
||
break
|
||
|
||
# 处理用户输入
|
||
conversation_history = process_conversation_round(user_input, conversation_history)
|
||
|
||
# 检查R1是否生成了后续问题并自动处理
|
||
auto_process_follow_up_questions(conversation_history)
|
||
|
||
def auto_process_follow_up_questions(conversation_history):
|
||
"""自动处理R1生成的后续问题"""
|
||
# 检查最后一条消息是否是R1的回答
|
||
if not conversation_history or len(conversation_history) == 0:
|
||
return
|
||
|
||
last_message = conversation_history[-1]
|
||
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
|
||
# 检查是否包含后续问题
|
||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
|
||
if follow_up_match:
|
||
# 提取后续问题
|
||
follow_up_question = follow_up_match.group(1).strip()
|
||
|
||
# 显示检测到的后续问题
|
||
console.print(Panel(
|
||
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
|
||
border_style="yellow",
|
||
expand=False
|
||
))
|
||
|
||
# 自动处理后续问题
|
||
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
|
||
|
||
# 递归处理后续问题,直到没有更多后续问题或达到最大迭代次数
|
||
max_auto_iterations = 3
|
||
current_iterations = 0
|
||
|
||
while current_iterations < max_auto_iterations:
|
||
current_iterations += 1
|
||
|
||
# 处理后续问题
|
||
conversation_history = process_conversation_round(follow_up_question, conversation_history)
|
||
|
||
# 检查是否还有后续问题
|
||
if len(conversation_history) > 0:
|
||
last_message = conversation_history[-1]
|
||
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
|
||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
|
||
if follow_up_match:
|
||
# 提取后续问题
|
||
follow_up_question = follow_up_match.group(1).strip()
|
||
|
||
# 显示检测到的后续问题
|
||
console.print(Panel(
|
||
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
|
||
border_style="yellow",
|
||
expand=False
|
||
))
|
||
|
||
# 自动处理后续问题
|
||
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
|
||
continue
|
||
|
||
# 如果没有更多后续问题,退出循环
|
||
break
|
||
|
||
if __name__ == '__main__':
|
||
try:
|
||
run_demo()
|
||
except KeyboardInterrupt:
|
||
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
|
||
except Exception as e:
|
||
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
|
||
import traceback
|
||
console.print(traceback.format_exc())
|