Files
mars-mcp/agent_test.py
2025-04-16 11:15:01 +08:00

402 lines
13 KiB
Python
Executable File

import asyncio
from tools_for_ms import *
from api_key import *
from openai import OpenAI
import json
from typing import Dict, List, Any, Union, Optional
from rich.console import Console
# 获取工具
tools = get_tools()
tools_schema = get_tool_schemas()
tool_map = {tool_name: tool for tool_name, tool in tools.items()}
console = Console()
console.print(tools_schema)
class DualModelAgent:
"""
同时支持 qwq 和 gpt4o 两种大模型的代理类
处理两种不同的返回值格式并提供统一的工具调用接口
"""
def __init__(self, qwq_model_name: str = "qwq-32b", gpt_model_name: str = "gpt-4o"):
"""
初始化两个模型的客户端
Args:
qwq_model_name: qwq 模型名称
gpt_model_name: gpt 模型名称
"""
# 初始化 qwq 客户端
self.qwq_client = OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_API_URL,
)
# 初始化 gpt 客户端 (可以使用不同的 API 密钥和 URL)
self.gpt_client = OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_API_URL,
)
# 模型名称
self.qwq_model_name = qwq_model_name
self.gpt_model_name = gpt_model_name
# 定义工具列表
self.tools = tools_schema
def get_qwq_response(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
获取 qwq 模型的响应(返回字典格式)
Args:
messages: 消息列表
Returns:
字典格式的响应
"""
completion = self.qwq_client.chat.completions.create(
model=self.qwq_model_name,
messages=messages,
tools=self.tools,
temperature=0.6,
tool_choice='auto'
)
# qwq 返回的是对象,需要转换为字典
return completion.model_dump()
def get_gpt_response(self, messages: List[Dict[str, Any]]) -> Any:
"""
获取 gpt 模型的响应(返回对象格式)
Args:
messages: 消息列表
Returns:
对象格式的响应
"""
completion = self.gpt_client.chat.completions.create(
model=self.gpt_model_name,
messages=messages,
tools=self.tools,
tool_choice="auto",
temperature=0.6,
)
# gpt 返回的是对象,直接返回
return completion
def extract_tool_calls_from_qwq(self, response: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
"""
从 qwq 响应中提取工具调用信息
Args:
response: qwq 响应字典
Returns:
工具调用列表,如果没有则返回 None
"""
assistant_message = response['choices'][0]['message']
console.print("assistant_message",assistant_message)
return assistant_message.get('tool_calls')
def extract_tool_calls_from_gpt(self, response: Any) -> Optional[List[Any]]:
"""
从 gpt 响应中提取工具调用信息
Args:
response: gpt 响应对象
Returns:
工具调用列表,如果没有则返回 None
"""
if hasattr(response.choices[0].message, 'tool_calls'):
return response.choices[0].message.tool_calls
return None
def extract_content_from_qwq(self, response: Dict[str, Any]) -> str:
"""
从 qwq 响应中提取内容
Args:
response: qwq 响应字典
Returns:
内容字符串
"""
content = response['choices'][0]['message'].get('content')
return content if content is not None else ""
def extract_content_from_gpt(self, response: Any) -> str:
"""
从 gpt 响应中提取内容
Args:
response: gpt 响应对象
Returns:
内容字符串
"""
content = response.choices[0].message.content
return content if content is not None else ""
def extract_finish_reason_from_qwq(self, response: Dict[str, Any]) -> str:
"""
从 qwq 响应中提取完成原因
Args:
response: qwq 响应字典
Returns:
完成原因
"""
return response['choices'][0]['finish_reason']
def extract_finish_reason_from_gpt(self, response: Any) -> str:
"""
从 gpt 响应中提取完成原因
Args:
response: gpt 响应对象
Returns:
完成原因
"""
return response.choices[0].finish_reason
async def call_tool(self, tool_name: str, tool_arguments: Dict[str, Any]) -> str:
"""
调用工具函数
Args:
tool_name: 工具名称
tool_arguments: 工具参数
Returns:
工具执行结果
"""
if tool_name in tools:
tool_function = tools[tool_name]
try:
# 异步调用工具函数
tool_result = await tool_function(**tool_arguments)
return tool_result
except Exception as e:
return f"工具调用错误: {str(e)}"
else:
return f"未找到工具: {tool_name}"
async def chat_with_qwq(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
"""
与 qwq 模型对话,支持工具调用
Args:
messages: 初始消息列表
max_turns: 最大对话轮数
Returns:
最终回答
"""
current_messages = messages.copy()
turn = 0
while turn < max_turns:
turn += 1
console.print(f"\n[bold cyan]第 {turn} 轮 QWQ 对话[/bold cyan]")
console.print("message",current_messages)
# 获取 qwq 响应
response = self.get_qwq_response(current_messages)
assistant_message = response['choices'][0]['message']
if assistant_message['content'] is None:
assistant_message['content'] = ""
console.print("message", assistant_message)
# 将助手消息添加到上下文
current_messages.append(assistant_message)
# 提取内容和工具调用
content = self.extract_content_from_qwq(response)
tool_calls = self.extract_tool_calls_from_qwq(response)
finish_reason = self.extract_finish_reason_from_qwq(response)
console.print(f"[green]助手回复:[/green] {content}")
# 如果没有工具调用或已完成,返回内容
if tool_calls is None or finish_reason != "tool_calls":
return content
# 处理工具调用
for tool_call in tool_calls:
tool_call_name = tool_call['function']['name']
tool_call_arguments = json.loads(tool_call['function']['arguments'])
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
# 执行工具调用
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
console.print(f"[blue]工具结果:[/blue] {tool_result}")
# 添加工具结果到上下文
current_messages.append({
"role": "tool",
"name": tool_call_name,
"content": tool_result
})
return "达到最大对话轮数限制"
async def chat_with_gpt(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
"""
与 gpt 模型对话,支持工具调用
Args:
messages: 初始消息列表
max_turns: 最大对话轮数
Returns:
最终回答
"""
current_messages = messages.copy()
turn = 0
while turn < max_turns:
turn += 1
console.print(f"\n[bold magenta]第 {turn} 轮 GPT 对话[/bold magenta]")
# 获取 gpt 响应
response = self.get_gpt_response(current_messages)
assistant_message = response.choices[0].message
# 将助手消息添加到上下文
current_messages.append(assistant_message.model_dump())
# 提取内容和工具调用
content = self.extract_content_from_gpt(response)
tool_calls = self.extract_tool_calls_from_gpt(response)
finish_reason = self.extract_finish_reason_from_gpt(response)
console.print(f"[green]助手回复:[/green] {content}")
# 如果没有工具调用或已完成,返回内容
if tool_calls is None or finish_reason != "tool_calls":
return content
# 处理工具调用
for tool_call in tool_calls:
tool_call_name = tool_call.function.name
tool_call_arguments = json.loads(tool_call.function.arguments)
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
# 执行工具调用
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
console.print(f"[blue]工具结果:[/blue] {tool_result}")
# 添加工具结果到上下文
current_messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call_name,
"content": tool_result
})
return "达到最大对话轮数限制"
async def chat_with_both_models(self, user_input: str, system_prompt: str = "你是一个有用的助手。") -> Dict[str, str]:
"""
同时与两个模型对话,比较它们的回答
Args:
user_input: 用户输入
system_prompt: 系统提示
Returns:
包含两个模型回答的字典
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input}
]
# 并行调用两个模型
qwq_task = asyncio.create_task(self.chat_with_qwq(messages))
gpt_task = asyncio.create_task(self.chat_with_gpt(messages))
# 等待两个任务完成
qwq_response, gpt_response = await asyncio.gather(qwq_task, gpt_task)
return {
"qwq": qwq_response,
"gpt": gpt_response
}
async def main():
"""主函数"""
# 创建双模型代理
agent = DualModelAgent()
while True:
# 获取用户输入
user_input = input("\n请输入问题(输入 'exit' 退出): ")
#user_input='现在的时间'
if user_input.lower() == 'exit':
break
# 选择模型
model_choice = input("选择模型 (1: QWQ, 2: GPT, 3: 两者): ")
try:
if model_choice == '1':
# 使用 QWQ 模型
messages = [
{"role": "system", "content": "你是一个有用的助手。"},
{"role": "user", "content": user_input}
]
response = await agent.chat_with_qwq(messages)
console.print(f"\n[bold cyan]QWQ 最终回答:[/bold cyan] {response}")
elif model_choice == '2':
# 使用 GPT 模型
messages = [
{"role": "system", "content": "你是一个有用的助手。"},
{"role": "user", "content": user_input}
]
response = await agent.chat_with_gpt(messages)
console.print(f"\n[bold magenta]GPT 最终回答:[/bold magenta] {response}")
elif model_choice == '3':
# 同时使用两个模型
responses = await agent.chat_with_both_models(user_input)
console.print(f"\n[bold cyan]QWQ 最终回答:[/bold cyan] {responses['qwq']}")
console.print(f"\n[bold magenta]GPT 最终回答:[/bold magenta] {responses['gpt']}")
else:
console.print("[bold red]无效的选择,请输入 1、2 或 3[/bold red]")
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
if __name__ == "__main__":
asyncio.run(main())
#from tools_for_ms.services_tools.mp_service_tools import search_material_property_from_material_project
#asyncio.run(search_material_property_from_material_project('Fe2O3'))
# 测试工具函数
# tool_name = 'get_crystal_structures_from_materials_project'
# result = asyncio.run(test_tool(tool_name))
# print(result)
#pass
# 知识检索API的接口 数据库