工具函数零散能用版
This commit is contained in:
401
agent_test.py
Normal file
401
agent_test.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import asyncio
|
||||
from tools_for_ms import *
|
||||
from api_key import *
|
||||
from openai import OpenAI
|
||||
import json
|
||||
from typing import Dict, List, Any, Union, Optional
|
||||
from rich.console import Console
|
||||
|
||||
# 获取工具
|
||||
tools = get_tools()
|
||||
tools_schema = get_tool_schemas()
|
||||
tool_map = {tool_name: tool for tool_name, tool in tools.items()}
|
||||
|
||||
console = Console()
|
||||
console.print(tools_schema)
|
||||
|
||||
class DualModelAgent:
|
||||
"""
|
||||
同时支持 qwq 和 gpt4o 两种大模型的代理类
|
||||
处理两种不同的返回值格式并提供统一的工具调用接口
|
||||
"""
|
||||
|
||||
def __init__(self, qwq_model_name: str = "qwq-32b", gpt_model_name: str = "gpt-4o"):
|
||||
"""
|
||||
初始化两个模型的客户端
|
||||
|
||||
Args:
|
||||
qwq_model_name: qwq 模型名称
|
||||
gpt_model_name: gpt 模型名称
|
||||
"""
|
||||
# 初始化 qwq 客户端
|
||||
self.qwq_client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_URL,
|
||||
)
|
||||
|
||||
# 初始化 gpt 客户端 (可以使用不同的 API 密钥和 URL)
|
||||
self.gpt_client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_URL,
|
||||
)
|
||||
|
||||
# 模型名称
|
||||
self.qwq_model_name = qwq_model_name
|
||||
self.gpt_model_name = gpt_model_name
|
||||
|
||||
# 定义工具列表
|
||||
self.tools = tools_schema
|
||||
|
||||
def get_qwq_response(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 qwq 模型的响应(返回字典格式)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
字典格式的响应
|
||||
"""
|
||||
completion = self.qwq_client.chat.completions.create(
|
||||
model=self.qwq_model_name,
|
||||
messages=messages,
|
||||
tools=self.tools,
|
||||
temperature=0.6,
|
||||
tool_choice='auto'
|
||||
)
|
||||
# qwq 返回的是对象,需要转换为字典
|
||||
return completion.model_dump()
|
||||
|
||||
def get_gpt_response(self, messages: List[Dict[str, Any]]) -> Any:
|
||||
"""
|
||||
获取 gpt 模型的响应(返回对象格式)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
对象格式的响应
|
||||
"""
|
||||
completion = self.gpt_client.chat.completions.create(
|
||||
model=self.gpt_model_name,
|
||||
messages=messages,
|
||||
tools=self.tools,
|
||||
tool_choice="auto",
|
||||
temperature=0.6,
|
||||
)
|
||||
# gpt 返回的是对象,直接返回
|
||||
return completion
|
||||
|
||||
def extract_tool_calls_from_qwq(self, response: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
从 qwq 响应中提取工具调用信息
|
||||
|
||||
Args:
|
||||
response: qwq 响应字典
|
||||
|
||||
Returns:
|
||||
工具调用列表,如果没有则返回 None
|
||||
"""
|
||||
assistant_message = response['choices'][0]['message']
|
||||
console.print("assistant_message",assistant_message)
|
||||
return assistant_message.get('tool_calls')
|
||||
|
||||
def extract_tool_calls_from_gpt(self, response: Any) -> Optional[List[Any]]:
|
||||
"""
|
||||
从 gpt 响应中提取工具调用信息
|
||||
|
||||
Args:
|
||||
response: gpt 响应对象
|
||||
|
||||
Returns:
|
||||
工具调用列表,如果没有则返回 None
|
||||
"""
|
||||
if hasattr(response.choices[0].message, 'tool_calls'):
|
||||
return response.choices[0].message.tool_calls
|
||||
return None
|
||||
|
||||
def extract_content_from_qwq(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
从 qwq 响应中提取内容
|
||||
|
||||
Args:
|
||||
response: qwq 响应字典
|
||||
|
||||
Returns:
|
||||
内容字符串
|
||||
"""
|
||||
content = response['choices'][0]['message'].get('content')
|
||||
return content if content is not None else ""
|
||||
|
||||
def extract_content_from_gpt(self, response: Any) -> str:
|
||||
"""
|
||||
从 gpt 响应中提取内容
|
||||
|
||||
Args:
|
||||
response: gpt 响应对象
|
||||
|
||||
Returns:
|
||||
内容字符串
|
||||
"""
|
||||
content = response.choices[0].message.content
|
||||
return content if content is not None else ""
|
||||
|
||||
def extract_finish_reason_from_qwq(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
从 qwq 响应中提取完成原因
|
||||
|
||||
Args:
|
||||
response: qwq 响应字典
|
||||
|
||||
Returns:
|
||||
完成原因
|
||||
"""
|
||||
return response['choices'][0]['finish_reason']
|
||||
|
||||
def extract_finish_reason_from_gpt(self, response: Any) -> str:
|
||||
"""
|
||||
从 gpt 响应中提取完成原因
|
||||
|
||||
Args:
|
||||
response: gpt 响应对象
|
||||
|
||||
Returns:
|
||||
完成原因
|
||||
"""
|
||||
return response.choices[0].finish_reason
|
||||
|
||||
async def call_tool(self, tool_name: str, tool_arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具函数
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
if tool_name in tools:
|
||||
tool_function = tools[tool_name]
|
||||
try:
|
||||
# 异步调用工具函数
|
||||
tool_result = await tool_function(**tool_arguments)
|
||||
return tool_result
|
||||
except Exception as e:
|
||||
return f"工具调用错误: {str(e)}"
|
||||
else:
|
||||
return f"未找到工具: {tool_name}"
|
||||
|
||||
async def chat_with_qwq(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
|
||||
"""
|
||||
与 qwq 模型对话,支持工具调用
|
||||
|
||||
Args:
|
||||
messages: 初始消息列表
|
||||
max_turns: 最大对话轮数
|
||||
|
||||
Returns:
|
||||
最终回答
|
||||
"""
|
||||
|
||||
current_messages = messages.copy()
|
||||
turn = 0
|
||||
|
||||
while turn < max_turns:
|
||||
turn += 1
|
||||
console.print(f"\n[bold cyan]第 {turn} 轮 QWQ 对话[/bold cyan]")
|
||||
console.print("message",current_messages)
|
||||
# 获取 qwq 响应
|
||||
response = self.get_qwq_response(current_messages)
|
||||
assistant_message = response['choices'][0]['message']
|
||||
if assistant_message['content'] is None:
|
||||
assistant_message['content'] = ""
|
||||
console.print("message", assistant_message)
|
||||
# 将助手消息添加到上下文
|
||||
current_messages.append(assistant_message)
|
||||
|
||||
# 提取内容和工具调用
|
||||
content = self.extract_content_from_qwq(response)
|
||||
tool_calls = self.extract_tool_calls_from_qwq(response)
|
||||
finish_reason = self.extract_finish_reason_from_qwq(response)
|
||||
|
||||
console.print(f"[green]助手回复:[/green] {content}")
|
||||
|
||||
# 如果没有工具调用或已完成,返回内容
|
||||
if tool_calls is None or finish_reason != "tool_calls":
|
||||
return content
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in tool_calls:
|
||||
tool_call_name = tool_call['function']['name']
|
||||
tool_call_arguments = json.loads(tool_call['function']['arguments'])
|
||||
|
||||
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
|
||||
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
|
||||
|
||||
# 执行工具调用
|
||||
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
|
||||
console.print(f"[blue]工具结果:[/blue] {tool_result}")
|
||||
|
||||
# 添加工具结果到上下文
|
||||
current_messages.append({
|
||||
"role": "tool",
|
||||
"name": tool_call_name,
|
||||
"content": tool_result
|
||||
})
|
||||
|
||||
return "达到最大对话轮数限制"
|
||||
|
||||
async def chat_with_gpt(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
|
||||
"""
|
||||
与 gpt 模型对话,支持工具调用
|
||||
|
||||
Args:
|
||||
messages: 初始消息列表
|
||||
max_turns: 最大对话轮数
|
||||
|
||||
Returns:
|
||||
最终回答
|
||||
"""
|
||||
current_messages = messages.copy()
|
||||
turn = 0
|
||||
|
||||
while turn < max_turns:
|
||||
turn += 1
|
||||
console.print(f"\n[bold magenta]第 {turn} 轮 GPT 对话[/bold magenta]")
|
||||
|
||||
# 获取 gpt 响应
|
||||
response = self.get_gpt_response(current_messages)
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# 将助手消息添加到上下文
|
||||
current_messages.append(assistant_message.model_dump())
|
||||
|
||||
# 提取内容和工具调用
|
||||
content = self.extract_content_from_gpt(response)
|
||||
tool_calls = self.extract_tool_calls_from_gpt(response)
|
||||
finish_reason = self.extract_finish_reason_from_gpt(response)
|
||||
|
||||
console.print(f"[green]助手回复:[/green] {content}")
|
||||
|
||||
# 如果没有工具调用或已完成,返回内容
|
||||
if tool_calls is None or finish_reason != "tool_calls":
|
||||
return content
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in tool_calls:
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = json.loads(tool_call.function.arguments)
|
||||
|
||||
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
|
||||
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
|
||||
|
||||
# 执行工具调用
|
||||
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
|
||||
console.print(f"[blue]工具结果:[/blue] {tool_result}")
|
||||
|
||||
# 添加工具结果到上下文
|
||||
current_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": tool_result
|
||||
})
|
||||
|
||||
return "达到最大对话轮数限制"
|
||||
|
||||
async def chat_with_both_models(self, user_input: str, system_prompt: str = "你是一个有用的助手。") -> Dict[str, str]:
|
||||
"""
|
||||
同时与两个模型对话,比较它们的回答
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
system_prompt: 系统提示
|
||||
|
||||
Returns:
|
||||
包含两个模型回答的字典
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_input}
|
||||
]
|
||||
|
||||
# 并行调用两个模型
|
||||
qwq_task = asyncio.create_task(self.chat_with_qwq(messages))
|
||||
gpt_task = asyncio.create_task(self.chat_with_gpt(messages))
|
||||
|
||||
# 等待两个任务完成
|
||||
qwq_response, gpt_response = await asyncio.gather(qwq_task, gpt_task)
|
||||
|
||||
return {
|
||||
"qwq": qwq_response,
|
||||
"gpt": gpt_response
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
# 创建双模型代理
|
||||
agent = DualModelAgent()
|
||||
|
||||
while True:
|
||||
# 获取用户输入
|
||||
user_input = input("\n请输入问题(输入 'exit' 退出): ")
|
||||
#user_input='现在的时间'
|
||||
if user_input.lower() == 'exit':
|
||||
break
|
||||
|
||||
# 选择模型
|
||||
model_choice = input("选择模型 (1: QWQ, 2: GPT, 3: 两者): ")
|
||||
|
||||
|
||||
try:
|
||||
if model_choice == '1':
|
||||
# 使用 QWQ 模型
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": user_input}
|
||||
]
|
||||
response = await agent.chat_with_qwq(messages)
|
||||
console.print(f"\n[bold cyan]QWQ 最终回答:[/bold cyan] {response}")
|
||||
|
||||
elif model_choice == '2':
|
||||
# 使用 GPT 模型
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": user_input}
|
||||
]
|
||||
response = await agent.chat_with_gpt(messages)
|
||||
console.print(f"\n[bold magenta]GPT 最终回答:[/bold magenta] {response}")
|
||||
|
||||
elif model_choice == '3':
|
||||
# 同时使用两个模型
|
||||
responses = await agent.chat_with_both_models(user_input)
|
||||
console.print(f"\n[bold cyan]QWQ 最终回答:[/bold cyan] {responses['qwq']}")
|
||||
console.print(f"\n[bold magenta]GPT 最终回答:[/bold magenta] {responses['gpt']}")
|
||||
|
||||
else:
|
||||
console.print("[bold red]无效的选择,请输入 1、2 或 3[/bold red]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
asyncio.run(main())
|
||||
#from tools_for_ms.services_tools.mp_service_tools import search_material_property_from_material_project
|
||||
#asyncio.run(search_material_property_from_material_project('Fe2O3'))
|
||||
|
||||
# 测试工具函数
|
||||
|
||||
# tool_name = 'get_crystal_structures_from_materials_project'
|
||||
# result = asyncio.run(test_tool(tool_name))
|
||||
# print(result)
|
||||
|
||||
#pass
|
||||
|
||||
# 知识检索API的接口 数据库
|
||||
Reference in New Issue
Block a user