178 lines
7.3 KiB
Python
178 lines
7.3 KiB
Python
import asyncio
|
||
import json
|
||
from openai import OpenAI
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from rich.text import Text
|
||
|
||
# 创建Rich控制台对象
|
||
console = Console()
|
||
|
||
# 定义分隔符样式
|
||
def print_separator(title=""):
|
||
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
|
||
|
||
api_key="gpustack_72b0d41ec69eddab_bce1ea964ddc277ac6aed46b67b03960"
|
||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||
client = OpenAI(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
from mars_toolkit import get_tool_schemas,get_tools
|
||
tools_schemas = get_tool_schemas()
|
||
tool_map = get_tools()
|
||
|
||
# 打印消息的函数
|
||
def print_message(message):
|
||
# 处理不同类型的消息对象
|
||
if hasattr(message, 'role'): # ChatCompletionMessage 对象
|
||
role = message.role
|
||
content = message.content if hasattr(message, 'content') else ""
|
||
# 如果是工具消息,获取工具名称
|
||
tool_name = None
|
||
if role == "tool" and hasattr(message, 'name'):
|
||
tool_name = message.name
|
||
else: # 字典类型
|
||
role = message.get("role", "unknown")
|
||
content = message.get("content", "")
|
||
# 如果是工具消息,获取工具名称
|
||
tool_name = message.get("name") if role == "tool" else None
|
||
|
||
# 根据角色选择不同的颜色
|
||
role_colors = {
|
||
"system": "bright_blue",
|
||
"user": "green",
|
||
"assistant": "yellow",
|
||
"tool": "bright_red"
|
||
}
|
||
color = role_colors.get(role, "white")
|
||
|
||
# 创建富文本面板
|
||
text = Text()
|
||
|
||
# 如果是工具消息,添加工具名称
|
||
if role == "tool" and tool_name:
|
||
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
|
||
else:
|
||
text.append(f"{role}: ", style=f"bold {color}")
|
||
|
||
text.append(str(content))
|
||
console.print(Panel(text, border_style=color))
|
||
|
||
messages = [
|
||
{"role": "system",
|
||
"content": "You are MARS-T1, a professional assistant in materials science. You first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> <structured_answer> </structured_answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here <structured_answer> structured answer here <structured_answer> </answer>'"},
|
||
{"role": "user", "content": "how to synthesize CsPbBr3 at room temperature"}
|
||
]
|
||
|
||
# 打印初始消息
|
||
print_separator("初始消息")
|
||
for message in messages:
|
||
print_message(message)
|
||
finish_reason = None
|
||
|
||
async def execute_tool(tool_name,tool_arguments):
|
||
tool_func = tool_map[tool_name] # 获取工具函数
|
||
arguments = {}
|
||
if tool_arguments:
|
||
# 检查arguments是字符串还是字典
|
||
if isinstance(tool_arguments, dict):
|
||
# 如果已经是字典,直接使用
|
||
arguments = tool_arguments
|
||
elif isinstance(tool_arguments, str):
|
||
# 如果是字符串,尝试解析为JSON
|
||
try:
|
||
# 尝试直接解析为JSON对象
|
||
arguments = json.loads(tool_arguments)
|
||
except json.JSONDecodeError:
|
||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||
# 尝试修复常见的JSON字符串问题
|
||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||
try:
|
||
arguments = json.loads(fixed_str)
|
||
except json.JSONDecodeError:
|
||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||
arguments = {"raw_string": tool_arguments}
|
||
|
||
# 调用工具函数
|
||
if asyncio.iscoroutinefunction(tool_func):
|
||
# 如果是异步函数,使用await调用
|
||
result = await tool_func(**arguments)
|
||
else:
|
||
# 如果是同步函数,直接调用
|
||
result = tool_func(**arguments)
|
||
# if func_name=='generate_material':
|
||
# print("xxxxx",result)
|
||
return result
|
||
|
||
while finish_reason is None or finish_reason == "tool_calls":
|
||
completion = client.chat.completions.create(
|
||
model="MARS-T1",
|
||
messages=messages,
|
||
temperature=0.3,
|
||
tools=tools_schemas, # <-- 我们通过 tools 参数,将定义好的 tools 提交给 Kimi 大模型
|
||
)
|
||
choice = completion.choices[0]
|
||
finish_reason = choice.finish_reason
|
||
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
|
||
# 打印assistant消息
|
||
print_separator("Assistant消息")
|
||
print_message(choice.message)
|
||
|
||
# 将ChatCompletionMessage对象转换为字典
|
||
assistant_message = {
|
||
"role": "assistant",
|
||
"content": choice.message.content if hasattr(choice.message, 'content') else None
|
||
}
|
||
|
||
# 如果有工具调用,添加到字典中
|
||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||
# 将tool_calls对象转换为字典列表
|
||
tool_calls_list = []
|
||
for tool_call in choice.message.tool_calls:
|
||
tool_call_dict = {
|
||
"id": tool_call.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_call.function.name,
|
||
"arguments": tool_call.function.arguments
|
||
}
|
||
}
|
||
tool_calls_list.append(tool_call_dict)
|
||
assistant_message["tool_calls"] = tool_calls_list
|
||
|
||
# 添加消息到上下文
|
||
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
|
||
|
||
# 打印工具调用信息
|
||
print_separator("工具调用")
|
||
for tool_call in choice.message.tool_calls:
|
||
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
|
||
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
|
||
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
|
||
console.print("")
|
||
|
||
tool_call_name = tool_call.function.name
|
||
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object,我们需要使用 json.loads 反序列化一下
|
||
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
|
||
|
||
# 构造工具响应消息
|
||
tool_message = {
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"name": tool_call_name,
|
||
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
|
||
}
|
||
|
||
# 打印工具响应
|
||
print_separator(f"工具响应: {tool_call_name}")
|
||
print_message(tool_message)
|
||
|
||
# 添加消息到上下文
|
||
messages.append(tool_message)
|
||
|
||
# 打印最终响应
|
||
if choice.message.content:
|
||
print_separator("最终响应")
|
||
console.print(Panel(choice.message.content, border_style="green"))
|