438 lines
16 KiB
Python
Executable File
438 lines
16 KiB
Python
Executable File
import asyncio
|
||
|
||
from api_key import *
|
||
from openai import OpenAI
|
||
import json
|
||
from typing import Dict, List, Any, Union, Optional
|
||
from rich.console import Console
|
||
import sys
|
||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||
from sci_mcp import *
|
||
|
||
# 获取工具
|
||
all_tools_schemas = get_all_tool_schemas()
|
||
tools = get_all_tools()
|
||
chemistry_tools = get_domain_tools("chemistry")
|
||
|
||
console = Console()
|
||
#console.print(all_tools_schemas)
|
||
|
||
class ModelAgent:
|
||
"""
|
||
只支持 gpt-4o 模型的代理类
|
||
处理返回值格式并提供统一的工具调用接口
|
||
"""
|
||
|
||
def __init__(self, model_name: str = "gpt-4o"):
|
||
"""
|
||
初始化模型客户端
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
"""
|
||
# 初始化客户端
|
||
self.client = OpenAI(
|
||
api_key=OPENAI_API_KEY,
|
||
base_url=OPENAI_API_URL,
|
||
)
|
||
|
||
# 模型名称
|
||
self.model_name = model_name
|
||
|
||
# 定义工具列表
|
||
self.tools = all_tools_schemas
|
||
def get_response(self, messages: List[Dict[str, Any]]) -> Any:
|
||
"""
|
||
获取模型的响应
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
|
||
Returns:
|
||
响应对象
|
||
"""
|
||
completion = self.client.chat.completions.create(
|
||
model=self.model_name,
|
||
messages=messages,
|
||
tools=self.tools,
|
||
tool_choice="auto",
|
||
temperature=0.6,
|
||
)
|
||
return completion
|
||
|
||
def extract_tool_calls(self, response: Any) -> Optional[List[Any]]:
|
||
"""
|
||
从响应中提取工具调用信息
|
||
|
||
Args:
|
||
response: 响应对象
|
||
|
||
Returns:
|
||
工具调用列表,如果没有则返回 None
|
||
"""
|
||
if hasattr(response.choices[0].message, 'tool_calls'):
|
||
return response.choices[0].message.tool_calls
|
||
return None
|
||
|
||
def extract_content(self, response: Any) -> str:
|
||
"""
|
||
从响应中提取内容
|
||
|
||
Args:
|
||
response: 响应对象
|
||
|
||
Returns:
|
||
内容字符串
|
||
"""
|
||
content = response.choices[0].message.content
|
||
return content if content is not None else ""
|
||
|
||
def extract_finish_reason(self, response: Any) -> str:
|
||
"""
|
||
从响应中提取完成原因
|
||
|
||
Args:
|
||
response: 响应对象
|
||
|
||
Returns:
|
||
完成原因
|
||
"""
|
||
return response.choices[0].finish_reason
|
||
|
||
async def call_tool(self, tool_name: str, tool_arguments: Dict[str, Any]) -> str:
|
||
"""
|
||
调用工具函数,支持同步和异步函数
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
tool_arguments: 工具参数
|
||
|
||
Returns:
|
||
工具执行结果
|
||
"""
|
||
if tool_name in tools:
|
||
tool_function = tools[tool_name]
|
||
try:
|
||
# 检查函数是同步的还是异步的
|
||
import asyncio
|
||
import inspect
|
||
|
||
if asyncio.iscoroutinefunction(tool_function) or inspect.isawaitable(tool_function):
|
||
# 异步调用工具函数
|
||
tool_result = await tool_function(**tool_arguments)
|
||
else:
|
||
# 同步调用工具函数
|
||
tool_result = tool_function(**tool_arguments)
|
||
|
||
return tool_result
|
||
except Exception as e:
|
||
return f"工具调用错误: {str(e)}"
|
||
else:
|
||
return f"未找到工具: {tool_name}"
|
||
|
||
async def chat(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
|
||
"""
|
||
与模型对话,支持工具调用
|
||
|
||
Args:
|
||
messages: 初始消息列表
|
||
max_turns: 最大对话轮数
|
||
|
||
Returns:
|
||
最终回答
|
||
"""
|
||
current_messages = messages.copy()
|
||
turn = 0
|
||
|
||
while turn < max_turns:
|
||
turn += 1
|
||
console.print(f"\n[bold magenta]第 {turn} 轮对话[/bold magenta]")
|
||
|
||
# 获取响应
|
||
response = self.get_response(current_messages)
|
||
assistant_message = response.choices[0].message
|
||
|
||
# 将助手消息添加到上下文
|
||
current_messages.append(assistant_message.model_dump())
|
||
|
||
# 提取内容和工具调用
|
||
content = self.extract_content(response)
|
||
tool_calls = self.extract_tool_calls(response)
|
||
finish_reason = self.extract_finish_reason(response)
|
||
|
||
console.print(f"[green]助手回复:[/green] {content}")
|
||
|
||
# 如果没有工具调用或已完成,返回内容
|
||
if tool_calls is None or finish_reason != "tool_calls":
|
||
return content
|
||
|
||
# 处理工具调用
|
||
for tool_call in tool_calls:
|
||
tool_call_name = tool_call.function.name
|
||
tool_call_arguments = json.loads(tool_call.function.arguments)
|
||
|
||
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
|
||
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
|
||
|
||
# 执行工具调用
|
||
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
|
||
console.print(f"[blue]工具结果:[/blue] {tool_result}")
|
||
|
||
# 添加工具结果到上下文
|
||
current_messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"name": tool_call_name,
|
||
"content": tool_result
|
||
})
|
||
|
||
return "达到最大对话轮数限制"
|
||
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
# 创建模型代理
|
||
agent = ModelAgent()
|
||
|
||
while True:
|
||
# 获取用户输入
|
||
user_input = input("\n请输入问题(输入 'exit' 退出): ")
|
||
if user_input.lower() == 'exit':
|
||
break
|
||
|
||
try:
|
||
# 使用 GPT-4o 模型
|
||
messages = [
|
||
{"role": "system", "content": "你是一个有用的助手。"},
|
||
{"role": "user", "content": user_input}
|
||
]
|
||
response = await agent.chat(messages)
|
||
console.print(f"\n[bold magenta]最终回答:[/bold magenta] {response}")
|
||
|
||
except Exception as e:
|
||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||
|
||
|
||
# 为每个工具函数生成的问题列表
|
||
def get_tool_questions():
|
||
"""
|
||
获取为每个工具函数生成的问题列表
|
||
这些问题设计为能够引导大模型调用相应的工具函数,但不直接命令模型调用特定工具
|
||
|
||
Returns:
|
||
包含工具名称和对应问题的字典
|
||
"""
|
||
questions = {
|
||
# 材料科学相关工具函数
|
||
"search_crystal_structures_from_materials_project": [
|
||
"我想了解氧化铁(Fe2O3)的晶体结构,能帮我找一下相关信息吗?",
|
||
"锂离子电池中常用的LiFePO4材料有什么样的晶体结构?",
|
||
"能否帮我查询一下钙钛矿(CaTiO3)的晶体结构数据?"
|
||
],
|
||
|
||
"search_material_property_from_material_project": [
|
||
"二氧化钛(TiO2)有哪些重要的物理和化学性质?",
|
||
"我正在研究锂电池材料,能告诉我LiCoO2的主要性质吗?",
|
||
"硅(Si)的带隙和电子性质是什么?能帮我查一下详细数据吗?"
|
||
],
|
||
|
||
"query_material_from_OQMD": [
|
||
"能帮我从OQMD数据库中查询一下铝合金(Al-Cu)的形成能吗?",
|
||
"我想了解镍基超合金在OQMD数据库中的热力学稳定性数据",
|
||
"OQMD数据库中有关于锌氧化物(ZnO)的什么信息?"
|
||
],
|
||
|
||
"retrieval_from_knowledge_base": [
|
||
"有关高温超导体的最新研究进展是什么?",
|
||
"能否从材料科学知识库中找到关于石墨烯应用的信息?",
|
||
"我想了解钙钛矿太阳能电池的工作原理和效率限制"
|
||
],
|
||
|
||
"predict_properties": [
|
||
"这个化学式为Li2FeSiO4的材料可能有什么样的电子性质?",
|
||
"能预测一下Na3V2(PO4)3这种材料的离子导电性吗?",
|
||
"如果我设计一个新的钙钛矿结构,能预测它的稳定性和带隙吗?"
|
||
],
|
||
|
||
"generate_material": [
|
||
"能生成一种可能具有铁磁性的新材料结构吗?"
|
||
],
|
||
|
||
"optimize_crystal_structure": [
|
||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能帮我优化一下使其更稳定吗?"
|
||
],
|
||
|
||
"calculate_density": [
|
||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能计算一下它的密度吗?"
|
||
],
|
||
|
||
"get_element_composition": [
|
||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能分析一下它的元素组成吗?"
|
||
],
|
||
|
||
"calculate_symmetry": [
|
||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能分析一下它的对称性和空间群吗?"
|
||
],
|
||
|
||
# 化学相关工具函数
|
||
"search_pubchem_advanced": [
|
||
"阿司匹林的分子结构和性质是什么?"
|
||
],
|
||
|
||
"calculate_molecular_properties": [
|
||
"对乙酰氨基酚的SMILES是CC(=O)NC1=CC=C(C=C1)O,计算它的物理化学性质"
|
||
],
|
||
|
||
"calculate_drug_likeness": [
|
||
"布洛芬的SMILES是CC(C)CC1=CC=C(C=C1)C(C)C(=O)O,能计算它的药物性吗?"
|
||
],
|
||
|
||
"calculate_topological_descriptors": [
|
||
"咖啡因的SMILES是CN1C=NC2=C1C(=O)N(C(=O)N2C)C,能计算它的拓扑描述符吗?"
|
||
],
|
||
|
||
"generate_molecular_fingerprints": [
|
||
"尼古丁的SMILES是CN1CCCC1C2=CN=CC=C2,能为它生成Morgan指纹吗?"
|
||
],
|
||
|
||
"calculate_molecular_similarity": [
|
||
"阿司匹林的SMILES是CC(=O)OC1=CC=CC=C1C(=O)O,对乙酰氨基酚的SMILES是CC(=O)NC1=CC=C(C=C1)O,它们的分子相似性如何?"
|
||
],
|
||
|
||
"analyze_molecular_structure": [
|
||
"苯甲酸的SMILES是C1=CC=C(C=C1)C(=O)O,能分析它的结构特征吗?"
|
||
],
|
||
|
||
"generate_molecular_conformer": [
|
||
"甲基苯并噻唑的SMILES是CC1=NC2=CC=CC=C2S1,能生成它的3D构象吗?"
|
||
],
|
||
|
||
"identify_scaffolds": [
|
||
"奎宁的SMILES是COC1=CC2=C(C=CN=C2C=C1)C(C3CC4CCN3CC4C=C)O,它的核心骨架是什么?"
|
||
],
|
||
|
||
"convert_between_chemical_formats": [
|
||
"将乙醇的SMILES:CCO转换为InChI格式"
|
||
],
|
||
|
||
"standardize_molecule": [
|
||
"将四环素的SMILES:CC1C2C(C(=O)C3(C(CC4C(C3C(=O)C2C(=O)C(=C1O)C(=O)N)O)(C(=O)CO4)O)O)N(C)C标准化处理"
|
||
],
|
||
|
||
"enumerate_stereoisomers": [
|
||
"2-丁醇的SMILES是CCC(C)O,它可能有哪些立体异构体?,不要单纯靠你自身的知识,如果不确定可以使用工具。"
|
||
],
|
||
|
||
"perform_substructure_search": [
|
||
"在阿莫西林的SMILES:CC1(C(N2C(S1)C(C2=O)NC(=O)C(C3=CC=C(C=C3)O)N)C(=O)O)C中搜索羧酸基团"
|
||
],
|
||
|
||
# RXN工具函数的测试问题
|
||
"predict_reaction_outcome": [
|
||
"我正在研究乙酸和乙醇的酯化反应,想知道这个反应的产物是什么。反应物的SMILES表示法是'CC(=O)O.CCO'。能帮我预测一下这个反应最可能的结果吗?"
|
||
],
|
||
|
||
"predict_reaction_batch": [
|
||
"我在实验室中设计了三个酯化反应系列,想同时了解它们的可能产物。这三个反应的SMILES分别是'CC(=O)O.CCO'(乙酸和乙醇)、'CC(=O)O.CCCO'(乙酸和丙醇)和'CC(=O)O.CCCCO'(乙酸和丁醇)。能否一次性预测这些反应的结果?"
|
||
],
|
||
|
||
"predict_reaction_topn": [
|
||
"我在研究丙烯醛和甲胺的反应机理,这个反应可能有多种产物路径。反应物的SMILES是'C=CC=O.CN'。能帮我分析出最可能的前3种产物及它们的相对可能性吗?"
|
||
],
|
||
|
||
"predict_retrosynthesis": [
|
||
"我需要为实验室合成阿司匹林,但不确定最佳的合成路线。阿司匹林的SMILES是'CC(=O)OC1=CC=CC=C1C(=O)O'。能帮我分析一下可能的合成路径,将其分解为更简单的前体化合物吗?"
|
||
],
|
||
|
||
"predict_biocatalytic_retrosynthesis": [
|
||
"我们实验室正在研究绿色化学合成方法,想知道是否可以使用酶催化方式合成这个含溴的芳香化合物(SMILES: 'OC1C(O)C=C(Br)C=C1')。能帮我分析可能的生物催化合成路径吗?"
|
||
],
|
||
|
||
"predict_reaction_properties": [
|
||
"我正在研究这个有机反应的机理:'CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F',特别想了解反应中的原子映射关系。能帮我分析一下反应前后各原子的对应关系吗?我需要atom-mapping属性。"
|
||
],
|
||
|
||
"extract_reaction_actions": [
|
||
"我从一篇有机合成文献中找到了这段实验步骤:'To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.' 能帮我将这段文本转换为结构化的反应步骤吗?这样我可以更清晰地理解每个操作。"
|
||
]
|
||
}
|
||
|
||
return questions
|
||
|
||
# 测试特定工具函数的问题
|
||
async def test_tool_with_question(question_index: int = 0):
|
||
"""
|
||
使用预设问题测试特定工具函数
|
||
|
||
Args:
|
||
question_index: 问题索引,默认为0
|
||
"""
|
||
# 获取所有工具问题
|
||
all_questions = get_tool_questions()
|
||
|
||
# 创建工具名称到问题的映射
|
||
tool_questions = {}
|
||
for tool_name, questions in all_questions.items():
|
||
if questions:
|
||
tool_questions[tool_name] = questions[min(question_index, len(questions)-1)]
|
||
|
||
# 打印可用的工具和问题
|
||
console.print("[bold]可用的工具和问题:[/bold]")
|
||
for i, (tool_name, question) in enumerate(tool_questions.items(), 1):
|
||
console.print(f"{i}. [cyan]{tool_name}[/cyan]: {question}")
|
||
|
||
# 选择要测试的工具
|
||
choice = input("\n请选择要测试的工具编号(输入'all'测试所有工具): ")
|
||
|
||
agent = ModelAgent()
|
||
|
||
if choice.lower() == 'all':
|
||
# 测试所有工具
|
||
for tool_name, question in tool_questions.items():
|
||
console.print(f"\n[bold]测试工具: [cyan]{tool_name}[/cyan][/bold]")
|
||
console.print(f"问题: {question}")
|
||
|
||
messages = [
|
||
{"role": "system", "content": "你是一个有用的助手。"},
|
||
{"role": "user", "content": question}
|
||
]
|
||
|
||
try:
|
||
response = await agent.chat(messages)
|
||
#console.print(f"[green]回答:[/green] {response[:200]}...")
|
||
except Exception as e:
|
||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||
else:
|
||
try:
|
||
index = int(choice) - 1
|
||
if 0 <= index < len(tool_questions):
|
||
tool_name = list(tool_questions.keys())[index]
|
||
question = tool_questions[tool_name]
|
||
|
||
console.print(f"\n[bold]测试工具: [cyan]{tool_name}[/cyan][/bold]")
|
||
console.print(f"问题: {question}")
|
||
|
||
messages = [
|
||
{"role": "system", "content": "你是一个有用的助手。"},
|
||
{"role": "user", "content": question+'如果你不确定答案,请使用工具'}
|
||
]
|
||
|
||
response = await agent.chat(messages)
|
||
#console.print(f"[green]回答:[/green] {response}")
|
||
else:
|
||
console.print("[bold red]无效的选择[/bold red]")
|
||
except ValueError:
|
||
console.print("[bold red]请输入有效的数字[/bold red]")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 取消注释以运行主函数
|
||
# asyncio.run(main())
|
||
|
||
# 取消注释以测试工具函数问题
|
||
asyncio.run(test_tool_with_question())
|
||
|
||
# pass
|
||
|
||
# 知识检索API的接口 数据库
|