402 lines
13 KiB
Python
Executable File
402 lines
13 KiB
Python
Executable File
import asyncio
|
|
from tools_for_ms import *
|
|
from api_key import *
|
|
from openai import OpenAI
|
|
import json
|
|
from typing import Dict, List, Any, Union, Optional
|
|
from rich.console import Console
|
|
|
|
# 获取工具
|
|
tools = get_tools()
|
|
tools_schema = get_tool_schemas()
|
|
tool_map = {tool_name: tool for tool_name, tool in tools.items()}
|
|
|
|
console = Console()
|
|
console.print(tools_schema)
|
|
|
|
class DualModelAgent:
|
|
"""
|
|
同时支持 qwq 和 gpt4o 两种大模型的代理类
|
|
处理两种不同的返回值格式并提供统一的工具调用接口
|
|
"""
|
|
|
|
def __init__(self, qwq_model_name: str = "qwq-32b", gpt_model_name: str = "gpt-4o"):
|
|
"""
|
|
初始化两个模型的客户端
|
|
|
|
Args:
|
|
qwq_model_name: qwq 模型名称
|
|
gpt_model_name: gpt 模型名称
|
|
"""
|
|
# 初始化 qwq 客户端
|
|
self.qwq_client = OpenAI(
|
|
api_key=OPENAI_API_KEY,
|
|
base_url=OPENAI_API_URL,
|
|
)
|
|
|
|
# 初始化 gpt 客户端 (可以使用不同的 API 密钥和 URL)
|
|
self.gpt_client = OpenAI(
|
|
api_key=OPENAI_API_KEY,
|
|
base_url=OPENAI_API_URL,
|
|
)
|
|
|
|
# 模型名称
|
|
self.qwq_model_name = qwq_model_name
|
|
self.gpt_model_name = gpt_model_name
|
|
|
|
# 定义工具列表
|
|
self.tools = tools_schema
|
|
|
|
def get_qwq_response(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""
|
|
获取 qwq 模型的响应(返回字典格式)
|
|
|
|
Args:
|
|
messages: 消息列表
|
|
|
|
Returns:
|
|
字典格式的响应
|
|
"""
|
|
completion = self.qwq_client.chat.completions.create(
|
|
model=self.qwq_model_name,
|
|
messages=messages,
|
|
tools=self.tools,
|
|
temperature=0.6,
|
|
tool_choice='auto'
|
|
)
|
|
# qwq 返回的是对象,需要转换为字典
|
|
return completion.model_dump()
|
|
|
|
def get_gpt_response(self, messages: List[Dict[str, Any]]) -> Any:
|
|
"""
|
|
获取 gpt 模型的响应(返回对象格式)
|
|
|
|
Args:
|
|
messages: 消息列表
|
|
|
|
Returns:
|
|
对象格式的响应
|
|
"""
|
|
completion = self.gpt_client.chat.completions.create(
|
|
model=self.gpt_model_name,
|
|
messages=messages,
|
|
tools=self.tools,
|
|
tool_choice="auto",
|
|
temperature=0.6,
|
|
)
|
|
# gpt 返回的是对象,直接返回
|
|
return completion
|
|
|
|
def extract_tool_calls_from_qwq(self, response: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
|
"""
|
|
从 qwq 响应中提取工具调用信息
|
|
|
|
Args:
|
|
response: qwq 响应字典
|
|
|
|
Returns:
|
|
工具调用列表,如果没有则返回 None
|
|
"""
|
|
assistant_message = response['choices'][0]['message']
|
|
console.print("assistant_message",assistant_message)
|
|
return assistant_message.get('tool_calls')
|
|
|
|
def extract_tool_calls_from_gpt(self, response: Any) -> Optional[List[Any]]:
|
|
"""
|
|
从 gpt 响应中提取工具调用信息
|
|
|
|
Args:
|
|
response: gpt 响应对象
|
|
|
|
Returns:
|
|
工具调用列表,如果没有则返回 None
|
|
"""
|
|
if hasattr(response.choices[0].message, 'tool_calls'):
|
|
return response.choices[0].message.tool_calls
|
|
return None
|
|
|
|
def extract_content_from_qwq(self, response: Dict[str, Any]) -> str:
|
|
"""
|
|
从 qwq 响应中提取内容
|
|
|
|
Args:
|
|
response: qwq 响应字典
|
|
|
|
Returns:
|
|
内容字符串
|
|
"""
|
|
content = response['choices'][0]['message'].get('content')
|
|
return content if content is not None else ""
|
|
|
|
def extract_content_from_gpt(self, response: Any) -> str:
|
|
"""
|
|
从 gpt 响应中提取内容
|
|
|
|
Args:
|
|
response: gpt 响应对象
|
|
|
|
Returns:
|
|
内容字符串
|
|
"""
|
|
content = response.choices[0].message.content
|
|
return content if content is not None else ""
|
|
|
|
def extract_finish_reason_from_qwq(self, response: Dict[str, Any]) -> str:
|
|
"""
|
|
从 qwq 响应中提取完成原因
|
|
|
|
Args:
|
|
response: qwq 响应字典
|
|
|
|
Returns:
|
|
完成原因
|
|
"""
|
|
return response['choices'][0]['finish_reason']
|
|
|
|
def extract_finish_reason_from_gpt(self, response: Any) -> str:
|
|
"""
|
|
从 gpt 响应中提取完成原因
|
|
|
|
Args:
|
|
response: gpt 响应对象
|
|
|
|
Returns:
|
|
完成原因
|
|
"""
|
|
return response.choices[0].finish_reason
|
|
|
|
async def call_tool(self, tool_name: str, tool_arguments: Dict[str, Any]) -> str:
|
|
"""
|
|
调用工具函数
|
|
|
|
Args:
|
|
tool_name: 工具名称
|
|
tool_arguments: 工具参数
|
|
|
|
Returns:
|
|
工具执行结果
|
|
"""
|
|
if tool_name in tools:
|
|
tool_function = tools[tool_name]
|
|
try:
|
|
# 异步调用工具函数
|
|
tool_result = await tool_function(**tool_arguments)
|
|
return tool_result
|
|
except Exception as e:
|
|
return f"工具调用错误: {str(e)}"
|
|
else:
|
|
return f"未找到工具: {tool_name}"
|
|
|
|
async def chat_with_qwq(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
|
|
"""
|
|
与 qwq 模型对话,支持工具调用
|
|
|
|
Args:
|
|
messages: 初始消息列表
|
|
max_turns: 最大对话轮数
|
|
|
|
Returns:
|
|
最终回答
|
|
"""
|
|
|
|
current_messages = messages.copy()
|
|
turn = 0
|
|
|
|
while turn < max_turns:
|
|
turn += 1
|
|
console.print(f"\n[bold cyan]第 {turn} 轮 QWQ 对话[/bold cyan]")
|
|
console.print("message",current_messages)
|
|
# 获取 qwq 响应
|
|
response = self.get_qwq_response(current_messages)
|
|
assistant_message = response['choices'][0]['message']
|
|
if assistant_message['content'] is None:
|
|
assistant_message['content'] = ""
|
|
console.print("message", assistant_message)
|
|
# 将助手消息添加到上下文
|
|
current_messages.append(assistant_message)
|
|
|
|
# 提取内容和工具调用
|
|
content = self.extract_content_from_qwq(response)
|
|
tool_calls = self.extract_tool_calls_from_qwq(response)
|
|
finish_reason = self.extract_finish_reason_from_qwq(response)
|
|
|
|
console.print(f"[green]助手回复:[/green] {content}")
|
|
|
|
# 如果没有工具调用或已完成,返回内容
|
|
if tool_calls is None or finish_reason != "tool_calls":
|
|
return content
|
|
|
|
# 处理工具调用
|
|
for tool_call in tool_calls:
|
|
tool_call_name = tool_call['function']['name']
|
|
tool_call_arguments = json.loads(tool_call['function']['arguments'])
|
|
|
|
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
|
|
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
|
|
|
|
# 执行工具调用
|
|
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
|
|
console.print(f"[blue]工具结果:[/blue] {tool_result}")
|
|
|
|
# 添加工具结果到上下文
|
|
current_messages.append({
|
|
"role": "tool",
|
|
"name": tool_call_name,
|
|
"content": tool_result
|
|
})
|
|
|
|
return "达到最大对话轮数限制"
|
|
|
|
async def chat_with_gpt(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
|
|
"""
|
|
与 gpt 模型对话,支持工具调用
|
|
|
|
Args:
|
|
messages: 初始消息列表
|
|
max_turns: 最大对话轮数
|
|
|
|
Returns:
|
|
最终回答
|
|
"""
|
|
current_messages = messages.copy()
|
|
turn = 0
|
|
|
|
while turn < max_turns:
|
|
turn += 1
|
|
console.print(f"\n[bold magenta]第 {turn} 轮 GPT 对话[/bold magenta]")
|
|
|
|
# 获取 gpt 响应
|
|
response = self.get_gpt_response(current_messages)
|
|
assistant_message = response.choices[0].message
|
|
|
|
# 将助手消息添加到上下文
|
|
current_messages.append(assistant_message.model_dump())
|
|
|
|
# 提取内容和工具调用
|
|
content = self.extract_content_from_gpt(response)
|
|
tool_calls = self.extract_tool_calls_from_gpt(response)
|
|
finish_reason = self.extract_finish_reason_from_gpt(response)
|
|
|
|
console.print(f"[green]助手回复:[/green] {content}")
|
|
|
|
# 如果没有工具调用或已完成,返回内容
|
|
if tool_calls is None or finish_reason != "tool_calls":
|
|
return content
|
|
|
|
# 处理工具调用
|
|
for tool_call in tool_calls:
|
|
tool_call_name = tool_call.function.name
|
|
tool_call_arguments = json.loads(tool_call.function.arguments)
|
|
|
|
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
|
|
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
|
|
|
|
# 执行工具调用
|
|
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
|
|
console.print(f"[blue]工具结果:[/blue] {tool_result}")
|
|
|
|
# 添加工具结果到上下文
|
|
current_messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call.id,
|
|
"name": tool_call_name,
|
|
"content": tool_result
|
|
})
|
|
|
|
return "达到最大对话轮数限制"
|
|
|
|
async def chat_with_both_models(self, user_input: str, system_prompt: str = "你是一个有用的助手。") -> Dict[str, str]:
|
|
"""
|
|
同时与两个模型对话,比较它们的回答
|
|
|
|
Args:
|
|
user_input: 用户输入
|
|
system_prompt: 系统提示
|
|
|
|
Returns:
|
|
包含两个模型回答的字典
|
|
"""
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_input}
|
|
]
|
|
|
|
# 并行调用两个模型
|
|
qwq_task = asyncio.create_task(self.chat_with_qwq(messages))
|
|
gpt_task = asyncio.create_task(self.chat_with_gpt(messages))
|
|
|
|
# 等待两个任务完成
|
|
qwq_response, gpt_response = await asyncio.gather(qwq_task, gpt_task)
|
|
|
|
return {
|
|
"qwq": qwq_response,
|
|
"gpt": gpt_response
|
|
}
|
|
|
|
|
|
async def main():
|
|
"""主函数"""
|
|
# 创建双模型代理
|
|
agent = DualModelAgent()
|
|
|
|
while True:
|
|
# 获取用户输入
|
|
user_input = input("\n请输入问题(输入 'exit' 退出): ")
|
|
#user_input='现在的时间'
|
|
if user_input.lower() == 'exit':
|
|
break
|
|
|
|
# 选择模型
|
|
model_choice = input("选择模型 (1: QWQ, 2: GPT, 3: 两者): ")
|
|
|
|
|
|
try:
|
|
if model_choice == '1':
|
|
# 使用 QWQ 模型
|
|
messages = [
|
|
{"role": "system", "content": "你是一个有用的助手。"},
|
|
{"role": "user", "content": user_input}
|
|
]
|
|
response = await agent.chat_with_qwq(messages)
|
|
console.print(f"\n[bold cyan]QWQ 最终回答:[/bold cyan] {response}")
|
|
|
|
elif model_choice == '2':
|
|
# 使用 GPT 模型
|
|
messages = [
|
|
{"role": "system", "content": "你是一个有用的助手。"},
|
|
{"role": "user", "content": user_input}
|
|
]
|
|
response = await agent.chat_with_gpt(messages)
|
|
console.print(f"\n[bold magenta]GPT 最终回答:[/bold magenta] {response}")
|
|
|
|
elif model_choice == '3':
|
|
# 同时使用两个模型
|
|
responses = await agent.chat_with_both_models(user_input)
|
|
console.print(f"\n[bold cyan]QWQ 最终回答:[/bold cyan] {responses['qwq']}")
|
|
console.print(f"\n[bold magenta]GPT 最终回答:[/bold magenta] {responses['gpt']}")
|
|
|
|
else:
|
|
console.print("[bold red]无效的选择,请输入 1、2 或 3[/bold red]")
|
|
|
|
except Exception as e:
|
|
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
asyncio.run(main())
|
|
#from tools_for_ms.services_tools.mp_service_tools import search_material_property_from_material_project
|
|
#asyncio.run(search_material_property_from_material_project('Fe2O3'))
|
|
|
|
# 测试工具函数
|
|
|
|
# tool_name = 'get_crystal_structures_from_materials_project'
|
|
# result = asyncio.run(test_tool(tool_name))
|
|
# print(result)
|
|
|
|
#pass
|
|
|
|
# 知识检索API的接口 数据库
|