初次提交
This commit is contained in:
437
test_tools/agent_test.py
Executable file
437
test_tools/agent_test.py
Executable file
@@ -0,0 +1,437 @@
|
||||
import asyncio
|
||||
|
||||
from api_key import *
|
||||
from openai import OpenAI
|
||||
import json
|
||||
from typing import Dict, List, Any, Union, Optional
|
||||
from rich.console import Console
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
from sci_mcp import *
|
||||
|
||||
# 获取工具
|
||||
all_tools_schemas = get_all_tool_schemas()
|
||||
tools = get_all_tools()
|
||||
chemistry_tools = get_domain_tools("chemistry")
|
||||
|
||||
console = Console()
|
||||
#console.print(all_tools_schemas)
|
||||
|
||||
class ModelAgent:
|
||||
"""
|
||||
只支持 gpt-4o 模型的代理类
|
||||
处理返回值格式并提供统一的工具调用接口
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "gpt-4o"):
|
||||
"""
|
||||
初始化模型客户端
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
"""
|
||||
# 初始化客户端
|
||||
self.client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_URL,
|
||||
)
|
||||
|
||||
# 模型名称
|
||||
self.model_name = model_name
|
||||
|
||||
# 定义工具列表
|
||||
self.tools = all_tools_schemas
|
||||
def get_response(self, messages: List[Dict[str, Any]]) -> Any:
|
||||
"""
|
||||
获取模型的响应
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
"""
|
||||
completion = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
tools=self.tools,
|
||||
tool_choice="auto",
|
||||
temperature=0.6,
|
||||
)
|
||||
return completion
|
||||
|
||||
def extract_tool_calls(self, response: Any) -> Optional[List[Any]]:
|
||||
"""
|
||||
从响应中提取工具调用信息
|
||||
|
||||
Args:
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
工具调用列表,如果没有则返回 None
|
||||
"""
|
||||
if hasattr(response.choices[0].message, 'tool_calls'):
|
||||
return response.choices[0].message.tool_calls
|
||||
return None
|
||||
|
||||
def extract_content(self, response: Any) -> str:
|
||||
"""
|
||||
从响应中提取内容
|
||||
|
||||
Args:
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
内容字符串
|
||||
"""
|
||||
content = response.choices[0].message.content
|
||||
return content if content is not None else ""
|
||||
|
||||
def extract_finish_reason(self, response: Any) -> str:
|
||||
"""
|
||||
从响应中提取完成原因
|
||||
|
||||
Args:
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
完成原因
|
||||
"""
|
||||
return response.choices[0].finish_reason
|
||||
|
||||
async def call_tool(self, tool_name: str, tool_arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具函数,支持同步和异步函数
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
if tool_name in tools:
|
||||
tool_function = tools[tool_name]
|
||||
try:
|
||||
# 检查函数是同步的还是异步的
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
if asyncio.iscoroutinefunction(tool_function) or inspect.isawaitable(tool_function):
|
||||
# 异步调用工具函数
|
||||
tool_result = await tool_function(**tool_arguments)
|
||||
else:
|
||||
# 同步调用工具函数
|
||||
tool_result = tool_function(**tool_arguments)
|
||||
|
||||
return tool_result
|
||||
except Exception as e:
|
||||
return f"工具调用错误: {str(e)}"
|
||||
else:
|
||||
return f"未找到工具: {tool_name}"
|
||||
|
||||
async def chat(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
|
||||
"""
|
||||
与模型对话,支持工具调用
|
||||
|
||||
Args:
|
||||
messages: 初始消息列表
|
||||
max_turns: 最大对话轮数
|
||||
|
||||
Returns:
|
||||
最终回答
|
||||
"""
|
||||
current_messages = messages.copy()
|
||||
turn = 0
|
||||
|
||||
while turn < max_turns:
|
||||
turn += 1
|
||||
console.print(f"\n[bold magenta]第 {turn} 轮对话[/bold magenta]")
|
||||
|
||||
# 获取响应
|
||||
response = self.get_response(current_messages)
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# 将助手消息添加到上下文
|
||||
current_messages.append(assistant_message.model_dump())
|
||||
|
||||
# 提取内容和工具调用
|
||||
content = self.extract_content(response)
|
||||
tool_calls = self.extract_tool_calls(response)
|
||||
finish_reason = self.extract_finish_reason(response)
|
||||
|
||||
console.print(f"[green]助手回复:[/green] {content}")
|
||||
|
||||
# 如果没有工具调用或已完成,返回内容
|
||||
if tool_calls is None or finish_reason != "tool_calls":
|
||||
return content
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in tool_calls:
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = json.loads(tool_call.function.arguments)
|
||||
|
||||
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
|
||||
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
|
||||
|
||||
# 执行工具调用
|
||||
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
|
||||
console.print(f"[blue]工具结果:[/blue] {tool_result}")
|
||||
|
||||
# 添加工具结果到上下文
|
||||
current_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": tool_result
|
||||
})
|
||||
|
||||
return "达到最大对话轮数限制"
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
# 创建模型代理
|
||||
agent = ModelAgent()
|
||||
|
||||
while True:
|
||||
# 获取用户输入
|
||||
user_input = input("\n请输入问题(输入 'exit' 退出): ")
|
||||
if user_input.lower() == 'exit':
|
||||
break
|
||||
|
||||
try:
|
||||
# 使用 GPT-4o 模型
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": user_input}
|
||||
]
|
||||
response = await agent.chat(messages)
|
||||
console.print(f"\n[bold magenta]最终回答:[/bold magenta] {response}")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
# 为每个工具函数生成的问题列表
|
||||
def get_tool_questions():
|
||||
"""
|
||||
获取为每个工具函数生成的问题列表
|
||||
这些问题设计为能够引导大模型调用相应的工具函数,但不直接命令模型调用特定工具
|
||||
|
||||
Returns:
|
||||
包含工具名称和对应问题的字典
|
||||
"""
|
||||
questions = {
|
||||
# 材料科学相关工具函数
|
||||
"search_crystal_structures_from_materials_project": [
|
||||
"我想了解氧化铁(Fe2O3)的晶体结构,能帮我找一下相关信息吗?",
|
||||
"锂离子电池中常用的LiFePO4材料有什么样的晶体结构?",
|
||||
"能否帮我查询一下钙钛矿(CaTiO3)的晶体结构数据?"
|
||||
],
|
||||
|
||||
"search_material_property_from_material_project": [
|
||||
"二氧化钛(TiO2)有哪些重要的物理和化学性质?",
|
||||
"我正在研究锂电池材料,能告诉我LiCoO2的主要性质吗?",
|
||||
"硅(Si)的带隙和电子性质是什么?能帮我查一下详细数据吗?"
|
||||
],
|
||||
|
||||
"query_material_from_OQMD": [
|
||||
"能帮我从OQMD数据库中查询一下铝合金(Al-Cu)的形成能吗?",
|
||||
"我想了解镍基超合金在OQMD数据库中的热力学稳定性数据",
|
||||
"OQMD数据库中有关于锌氧化物(ZnO)的什么信息?"
|
||||
],
|
||||
|
||||
"retrieval_from_knowledge_base": [
|
||||
"有关高温超导体的最新研究进展是什么?",
|
||||
"能否从材料科学知识库中找到关于石墨烯应用的信息?",
|
||||
"我想了解钙钛矿太阳能电池的工作原理和效率限制"
|
||||
],
|
||||
|
||||
"predict_properties": [
|
||||
"这个化学式为Li2FeSiO4的材料可能有什么样的电子性质?",
|
||||
"能预测一下Na3V2(PO4)3这种材料的离子导电性吗?",
|
||||
"如果我设计一个新的钙钛矿结构,能预测它的稳定性和带隙吗?"
|
||||
],
|
||||
|
||||
"generate_material": [
|
||||
"能生成一种可能具有铁磁性的新材料结构吗?"
|
||||
],
|
||||
|
||||
"optimize_crystal_structure": [
|
||||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能帮我优化一下使其更稳定吗?"
|
||||
],
|
||||
|
||||
"calculate_density": [
|
||||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能计算一下它的密度吗?"
|
||||
],
|
||||
|
||||
"get_element_composition": [
|
||||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能分析一下它的元素组成吗?"
|
||||
],
|
||||
|
||||
"calculate_symmetry": [
|
||||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能分析一下它的对称性和空间群吗?"
|
||||
],
|
||||
|
||||
# 化学相关工具函数
|
||||
"search_pubchem_advanced": [
|
||||
"阿司匹林的分子结构和性质是什么?"
|
||||
],
|
||||
|
||||
"calculate_molecular_properties": [
|
||||
"对乙酰氨基酚的SMILES是CC(=O)NC1=CC=C(C=C1)O,计算它的物理化学性质"
|
||||
],
|
||||
|
||||
"calculate_drug_likeness": [
|
||||
"布洛芬的SMILES是CC(C)CC1=CC=C(C=C1)C(C)C(=O)O,能计算它的药物性吗?"
|
||||
],
|
||||
|
||||
"calculate_topological_descriptors": [
|
||||
"咖啡因的SMILES是CN1C=NC2=C1C(=O)N(C(=O)N2C)C,能计算它的拓扑描述符吗?"
|
||||
],
|
||||
|
||||
"generate_molecular_fingerprints": [
|
||||
"尼古丁的SMILES是CN1CCCC1C2=CN=CC=C2,能为它生成Morgan指纹吗?"
|
||||
],
|
||||
|
||||
"calculate_molecular_similarity": [
|
||||
"阿司匹林的SMILES是CC(=O)OC1=CC=CC=C1C(=O)O,对乙酰氨基酚的SMILES是CC(=O)NC1=CC=C(C=C1)O,它们的分子相似性如何?"
|
||||
],
|
||||
|
||||
"analyze_molecular_structure": [
|
||||
"苯甲酸的SMILES是C1=CC=C(C=C1)C(=O)O,能分析它的结构特征吗?"
|
||||
],
|
||||
|
||||
"generate_molecular_conformer": [
|
||||
"甲基苯并噻唑的SMILES是CC1=NC2=CC=CC=C2S1,能生成它的3D构象吗?"
|
||||
],
|
||||
|
||||
"identify_scaffolds": [
|
||||
"奎宁的SMILES是COC1=CC2=C(C=CN=C2C=C1)C(C3CC4CCN3CC4C=C)O,它的核心骨架是什么?"
|
||||
],
|
||||
|
||||
"convert_between_chemical_formats": [
|
||||
"将乙醇的SMILES:CCO转换为InChI格式"
|
||||
],
|
||||
|
||||
"standardize_molecule": [
|
||||
"将四环素的SMILES:CC1C2C(C(=O)C3(C(CC4C(C3C(=O)C2C(=O)C(=C1O)C(=O)N)O)(C(=O)CO4)O)O)N(C)C标准化处理"
|
||||
],
|
||||
|
||||
"enumerate_stereoisomers": [
|
||||
"2-丁醇的SMILES是CCC(C)O,它可能有哪些立体异构体?,不要单纯靠你自身的知识,如果不确定可以使用工具。"
|
||||
],
|
||||
|
||||
"perform_substructure_search": [
|
||||
"在阿莫西林的SMILES:CC1(C(N2C(S1)C(C2=O)NC(=O)C(C3=CC=C(C=C3)O)N)C(=O)O)C中搜索羧酸基团"
|
||||
],
|
||||
|
||||
# RXN工具函数的测试问题
|
||||
"predict_reaction_outcome": [
|
||||
"我正在研究乙酸和乙醇的酯化反应,想知道这个反应的产物是什么。反应物的SMILES表示法是'CC(=O)O.CCO'。能帮我预测一下这个反应最可能的结果吗?"
|
||||
],
|
||||
|
||||
"predict_reaction_batch": [
|
||||
"我在实验室中设计了三个酯化反应系列,想同时了解它们的可能产物。这三个反应的SMILES分别是'CC(=O)O.CCO'(乙酸和乙醇)、'CC(=O)O.CCCO'(乙酸和丙醇)和'CC(=O)O.CCCCO'(乙酸和丁醇)。能否一次性预测这些反应的结果?"
|
||||
],
|
||||
|
||||
"predict_reaction_topn": [
|
||||
"我在研究丙烯醛和甲胺的反应机理,这个反应可能有多种产物路径。反应物的SMILES是'C=CC=O.CN'。能帮我分析出最可能的前3种产物及它们的相对可能性吗?"
|
||||
],
|
||||
|
||||
"predict_retrosynthesis": [
|
||||
"我需要为实验室合成阿司匹林,但不确定最佳的合成路线。阿司匹林的SMILES是'CC(=O)OC1=CC=CC=C1C(=O)O'。能帮我分析一下可能的合成路径,将其分解为更简单的前体化合物吗?"
|
||||
],
|
||||
|
||||
"predict_biocatalytic_retrosynthesis": [
|
||||
"我们实验室正在研究绿色化学合成方法,想知道是否可以使用酶催化方式合成这个含溴的芳香化合物(SMILES: 'OC1C(O)C=C(Br)C=C1')。能帮我分析可能的生物催化合成路径吗?"
|
||||
],
|
||||
|
||||
"predict_reaction_properties": [
|
||||
"我正在研究这个有机反应的机理:'CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F',特别想了解反应中的原子映射关系。能帮我分析一下反应前后各原子的对应关系吗?我需要atom-mapping属性。"
|
||||
],
|
||||
|
||||
"extract_reaction_actions": [
|
||||
"我从一篇有机合成文献中找到了这段实验步骤:'To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.' 能帮我将这段文本转换为结构化的反应步骤吗?这样我可以更清晰地理解每个操作。"
|
||||
]
|
||||
}
|
||||
|
||||
return questions
|
||||
|
||||
# 测试特定工具函数的问题
|
||||
async def test_tool_with_question(question_index: int = 0):
|
||||
"""
|
||||
使用预设问题测试特定工具函数
|
||||
|
||||
Args:
|
||||
question_index: 问题索引,默认为0
|
||||
"""
|
||||
# 获取所有工具问题
|
||||
all_questions = get_tool_questions()
|
||||
|
||||
# 创建工具名称到问题的映射
|
||||
tool_questions = {}
|
||||
for tool_name, questions in all_questions.items():
|
||||
if questions:
|
||||
tool_questions[tool_name] = questions[min(question_index, len(questions)-1)]
|
||||
|
||||
# 打印可用的工具和问题
|
||||
console.print("[bold]可用的工具和问题:[/bold]")
|
||||
for i, (tool_name, question) in enumerate(tool_questions.items(), 1):
|
||||
console.print(f"{i}. [cyan]{tool_name}[/cyan]: {question}")
|
||||
|
||||
# 选择要测试的工具
|
||||
choice = input("\n请选择要测试的工具编号(输入'all'测试所有工具): ")
|
||||
|
||||
agent = ModelAgent()
|
||||
|
||||
if choice.lower() == 'all':
|
||||
# 测试所有工具
|
||||
for tool_name, question in tool_questions.items():
|
||||
console.print(f"\n[bold]测试工具: [cyan]{tool_name}[/cyan][/bold]")
|
||||
console.print(f"问题: {question}")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": question}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await agent.chat(messages)
|
||||
#console.print(f"[green]回答:[/green] {response[:200]}...")
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
else:
|
||||
try:
|
||||
index = int(choice) - 1
|
||||
if 0 <= index < len(tool_questions):
|
||||
tool_name = list(tool_questions.keys())[index]
|
||||
question = tool_questions[tool_name]
|
||||
|
||||
console.print(f"\n[bold]测试工具: [cyan]{tool_name}[/cyan][/bold]")
|
||||
console.print(f"问题: {question}")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": question+'如果你不确定答案,请使用工具'}
|
||||
]
|
||||
|
||||
response = await agent.chat(messages)
|
||||
#console.print(f"[green]回答:[/green] {response}")
|
||||
else:
|
||||
console.print("[bold red]无效的选择[/bold red]")
|
||||
except ValueError:
|
||||
console.print("[bold red]请输入有效的数字[/bold red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 取消注释以运行主函数
|
||||
# asyncio.run(main())
|
||||
|
||||
# 取消注释以测试工具函数问题
|
||||
asyncio.run(test_tool_with_question())
|
||||
|
||||
# pass
|
||||
|
||||
# 知识检索API的接口 数据库
|
||||
8
test_tools/api_key.py
Executable file
8
test_tools/api_key.py
Executable file
@@ -0,0 +1,8 @@
|
||||
OPENAI_API_KEY='sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d'
|
||||
OPENAI_API_URL='https://vip.apiyi.com/v1'
|
||||
|
||||
#OPENAI_API_KEY='gpustack_56f0adc61a865d22_c61cdbf601fa2cb95979d417618060e6'
|
||||
#OPENAI_API_URL='http://192.168.191.100:5080/v1'
|
||||
|
||||
|
||||
|
||||
123
test_tools/chemistry/test_pubchem.py
Normal file
123
test_tools/chemistry/test_pubchem.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Test script for PubChem tools
|
||||
|
||||
This script tests the search_pubchem_advanced function from the chemistry_mcp module.
|
||||
"""
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
import asyncio
|
||||
|
||||
from sci_mcp.chemistry_mcp.pubchem_tools.pubchem_tools import _search_by_formula
|
||||
|
||||
|
||||
from sci_mcp.chemistry_mcp import search_pubchem_advanced
|
||||
|
||||
async def test_search_by_name():
|
||||
"""Test searching compounds by name"""
|
||||
print("\n=== Testing search by name ===")
|
||||
result = await search_pubchem_advanced(name="Aspirin")
|
||||
print(result)
|
||||
|
||||
async def test_search_by_smiles():
|
||||
"""Test searching compounds by SMILES notation"""
|
||||
print("\n=== Testing search by SMILES ===")
|
||||
# SMILES for Caffeine
|
||||
result = await search_pubchem_advanced(smiles="CN1C=NC2=C1C(=O)N(C(=O)N2C)C")
|
||||
print(result)
|
||||
|
||||
async def test_search_by_formula():
|
||||
"""Test searching compounds by molecular formula"""
|
||||
print("\n=== Testing search by formula ===")
|
||||
# Formula for Aspirin
|
||||
result = await search_pubchem_advanced(formula="C9H8O4", max_results=2)
|
||||
print(result)
|
||||
|
||||
async def test_complex_formula():
|
||||
"""Test searching with a more complex formula that might cause timeout"""
|
||||
print("\n=== Testing complex formula search ===")
|
||||
# A more complex formula that might return many results
|
||||
result = await search_pubchem_advanced(
|
||||
formula="C6H12O6", # Glucose and isomers
|
||||
max_results=5
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def test_complex_molecules():
|
||||
"""Test searching for complex molecules with rich molecular features"""
|
||||
print("\n=== Testing complex molecules with rich features ===")
|
||||
|
||||
# 1. Paclitaxel (Taxol) - Complex anticancer drug with many rotatable bonds and H-bond donors/acceptors
|
||||
print("\n--- Testing Paclitaxel (anticancer drug) ---")
|
||||
result = await search_pubchem_advanced(name="Paclitaxel")
|
||||
print(result)
|
||||
|
||||
# 2. Vancomycin - Complex antibiotic with many H-bond donors/acceptors
|
||||
print("\n--- Testing Vancomycin (antibiotic) ---")
|
||||
result = await search_pubchem_advanced(name="Vancomycin")
|
||||
print(result)
|
||||
|
||||
# 3. Cholesterol - Steroid with complex ring structure
|
||||
print("\n--- Testing Cholesterol (steroid) ---")
|
||||
result = await search_pubchem_advanced(name="Cholesterol")
|
||||
print(result)
|
||||
|
||||
# 4. Ibuprofen - Common NSAID with rotatable bonds
|
||||
print("\n--- Testing Ibuprofen (NSAID) ---")
|
||||
result = await search_pubchem_advanced(name="Ibuprofen")
|
||||
print(result)
|
||||
|
||||
# 5. Amoxicillin - Antibiotic with multiple functional groups
|
||||
print("\n--- Testing Amoxicillin (antibiotic) ---")
|
||||
result = await search_pubchem_advanced(name="Amoxicillin")
|
||||
print(result)
|
||||
|
||||
async def test_molecules_by_smiles():
|
||||
"""Test searching for complex molecules using SMILES notation"""
|
||||
print("\n=== Testing complex molecules by SMILES ===")
|
||||
|
||||
# 1. Atorvastatin (Lipitor) - Cholesterol-lowering drug with complex structure
|
||||
print("\n--- Testing Atorvastatin (Lipitor) ---")
|
||||
result = await search_pubchem_advanced(
|
||||
smiles="CC(C)C1=C(C(=C(C=C1)C(C)C)C2=CC(=C(C=C2)F)F)C(CC(CC(=O)O)O)NC(=O)C3=CC=C(C=C3)F"
|
||||
)
|
||||
print(result)
|
||||
|
||||
# 2. Morphine - Opioid with multiple rings and H-bond features
|
||||
print("\n--- Testing Morphine ---")
|
||||
result = await search_pubchem_advanced(
|
||||
smiles="CN1CCC23C4C1CC5=C2C(=C(C=C5)O)OC3C(C=C4)O"
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def test_invalid_search():
|
||||
"""Test searching with invalid parameters"""
|
||||
print("\n=== Testing invalid search ===")
|
||||
# No parameters provided
|
||||
result = await search_pubchem_advanced()
|
||||
print(result)
|
||||
|
||||
# Invalid SMILES
|
||||
print("\n=== Testing invalid SMILES ===")
|
||||
result = await search_pubchem_advanced(smiles="INVALID_SMILES_STRING")
|
||||
print(result)
|
||||
|
||||
async def run_all_tests():
|
||||
"""Run all test functions"""
|
||||
await test_search_by_name()
|
||||
await test_search_by_smiles()
|
||||
await test_search_by_formula()
|
||||
await test_complex_formula()
|
||||
# await test_complex_molecules()
|
||||
# await test_molecules_by_smiles()
|
||||
#await test_invalid_search()
|
||||
# from sci_mcp.chemistry_mcp.pubchem_tools.pubchem_tools import _search_by_name
|
||||
# compounds=await _search_by_formula('C6H12O6')
|
||||
# print(compounds[0])
|
||||
if __name__ == "__main__":
|
||||
print("Testing PubChem search tools...")
|
||||
asyncio.run(run_all_tests())
|
||||
print("\nAll tests completed.")
|
||||
# import pubchempy
|
||||
# compunnds = pubchempy.get_compounds('Aspirin', 'name')
|
||||
# print(compunnds[0].to_dict())
|
||||
|
||||
159
test_tools/chemistry/test_rdkit.py
Normal file
159
test_tools/chemistry/test_rdkit.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Test script for RDKit tools.
|
||||
|
||||
This script tests the functionality of the RDKit tools implemented in the
|
||||
sci_mcp/chemistry_mcp/rdkit_tools module.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the project root directory to the Python path
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
|
||||
from sci_mcp.chemistry_mcp.rdkit_tools.rdkit_tools import (
|
||||
calculate_molecular_properties,
|
||||
calculate_drug_likeness,
|
||||
calculate_topological_descriptors,
|
||||
generate_molecular_fingerprints,
|
||||
calculate_molecular_similarity,
|
||||
analyze_molecular_structure,
|
||||
generate_molecular_conformer,
|
||||
identify_scaffolds,
|
||||
convert_between_chemical_formats,
|
||||
standardize_molecule,
|
||||
enumerate_stereoisomers,
|
||||
perform_substructure_search
|
||||
)
|
||||
|
||||
def test_molecular_properties():
|
||||
"""Test the calculation of molecular properties."""
|
||||
print("Testing calculate_molecular_properties...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = calculate_molecular_properties(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_drug_likeness():
|
||||
"""Test the calculation of drug-likeness properties."""
|
||||
print("Testing calculate_drug_likeness...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = calculate_drug_likeness(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_topological_descriptors():
|
||||
"""Test the calculation of topological descriptors."""
|
||||
print("Testing calculate_topological_descriptors...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = calculate_topological_descriptors(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_molecular_fingerprints():
|
||||
"""Test the generation of molecular fingerprints."""
|
||||
print("Testing generate_molecular_fingerprints...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = generate_molecular_fingerprints(smiles, fingerprint_type="morgan")
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_molecular_similarity():
|
||||
"""Test the calculation of molecular similarity."""
|
||||
print("Testing calculate_molecular_similarity...")
|
||||
# Aspirin and Ibuprofen
|
||||
smiles1 = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin
|
||||
smiles2 = "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" # Ibuprofen
|
||||
result = calculate_molecular_similarity(smiles1, smiles2)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_molecular_structure():
|
||||
"""Test the analysis of molecular structure."""
|
||||
print("Testing analyze_molecular_structure...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = analyze_molecular_structure(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_molecular_conformer():
|
||||
"""Test the generation of molecular conformers."""
|
||||
print("Testing generate_molecular_conformer...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = generate_molecular_conformer(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_scaffolds():
|
||||
"""Test the identification of molecular scaffolds."""
|
||||
print("Testing identify_scaffolds...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = identify_scaffolds(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_format_conversion():
|
||||
"""Test the conversion between chemical formats."""
|
||||
print("Testing convert_between_chemical_formats...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = convert_between_chemical_formats(smiles, "smiles", "inchi")
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_standardize_molecule():
|
||||
"""Test the standardization of molecules."""
|
||||
print("Testing standardize_molecule...")
|
||||
# Betaine with charges
|
||||
smiles = "C[N+](C)(C)CC(=O)[O-]"
|
||||
result = standardize_molecule(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_stereoisomers():
|
||||
"""Test the enumeration of stereoisomers."""
|
||||
print("Testing enumerate_stereoisomers...")
|
||||
# 3-penten-2-ol (has both a stereocenter and a stereobond)
|
||||
smiles = "CC(O)C=CC"
|
||||
result = enumerate_stereoisomers(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_substructure_search():
|
||||
"""Test the substructure search."""
|
||||
print("Testing perform_substructure_search...")
|
||||
# Aspirin, search for carboxylic acid group
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
pattern = "C(=O)O"
|
||||
result = perform_substructure_search(smiles, pattern)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("Testing RDKit tools...\n")
|
||||
|
||||
# Uncomment the tests you want to run
|
||||
test_molecular_properties()
|
||||
test_drug_likeness()
|
||||
test_topological_descriptors()
|
||||
test_molecular_fingerprints()
|
||||
test_molecular_similarity()
|
||||
test_molecular_structure()
|
||||
test_molecular_conformer()
|
||||
test_scaffolds()
|
||||
test_format_conversion()
|
||||
test_standardize_molecule()
|
||||
test_stereoisomers()
|
||||
test_substructure_search()
|
||||
|
||||
print("All tests completed.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
289
test_tools/chemistry/test_rxn.py
Normal file
289
test_tools/chemistry/test_rxn.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
测试RXN工具函数模块
|
||||
|
||||
此模块包含用于测试rxn_tools模块中化学反应预测和分析工具函数的测试用例。
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
|
||||
from sci_mcp.chemistry_mcp.rxn_tools.rxn_tools import (
|
||||
predict_reaction_outcome_rxn,
|
||||
predict_reaction_topn_rxn,
|
||||
predict_reaction_properties_rxn,
|
||||
extract_reaction_actions_rxn
|
||||
)
|
||||
|
||||
# 创建控制台对象用于格式化输出
|
||||
console = Console()
|
||||
|
||||
|
||||
async def test_predict_reaction_outcome():
|
||||
"""测试反应结果预测功能"""
|
||||
console.print("[bold cyan]测试反应结果预测功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:溴和蒽的反应
|
||||
fixed_reactants = "BrBr.c1ccc2cc3ccccc3cc2c1"
|
||||
console.print(f"固定反应物: {fixed_reactants}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_outcome(fixed_reactants)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_predict_reaction_topn():
|
||||
"""测试多产物预测功能"""
|
||||
console.print("\n[bold cyan]测试多产物预测功能[/bold cyan]")
|
||||
|
||||
# 测试1:单个反应(字符串格式)
|
||||
fixed_reactants = "C=CC=O.CN" # 丙烯醛和甲胺
|
||||
fixed_topn = 2
|
||||
console.print(f"测试1 - 单个反应(字符串格式)")
|
||||
console.print(f"固定反应物: {fixed_reactants}")
|
||||
console.print(f"固定预测产物数量: {fixed_topn}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_topn(fixed_reactants, fixed_topn)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
# 测试2:单个反应(列表格式)
|
||||
fixed_reactants_list = ["BrBr", "c1ccc2cc3ccccc3cc2c1"] # 溴和蒽
|
||||
console.print(f"\n测试2 - 单个反应(列表格式)")
|
||||
console.print(f"固定反应物: {fixed_reactants_list}")
|
||||
console.print(f"固定预测产物数量: {fixed_topn}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_topn(fixed_reactants_list, fixed_topn)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
# 测试3:多个反应(列表的列表格式)
|
||||
fixed_reactants_batch = [
|
||||
["BrBr", "c1ccc2cc3ccccc3cc2c1"], # 溴和蒽
|
||||
["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"] # 溴和修饰的蒽
|
||||
]
|
||||
console.print(f"\n测试3 - 多个反应(列表的列表格式)")
|
||||
console.print(f"固定反应物批量: {fixed_reactants_batch}")
|
||||
console.print(f"固定预测产物数量: {fixed_topn}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_topn(fixed_reactants_batch, fixed_topn)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_predict_retrosynthesis():
|
||||
"""测试逆合成分析功能"""
|
||||
console.print("\n[bold cyan]测试逆合成分析功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:阿司匹林的逆合成分析
|
||||
fixed_target_molecule = "CC(=O)OC1=CC=CC=C1C(=O)O" # 阿司匹林
|
||||
fixed_max_steps = 1
|
||||
console.print(f"固定目标分子: {fixed_target_molecule}")
|
||||
console.print(f"固定最大步骤数: {fixed_max_steps}")
|
||||
|
||||
try:
|
||||
result = await predict_retrosynthesis(fixed_target_molecule, fixed_max_steps)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_predict_biocatalytic_retrosynthesis():
|
||||
"""测试生物催化逆合成分析功能"""
|
||||
console.print("\n[bold cyan]测试生物催化逆合成分析功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:一个可能适合酶催化的分子
|
||||
fixed_target_molecule = "OC1C(O)C=C(Br)C=C1"
|
||||
console.print(f"固定目标分子: {fixed_target_molecule}")
|
||||
|
||||
try:
|
||||
result = await predict_biocatalytic_retrosynthesis(fixed_target_molecule)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_predict_reaction_properties():
|
||||
"""测试反应属性预测功能"""
|
||||
console.print("\n[bold cyan]测试反应属性预测功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:原子映射
|
||||
fixed_reaction = "CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F"
|
||||
fixed_property_type = "atom-mapping"
|
||||
console.print(f"固定反应: {fixed_reaction}")
|
||||
console.print(f"固定属性类型: {fixed_property_type}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_properties(fixed_reaction, fixed_property_type)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_extract_reaction_actions():
|
||||
"""测试从文本提取反应步骤功能"""
|
||||
console.print("\n[bold cyan]测试从文本提取反应步骤功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:从文本描述中提取反应步骤
|
||||
fixed_reaction_text = """To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour."""
|
||||
console.print(f"固定反应文本: {fixed_reaction_text}")
|
||||
|
||||
try:
|
||||
result = await extract_reaction_actions(fixed_reaction_text)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_all():
|
||||
"""测试所有RXN工具函数"""
|
||||
console.print("[bold magenta]===== 开始测试RXN工具函数 =====[/bold magenta]\n")
|
||||
|
||||
# 测试各个功能
|
||||
await test_predict_reaction_outcome()
|
||||
await test_predict_reaction_topn()
|
||||
await test_predict_retrosynthesis()
|
||||
await test_predict_biocatalytic_retrosynthesis()
|
||||
await test_predict_reaction_properties()
|
||||
await test_extract_reaction_actions()
|
||||
|
||||
console.print("\n[bold magenta]===== RXN工具函数测试完成 =====[/bold magenta]")
|
||||
|
||||
|
||||
def get_rxn_tool_questions():
|
||||
"""
|
||||
获取为RXN工具函数生成的问题列表
|
||||
这些问题设计为能够引导大模型调用相应的工具函数
|
||||
|
||||
Returns:
|
||||
包含工具名称和对应问题的字典
|
||||
"""
|
||||
questions = {
|
||||
"predict_reaction_outcome": [
|
||||
"如果我将溴和蒽混合在一起,会形成什么产物?",
|
||||
"乙酸和乙醇反应会生成什么?",
|
||||
"预测一下丙烯醛和甲胺反应的结果"
|
||||
],
|
||||
|
||||
|
||||
"predict_reaction_topn": [
|
||||
"丙烯醛和甲胺反应可能生成哪几种主要产物?",
|
||||
"预测溴和蒽反应可能的前3个产物",
|
||||
"乙酸和乙醇反应可能有哪些不同的结果?请给出最可能的几种产物"
|
||||
],
|
||||
|
||||
"predict_retrosynthesis": [
|
||||
"如何合成阿司匹林?请给出可能的合成路线",
|
||||
"对于分子CC(=O)OC1=CC=CC=C1C(=O)O,有哪些可能的合成路径?",
|
||||
"请分析一下布洛芬的可能合成路线"
|
||||
],
|
||||
|
||||
"predict_biocatalytic_retrosynthesis": [
|
||||
"有没有可能用酶催化合成OC1C(O)C=C(Br)C=C1这个分子?",
|
||||
"请提供一种使用生物催化方法合成对羟基苯甲醇的路线",
|
||||
"我想用酶催化方法合成一些复杂分子,能否分析一下OC1C(O)C=C(Br)C=C1的可能合成路径?"
|
||||
],
|
||||
|
||||
"predict_reaction_properties": [
|
||||
"在这个反应中:CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F,原子是如何映射的?",
|
||||
"这个反应的产率可能是多少:Clc1ccccn1.Cc1ccc(N)cc1>>Cc1ccc(Nc2ccccn2)cc1",
|
||||
"能分析一下这个反应中原子的去向吗:CC(=O)O.CCO>>CC(=O)OCC"
|
||||
],
|
||||
|
||||
"extract_reaction_actions": [
|
||||
"能否将这段实验描述转换为结构化的反应步骤?'To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.'",
|
||||
"请从这段文本中提取出具体的反应操作步骤:'A solution of benzoic acid (1.0 g, 8.2 mmol) in thionyl chloride (10 mL) was heated under reflux for 2 hours. The excess thionyl chloride was removed under reduced pressure to give benzoyl chloride as a colorless liquid.'",
|
||||
"帮我解析这个实验步骤,提取出关键操作:'The aldehyde (5 mmol) was dissolved in methanol (20 mL) and sodium borohydride (7.5 mmol) was added portionwise at 0°C. The mixture was allowed to warm to room temperature and stirred for 3 hours.'"
|
||||
]
|
||||
}
|
||||
|
||||
return questions
|
||||
|
||||
|
||||
def update_agent_test_questions():
|
||||
"""
|
||||
更新agent_test.py中的工具问题字典,添加RXN工具函数的问题
|
||||
"""
|
||||
try:
|
||||
# 获取agent_test.py文件路径
|
||||
agent_test_path = Path('/home/ubuntu/sas0/lzy/multi_mcp_server/test_tools/agent_test.py')
|
||||
|
||||
# 读取文件内容
|
||||
with open(agent_test_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# 获取RXN工具函数的问题
|
||||
rxn_questions = get_rxn_tool_questions()
|
||||
|
||||
# 检查文件中是否已包含RXN工具函数的问题
|
||||
rxn_tools_exist = any(tool in content for tool in rxn_questions.keys())
|
||||
|
||||
if not rxn_tools_exist:
|
||||
# 找到questions字典的结束位置
|
||||
dict_end_pos = content.find(' return questions')
|
||||
|
||||
if dict_end_pos != -1:
|
||||
# 构建要插入的RXN工具函数问题
|
||||
rxn_questions_str = ""
|
||||
for tool_name, questions_list in rxn_questions.items():
|
||||
rxn_questions_str += f'\n "{tool_name}": [\n'
|
||||
for q in questions_list:
|
||||
rxn_questions_str += f' "{q}",\n'
|
||||
rxn_questions_str += ' ],'
|
||||
|
||||
# 在字典结束前插入RXN工具函数问题
|
||||
new_content = content[:dict_end_pos] + rxn_questions_str + content[dict_end_pos:]
|
||||
|
||||
# 写回文件
|
||||
with open(agent_test_path, 'w') as f:
|
||||
f.write(new_content)
|
||||
|
||||
console.print("[green]成功更新agent_test.py,添加了RXN工具函数的测试问题[/green]")
|
||||
else:
|
||||
console.print("[yellow]无法找到questions字典的结束位置,未更新agent_test.py[/yellow]")
|
||||
else:
|
||||
console.print("[yellow]agent_test.py中已包含RXN工具函数的问题,无需更新[/yellow]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]更新agent_test.py时出错:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行所有测试
|
||||
# asyncio.run(test_all())
|
||||
|
||||
# # 更新agent_test.py中的工具问题
|
||||
# update_agent_test_questions()
|
||||
|
||||
api_key = 'apk-8928522a146c2503f30b16d9909222d7583f412ee8f1049f08d32a089ba88d34'
|
||||
from rxn4chemistry import RXN4ChemistryWrapper
|
||||
|
||||
rxn4chemistry_wrapper = RXN4ChemistryWrapper(api_key=api_key)
|
||||
rxn4chemistry_wrapper.create_project('test_wrapper')
|
||||
response = rxn4chemistry_wrapper.predict_automatic_retrosynthesis(
|
||||
'Brc1c2ccccc2c(Br)c2ccccc12')
|
||||
results = rxn4chemistry_wrapper.get_predict_automatic_retrosynthesis_results(response['prediction_id'])
|
||||
print(results['status'])
|
||||
# NOTE: upon 'SUCCESS' you can inspect the predicted retrosynthetic paths.
|
||||
print(results['retrosynthetic_paths'][0])
|
||||
55
test_tools/complex_material_query.py
Normal file
55
test_tools/complex_material_query.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
from test_tools.multi_round_conversation import process_conversation_round
|
||||
|
||||
# 初始化rich控制台
|
||||
console = Console()
|
||||
|
||||
# 设计一个简单但需要多轮查询的问题(可能会调用mattergen)
|
||||
# complex_question = """我想了解LiFePO4材料在不同温度下的性能变化。请先告诉我这种材料的基本结构特性。"""
|
||||
|
||||
# 设计一个不调用mattergen但仍然可以触发多轮工具调用的问题(之前的尝试)
|
||||
# complex_question = """我想比较TiO2和ZnO这两种材料作为光催化剂的性能。请先告诉我TiO2的晶体结构和能带特性。"""
|
||||
|
||||
# 设计一个需要先获取信息然后基于这些信息进行进一步分析的问题
|
||||
complex_question = """我需要分析一种名为Na2Fe2(SO4)3的钠离子电池材料。请先查询这种材料的晶体结构。"""
|
||||
|
||||
def run_complex_query():
|
||||
"""运行复杂的材料科学查询演示"""
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]复杂材料科学查询演示[/bold cyan] - 测试多轮对话逻辑",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# 处理复杂问题
|
||||
conversation_history = process_conversation_round(complex_question)
|
||||
|
||||
# 多轮对话循环
|
||||
while True:
|
||||
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit' 或 'quit' 退出[/bold cyan]")
|
||||
user_input = input("> ")
|
||||
|
||||
# 检查是否退出
|
||||
if user_input.lower() in ['exit', 'quit', '退出']:
|
||||
console.print("[bold cyan]演示结束,再见![/bold cyan]")
|
||||
break
|
||||
|
||||
# 处理用户输入
|
||||
conversation_history = process_conversation_round(user_input, conversation_history)
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
run_complex_query()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
|
||||
import traceback
|
||||
console.print(traceback.format_exc())
|
||||
375
test_tools/demo_conversation.py
Normal file
375
test_tools/demo_conversation.py
Normal file
@@ -0,0 +1,375 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}]
|
||||
# 初始化rich控制台
|
||||
console = Console()
|
||||
|
||||
# 获取工具模式和映射
|
||||
tools_schemas = get_domain_tool_schemas(["material", 'general'])
|
||||
tool_map = get_domain_tools(["material", 'general'])
|
||||
|
||||
# API配置
|
||||
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url = "http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
def get_t1_response(messages):
|
||||
"""获取T1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]正在调用MARS-T1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
reasoning_content = choice.message.content
|
||||
|
||||
tool_calls_list = []
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
|
||||
return reasoning_content, tool_calls_list
|
||||
|
||||
async def execute_tool(tool_name, tool_arguments):
|
||||
"""执行工具调用"""
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
def get_all_tool_calls_results(tool_calls_list):
|
||||
"""获取所有工具调用的结果"""
|
||||
all_results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold green]正在执行工具调用..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
task = progress.add_task("", total=len(tool_calls_list))
|
||||
|
||||
for tool_call in tool_calls_list:
|
||||
tool_name = tool_call['function']['name']
|
||||
tool_arguments = tool_call['function']['arguments']
|
||||
|
||||
# 显示当前执行的工具
|
||||
progress.update(task, description=f"执行 {tool_name}")
|
||||
|
||||
result = asyncio.run(execute_tool(tool_name, tool_arguments))
|
||||
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
|
||||
all_results.append(result_str)
|
||||
|
||||
# 更新进度
|
||||
progress.update(task, advance=1)
|
||||
|
||||
return all_results
|
||||
|
||||
def get_response_from_r1(messages):
|
||||
"""获取R1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold purple]正在调用MARS-R1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-R1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
return choice.message.content
|
||||
|
||||
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
|
||||
"""显示单条消息"""
|
||||
title = role.capitalize()
|
||||
if model:
|
||||
title = f"{model} {title}"
|
||||
|
||||
if role == "user":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[{title_style}]{title}[/{title_style}]",
|
||||
border_style=border_style,
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-T1":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold yellow]{title}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
elif role == "tool":
|
||||
# 创建一个表格来显示工具调用结果
|
||||
table = Table(box=box.ROUNDED, expand=False, show_header=False)
|
||||
table.add_column("内容", style="green")
|
||||
|
||||
# 分割工具调用结果并添加到表格
|
||||
results = content.split("\n")
|
||||
for result in results:
|
||||
table.add_row(result)
|
||||
|
||||
console.print(Panel(
|
||||
table,
|
||||
title=f"[bold green]{title}[/bold green]",
|
||||
border_style="green",
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-R1":
|
||||
try:
|
||||
# 尝试将内容解析为Markdown
|
||||
md = Markdown(content)
|
||||
console.print(Panel(
|
||||
md,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
except:
|
||||
# 如果解析失败,直接显示文本
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
|
||||
def process_conversation_round(user_input, conversation_history=None):
|
||||
"""处理一轮对话,返回更新后的对话历史"""
|
||||
if conversation_history is None:
|
||||
conversation_history = []
|
||||
|
||||
# 添加用户消息到历史
|
||||
conversation_history.append({
|
||||
"role": "user",
|
||||
"content": user_input
|
||||
})
|
||||
|
||||
# 显示用户消息
|
||||
display_message("user", user_input)
|
||||
|
||||
# 准备发送给T1模型的消息
|
||||
t1_messages = []
|
||||
for msg in conversation_history:
|
||||
if msg["role"] in ["user", "assistant"]:
|
||||
t1_messages.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"]
|
||||
})
|
||||
|
||||
# 获取T1模型的响应
|
||||
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
|
||||
|
||||
# 添加T1推理到历史
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-T1"
|
||||
})
|
||||
|
||||
# 显示T1推理
|
||||
display_message("assistant", reasoning_content, model="MARS-T1")
|
||||
|
||||
# 如果有工具调用,执行并获取结果
|
||||
if tool_calls_list:
|
||||
tool_call_results = get_all_tool_calls_results(tool_calls_list)
|
||||
tool_call_results_str = "\n".join(tool_call_results)
|
||||
|
||||
# 添加工具调用结果到历史
|
||||
conversation_history.append({
|
||||
"role": "tool",
|
||||
"content": tool_call_results_str
|
||||
})
|
||||
|
||||
# 显示工具调用结果
|
||||
display_message("tool", tool_call_results_str)
|
||||
|
||||
# 准备发送给R1模型的消息
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": f"# 信息如下:\n{tool_call_results_str}\n# 问题如下:\n{user_input}"
|
||||
}
|
||||
|
||||
# 获取R1模型的响应
|
||||
r1_response = get_response_from_r1([user_message])
|
||||
|
||||
# 添加R1回答到历史
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": r1_response,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
# 显示R1回答
|
||||
display_message("assistant", r1_response, model="MARS-R1")
|
||||
else:
|
||||
# 如果没有工具调用,直接使用T1的推理作为回答
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
# 显示R1回答(实际上是T1的推理)
|
||||
display_message("assistant", reasoning_content, model="MARS-R1")
|
||||
|
||||
return conversation_history
|
||||
|
||||
def run_demo():
|
||||
"""运行演示,使用初始消息作为第一个用户问题"""
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# 获取初始用户问题
|
||||
initial_user_input = initial_message[0]["content"]
|
||||
|
||||
# 处理第一轮对话
|
||||
conversation_history = process_conversation_round(initial_user_input)
|
||||
|
||||
# 多轮对话循环
|
||||
while True:
|
||||
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit' 或 'quit' 退出[/bold cyan]")
|
||||
user_input = input("> ")
|
||||
|
||||
# 检查是否退出
|
||||
if user_input.lower() in ['exit', 'quit', '退出']:
|
||||
console.print("[bold cyan]演示结束,再见![/bold cyan]")
|
||||
break
|
||||
|
||||
# 处理用户输入
|
||||
conversation_history = process_conversation_round(user_input, conversation_history)
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
run_demo()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
|
||||
import traceback
|
||||
console.print(traceback.format_exc())
|
||||
7
test_tools/general/test_searxng.py
Normal file
7
test_tools/general/test_searxng.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server/')
|
||||
from sci_mcp.general_mcp.searxng_query.searxng_query_tools import search_online
|
||||
import asyncio
|
||||
|
||||
# 字典
|
||||
print(asyncio.run(search_online("CsPbBr3", 5)))
|
||||
14
test_tools/material/mattergen/extract_cif.py
Normal file
14
test_tools/material/mattergen/extract_cif.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
from sci_mcp.material_mcp.mattergen_gen.mattergen_service import format_cif_content
|
||||
|
||||
cif_zip_path = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/material/20250508110506/generated_crystals_cif.zip'
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
cif_content = f.read().decode('utf-8', errors='replace')
|
||||
print(format_cif_content(cif_content))
|
||||
|
||||
|
||||
|
||||
|
||||
166
test_tools/material/mattergen/test_mattergen.py
Normal file
166
test_tools/material/mattergen/test_mattergen.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Test script for mattergen_gen/material_gen_tools.py
|
||||
|
||||
This script tests the generate_material function from the material_gen_tools module.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import unittest
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from sci_mcp.material_mcp.mattergen_gen.material_gen_tools import generate_material_MatterGen
|
||||
|
||||
|
||||
class TestMatterGen(unittest.TestCase):
|
||||
"""Test cases for MatterGen material generation tools."""
|
||||
|
||||
def test_unconditional_generation(self):
|
||||
"""Test unconditional crystal structure generation."""
|
||||
# 无条件生成(不指定属性)
|
||||
result = generate_material_MatterGen(properties=None, batch_size=1, num_batches=1)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
# 检查结果是否包含一些常见的描述性文本
|
||||
self.assertIn("Material", result)
|
||||
self.assertIn("structures", result)
|
||||
|
||||
print("无条件生成结果示例:")
|
||||
print(result[:500] + "...\n" if len(result) > 500 else result)
|
||||
|
||||
return result
|
||||
|
||||
# def test_single_property_generation(self):
|
||||
# """Test crystal structure generation with a single property constraint."""
|
||||
# # 单属性条件生成 - 使用化学系统属性
|
||||
# properties = {"chemical_system": "Si-O"}
|
||||
# result = generate_material(properties=properties, batch_size=1, num_batches=1)
|
||||
|
||||
# # 验证结果是否包含预期的关键信息
|
||||
# self.assertIsInstance(result, str)
|
||||
# # 检查结果是否包含相关的化学元素
|
||||
# self.assertIn("Si-O", result)
|
||||
|
||||
# print("单属性条件生成结果示例:")
|
||||
# print(result[:500] + "...\n" if len(result) > 500 else result)
|
||||
|
||||
# return result
|
||||
|
||||
# def test_multi_property_generation(self):
|
||||
# """Test crystal structure generation with multiple property constraints."""
|
||||
# # 多属性条件生成
|
||||
# properties = {
|
||||
# "chemical_system": "Fe-O",
|
||||
# "space_group": 227 # 立方晶系,空间群Fd-3m
|
||||
# }
|
||||
# result = generate_material(properties=properties, batch_size=1, num_batches=1)
|
||||
|
||||
# # 验证结果是否为字符串
|
||||
# self.assertIsInstance(result, str)
|
||||
|
||||
# # 检查结果 - 可能是成功生成或错误信息
|
||||
# if "Error" in result:
|
||||
# # 如果是错误信息,验证它包含相关的属性信息
|
||||
# self.assertIn("properties", result)
|
||||
# print("多属性条件生成返回错误 (这是预期的,因为可能不支持多属性):")
|
||||
# else:
|
||||
# # 如果成功,验证包含相关元素
|
||||
# self.assertIn("Fe", result)
|
||||
# self.assertIn("O", result)
|
||||
# print("多属性条件生成成功:")
|
||||
|
||||
# print(result[:500] + "...\n" if len(result) > 500 else result)
|
||||
|
||||
# return result
|
||||
|
||||
# def test_batch_generation(self):
|
||||
# """Test generating multiple structures in batches."""
|
||||
# # 测试批量生成
|
||||
# result = generate_material(properties=None, batch_size=2, num_batches=2)
|
||||
|
||||
# # 验证结果是否包含预期的关键信息
|
||||
# self.assertIsInstance(result, str)
|
||||
|
||||
# # 检查结果是否提到了批量生成
|
||||
# self.assertIn("structures", result)
|
||||
|
||||
# print("批量生成结果示例:")
|
||||
# print(result[:500] + "...\n" if len(result) > 500 else result)
|
||||
|
||||
# return result
|
||||
|
||||
# def test_guidance_factor(self):
|
||||
# """Test the effect of diffusion guidance factor."""
|
||||
# # 测试不同的diffusion_guidance_factor值
|
||||
# properties = {"chemical_system": "Al-O"}
|
||||
|
||||
# # 使用较低的指导因子
|
||||
# result_low = generate_material(
|
||||
# properties=properties,
|
||||
# batch_size=1,
|
||||
# num_batches=1,
|
||||
# diffusion_guidance_factor=1.0
|
||||
# )
|
||||
|
||||
# # 使用较高的指导因子
|
||||
# result_high = generate_material(
|
||||
# properties=properties,
|
||||
# batch_size=1,
|
||||
# num_batches=1,
|
||||
# diffusion_guidance_factor=3.0
|
||||
# )
|
||||
|
||||
# # 验证两个结果都是有效的
|
||||
# self.assertIsInstance(result_low, str)
|
||||
# self.assertIsInstance(result_high, str)
|
||||
# self.assertIn("Al-O", result_low)
|
||||
# self.assertIn("Al-O", result_high)
|
||||
|
||||
# # 验证两个结果都提到了diffusion guidance factor
|
||||
# self.assertIn("guidance factor", result_low)
|
||||
# self.assertIn("guidance factor", result_high)
|
||||
|
||||
# print("不同指导因子的生成结果示例:")
|
||||
# print("低指导因子 (1.0):")
|
||||
# print(result_low[:300] + "...\n" if len(result_low) > 300 else result_low)
|
||||
# print("高指导因子 (3.0):")
|
||||
# print(result_high[:300] + "...\n" if len(result_high) > 300 else result_high)
|
||||
|
||||
# return result_low, result_high
|
||||
|
||||
# def test_invalid_properties(self):
|
||||
# """Test handling of invalid properties."""
|
||||
# # 测试无效属性
|
||||
# invalid_properties = {"invalid_property": "value"}
|
||||
# result = generate_material(properties=invalid_properties)
|
||||
|
||||
# # 验证结果是否为字符串
|
||||
# self.assertIsInstance(result, str)
|
||||
|
||||
# # 检查结果 - 可能返回错误信息或尝试生成
|
||||
# if "Error" in result:
|
||||
# print("无效属性测试返回错误 (预期行为):")
|
||||
# else:
|
||||
# # 如果没有返回错误,至少应该包含我们请求的属性名称
|
||||
# self.assertIn("invalid_property", result)
|
||||
# print("无效属性测试尝试生成:")
|
||||
|
||||
# print(result)
|
||||
|
||||
# return result
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""运行所有测试。"""
|
||||
unittest.main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
73
test_tools/material/test_mgl.py
Normal file
73
test_tools/material/test_mgl.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import asyncio
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
import matgl
|
||||
from sci_mcp.material_mcp.matgl_tools import relax_crystal_structure_M3GNet,predict_formation_energy_M3GNet,run_molecular_dynamics_M3GNet,calculate_single_point_energy_M3GNet
|
||||
print(matgl.get_available_pretrained_models())
|
||||
cif_file_name = 'GdPbGdHGd.cif'
|
||||
cif_content="""data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
"""
|
||||
|
||||
#print(asyncio.run(relax_crystal_structure(cif_file_name)))
|
||||
print(asyncio.run(predict_formation_energy_M3GNet(cif_content)))
|
||||
#print(asyncio.run(calculate_single_point_energy(cif_file_name)))
|
||||
#print(asyncio.run(run_molecular_dynamics(cif_file_name)))
|
||||
|
||||
15
test_tools/material/test_mp.py
Normal file
15
test_tools/material/test_mp.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server/')
|
||||
from sci_mcp_server.material_mcp.mp_query.mp_query_tools import search_material_property_from_material_project,search_crystal_structures_from_materials_project
|
||||
import asyncio
|
||||
from sci_mcp_server.core.llm_tools import set_llm_context, clear_llm_context
|
||||
|
||||
set_llm_context(True)
|
||||
print(asyncio.run(search_material_property_from_material_project("CsPbBr3")))
|
||||
clear_llm_context()
|
||||
print(asyncio.run(search_material_property_from_material_project("CsPbBr3")))
|
||||
|
||||
set_llm_context(True)
|
||||
print(asyncio.run(search_crystal_structures_from_materials_project("CsPbBr3")))
|
||||
clear_llm_context()
|
||||
print(asyncio.run(search_crystal_structures_from_materials_project("CsPbBr3")))
|
||||
143
test_tools/material/test_property_pred.py
Normal file
143
test_tools/material/test_property_pred.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Test script for property_pred_tools.py
|
||||
|
||||
This script tests the predict_properties function from the property_pred_tools module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from sci_mcp.material_mcp.mattersim_pred.property_pred_tools import predict_properties_MatterSim
|
||||
|
||||
|
||||
class TestPropertyPrediction(unittest.TestCase):
|
||||
"""Test cases for property prediction tools."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# 简单的CIF字符串示例 - 硅晶体结构
|
||||
self.simple_cif = """
|
||||
data_Si
|
||||
_cell_length_a 5.43
|
||||
_cell_length_b 5.43
|
||||
_cell_length_c 5.43
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 90
|
||||
_symmetry_space_group_name_H-M 'P 1'
|
||||
_symmetry_Int_Tables_number 1
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Si 0.0 0.0 0.0
|
||||
Si 0.5 0.5 0.0
|
||||
Si 0.5 0.0 0.5
|
||||
Si 0.0 0.5 0.5
|
||||
Si 0.25 0.25 0.25
|
||||
Si 0.75 0.75 0.25
|
||||
Si 0.75 0.25 0.75
|
||||
Si 0.25 0.75 0.75
|
||||
"""
|
||||
|
||||
def test_predict_properties_async(self):
|
||||
"""Test predict_properties function with a simple CIF string (异步版本)."""
|
||||
async def _async_test():
|
||||
result = await predict_properties(self.simple_cif)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Crystal Structure Property Prediction Results", result)
|
||||
self.assertIn("Total Energy (eV):", result)
|
||||
self.assertIn("Energy per Atom (eV/atom):", result)
|
||||
self.assertIn("Forces (eV/Angstrom):", result)
|
||||
self.assertIn("Stress (GPa):", result)
|
||||
self.assertIn("Stress (eV/A^3):", result)
|
||||
|
||||
print("预测结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_predict_properties_sync(self):
|
||||
"""同步方式测试predict_properties函数。"""
|
||||
self.test_predict_properties_async()
|
||||
|
||||
|
||||
class TestPropertyPredictionWithFile(unittest.TestCase):
|
||||
"""使用文件测试属性预测工具。"""
|
||||
|
||||
def setUp(self):
|
||||
"""设置测试夹具,创建临时CIF文件。"""
|
||||
self.temp_cif_path = "temp_test_structure.cif"
|
||||
|
||||
# 简单的CIF内容 - 氧化铝结构
|
||||
cif_content = """
|
||||
data_Al2O3
|
||||
_cell_length_a 4.76
|
||||
_cell_length_b 4.76
|
||||
_cell_length_c 12.99
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 120
|
||||
_symmetry_space_group_name_H-M 'R -3 c'
|
||||
_symmetry_Int_Tables_number 167
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Al 0.0 0.0 0.35
|
||||
Al 0.0 0.0 0.85
|
||||
O 0.31 0.0 0.25
|
||||
"""
|
||||
|
||||
# 创建临时文件
|
||||
with open(self.temp_cif_path, "w") as f:
|
||||
f.write(cif_content)
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试夹具,删除临时文件。"""
|
||||
if os.path.exists(self.temp_cif_path):
|
||||
os.remove(self.temp_cif_path)
|
||||
|
||||
def test_predict_properties_from_file_async(self):
|
||||
"""测试从文件预测属性(异步版本)。"""
|
||||
async def _async_test():
|
||||
result = await predict_properties(self.temp_cif_path)
|
||||
|
||||
# 验证结果
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Crystal Structure Property Prediction Results", result)
|
||||
self.assertIn("Total Energy (eV):", result)
|
||||
|
||||
print("从文件预测结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_predict_properties_from_file_sync(self):
|
||||
"""同步方式测试从文件预测属性。"""
|
||||
self.test_predict_properties_from_file_async()
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""运行所有测试。"""
|
||||
unittest.main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
run_tests()
|
||||
208
test_tools/material/test_pymatgen_cal.py
Normal file
208
test_tools/material/test_pymatgen_cal.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Test script for pymatgen_cal_tools.py
|
||||
|
||||
This script tests the functions from the pymatgen_cal_tools module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from sci_mcp.material_mcp.pymatgen_cal.pymatgen_cal_tools import (
|
||||
calculate_density,
|
||||
get_element_composition,
|
||||
calculate_symmetry
|
||||
)
|
||||
|
||||
|
||||
class TestPymatgenCalculations(unittest.TestCase):
|
||||
"""Test cases for pymatgen calculation tools."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# 简单的CIF字符串示例 - 硅晶体结构
|
||||
self.simple_cif = """
|
||||
data_Si
|
||||
_cell_length_a 5.43
|
||||
_cell_length_b 5.43
|
||||
_cell_length_c 5.43
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 90
|
||||
_symmetry_space_group_name_H-M 'P 1'
|
||||
_symmetry_Int_Tables_number 1
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Si 0.0 0.0 0.0
|
||||
Si 0.5 0.5 0.0
|
||||
Si 0.5 0.0 0.5
|
||||
Si 0.0 0.5 0.5
|
||||
Si 0.25 0.25 0.25
|
||||
Si 0.75 0.75 0.25
|
||||
Si 0.75 0.25 0.75
|
||||
Si 0.25 0.75 0.75
|
||||
"""
|
||||
|
||||
def test_calculate_density_async(self):
|
||||
"""Test calculate_density function with a simple CIF string (异步版本)."""
|
||||
async def _async_test():
|
||||
result = await calculate_density(self.simple_cif)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Density Calculation", result)
|
||||
self.assertIn("Density", result)
|
||||
self.assertIn("g/cm³", result)
|
||||
|
||||
print("密度计算结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_get_element_composition_async(self):
|
||||
"""Test get_element_composition function with a simple CIF string (异步版本)."""
|
||||
async def _async_test():
|
||||
result = await get_element_composition(self.simple_cif)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Element Composition", result)
|
||||
self.assertIn("Composition", result)
|
||||
self.assertIn("Si", result)
|
||||
|
||||
print("元素组成计算结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_calculate_symmetry_async(self):
|
||||
"""Test calculate_symmetry function with a simple CIF string (异步版本)."""
|
||||
async def _async_test():
|
||||
result = await calculate_symmetry(self.simple_cif)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Symmetry Information", result)
|
||||
self.assertIn("Space Group", result)
|
||||
self.assertIn("Number", result)
|
||||
|
||||
print("对称性计算结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
|
||||
class TestPymatgenCalculationsWithFile(unittest.TestCase):
|
||||
"""使用文件测试pymatgen计算工具。"""
|
||||
|
||||
def setUp(self):
|
||||
"""设置测试夹具,创建临时CIF文件。"""
|
||||
self.temp_cif_path = "temp_test_structure.cif"
|
||||
|
||||
# 简单的CIF内容 - 氧化铝结构
|
||||
cif_content = """
|
||||
data_Al2O3
|
||||
_cell_length_a 4.76
|
||||
_cell_length_b 4.76
|
||||
_cell_length_c 12.99
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 120
|
||||
_symmetry_space_group_name_H-M 'R -3 c'
|
||||
_symmetry_Int_Tables_number 167
|
||||
_symmetry_equiv_pos_as_xyz 'x, y, z'
|
||||
_symmetry_equiv_pos_as_xyz '-x, -y, -z'
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Al1 Al 0.0 0.0 0.35 1.0
|
||||
Al2 Al 0.0 0.0 0.85 1.0
|
||||
O1 O 0.31 0.0 0.25 1.0
|
||||
"""
|
||||
|
||||
# 创建临时文件
|
||||
with open(self.temp_cif_path, "w") as f:
|
||||
f.write(cif_content)
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试夹具,删除临时文件。"""
|
||||
if os.path.exists(self.temp_cif_path):
|
||||
os.remove(self.temp_cif_path)
|
||||
|
||||
def test_calculate_density_from_file_async(self):
|
||||
"""测试从文件计算密度(异步版本)。"""
|
||||
async def _async_test():
|
||||
result = await calculate_density(self.temp_cif_path)
|
||||
|
||||
# 验证结果
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Density Calculation", result)
|
||||
self.assertIn("Density", result)
|
||||
|
||||
print("从文件计算密度结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_get_element_composition_from_file_async(self):
|
||||
"""测试从文件获取元素组成(异步版本)。"""
|
||||
async def _async_test():
|
||||
result = await get_element_composition(self.temp_cif_path)
|
||||
|
||||
# 验证结果
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Element Composition", result)
|
||||
self.assertIn("Composition", result)
|
||||
|
||||
print("从文件获取元素组成结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_calculate_symmetry_from_file_async(self):
|
||||
"""测试从文件计算对称性(异步版本)。"""
|
||||
async def _async_test():
|
||||
result = await calculate_symmetry(self.temp_cif_path)
|
||||
|
||||
# 验证结果
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Symmetry Information", result)
|
||||
self.assertIn("Space Group", result)
|
||||
|
||||
print("从文件计算对称性结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""运行所有测试。"""
|
||||
unittest.main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
118
test_tools/material/test_structure_opt.py
Normal file
118
test_tools/material/test_structure_opt.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
测试晶体结构优化工具函数
|
||||
|
||||
此脚本测试改进后的optimize_crystal_structure函数,
|
||||
该函数接受单一的file_name_or_content_string参数,可以是文件路径或直接的结构内容。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
sys.path.append("/home/ubuntu/sas0/lzy/multi_mcp_server")
|
||||
|
||||
from sci_mcp.material_mcp.fairchem_structure_opt.structure_opt_tools import optimize_crystal_structure
|
||||
from sci_mcp.core.config import material_config
|
||||
|
||||
# 简单的CIF结构示例
|
||||
SAMPLE_CIF = """
|
||||
data_SrTiO3
|
||||
_cell_length_a 3.905
|
||||
_cell_length_b 3.905
|
||||
_cell_length_c 3.905
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 90
|
||||
_symmetry_space_group_name_H-M 'P m -3 m'
|
||||
_symmetry_Int_Tables_number 221
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Sr1 Sr 0.0 0.0 0.0
|
||||
Ti1 Ti 0.5 0.5 0.5
|
||||
O1 O 0.5 0.5 0.0
|
||||
O2 O 0.5 0.0 0.5
|
||||
O3 O 0.0 0.5 0.5
|
||||
"""
|
||||
|
||||
async def test_with_content():
|
||||
"""测试使用直接结构内容"""
|
||||
print("\n=== 测试使用直接结构内容 ===")
|
||||
result = await optimize_crystal_structure(
|
||||
file_name_or_content_string=SAMPLE_CIF,
|
||||
format_type="cif",
|
||||
optimization_level="quick"
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def test_with_file():
|
||||
"""测试使用文件路径(如果文件存在)"""
|
||||
print("\n=== 测试使用文件路径 ===")
|
||||
# 创建临时CIF文件
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(SAMPLE_CIF)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
result = await optimize_crystal_structure(
|
||||
file_name_or_content_string=tmp_path,
|
||||
format_type="auto",
|
||||
optimization_level="quick"
|
||||
)
|
||||
print(result)
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
async def test_with_temp_file():
|
||||
"""测试使用临时目录中的文件名"""
|
||||
print("\n=== 测试使用临时目录中的文件名 ===")
|
||||
|
||||
# 确保临时目录存在
|
||||
os.makedirs(material_config.TEMP_ROOT, exist_ok=True)
|
||||
|
||||
# 在临时目录中创建文件
|
||||
temp_filename = "test_structure.cif"
|
||||
temp_filepath = os.path.join(material_config.TEMP_ROOT, temp_filename)
|
||||
|
||||
with open(temp_filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(SAMPLE_CIF)
|
||||
|
||||
try:
|
||||
# 只传递文件名,而不是完整路径
|
||||
result = await optimize_crystal_structure(
|
||||
file_name_or_content_string=temp_filename,
|
||||
format_type="auto",
|
||||
optimization_level="quick"
|
||||
)
|
||||
print(result)
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_filepath):
|
||||
os.unlink(temp_filepath)
|
||||
|
||||
async def test_auto_format():
|
||||
"""测试自动格式检测"""
|
||||
print("\n=== 测试自动格式检测 ===")
|
||||
result = await optimize_crystal_structure(
|
||||
file_name_or_content_string=SAMPLE_CIF,
|
||||
format_type="auto"
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def main():
|
||||
"""运行所有测试"""
|
||||
print("测试改进后的optimize_crystal_structure函数")
|
||||
|
||||
await test_with_content()
|
||||
await test_with_file()
|
||||
await test_with_temp_file()
|
||||
await test_auto_format()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
521
test_tools/multi_round_conversation.py
Normal file
521
test_tools/multi_round_conversation.py
Normal file
@@ -0,0 +1,521 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}]
|
||||
# 初始化rich控制台
|
||||
console = Console()
|
||||
|
||||
# 获取工具模式和映射
|
||||
tools_schemas = get_domain_tool_schemas(["material", 'general'])
|
||||
tool_map = get_domain_tools(["material", 'general'])
|
||||
|
||||
# API配置
|
||||
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url = "http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
def get_t1_response(messages):
|
||||
"""获取T1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]正在调用MARS-T1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
reasoning_content = choice.message.content
|
||||
|
||||
tool_calls_list = []
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
|
||||
return reasoning_content, tool_calls_list
|
||||
|
||||
async def execute_tool(tool_name, tool_arguments):
|
||||
"""执行工具调用"""
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
def get_all_tool_calls_results(tool_calls_list):
|
||||
"""获取所有工具调用的结果"""
|
||||
all_results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold green]正在执行工具调用..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
task = progress.add_task("", total=len(tool_calls_list))
|
||||
|
||||
for tool_call in tool_calls_list:
|
||||
tool_name = tool_call['function']['name']
|
||||
tool_arguments = tool_call['function']['arguments']
|
||||
|
||||
# 显示当前执行的工具
|
||||
progress.update(task, description=f"执行 {tool_name}")
|
||||
|
||||
result = asyncio.run(execute_tool(tool_name, tool_arguments))
|
||||
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
|
||||
all_results.append(result_str)
|
||||
|
||||
# 更新进度
|
||||
progress.update(task, advance=1)
|
||||
|
||||
return all_results
|
||||
|
||||
def get_response_from_r1(messages):
|
||||
"""获取R1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold purple]正在调用MARS-R1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-R1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
return choice.message.content
|
||||
|
||||
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
|
||||
"""显示单条消息"""
|
||||
title = role.capitalize()
|
||||
if model:
|
||||
title = f"{model} {title}"
|
||||
|
||||
if role == "user":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[{title_style}]{title}[/{title_style}]",
|
||||
border_style=border_style,
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-T1":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold yellow]{title}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
elif role == "tool":
|
||||
# 创建一个表格来显示工具调用结果
|
||||
table = Table(box=box.ROUNDED, expand=False, show_header=False)
|
||||
table.add_column("内容", style="green")
|
||||
|
||||
# 分割工具调用结果并添加到表格
|
||||
results = content.split("\n")
|
||||
for result in results:
|
||||
table.add_row(result)
|
||||
|
||||
console.print(Panel(
|
||||
table,
|
||||
title=f"[bold green]{title}[/bold green]",
|
||||
border_style="green",
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-R1":
|
||||
try:
|
||||
# 尝试将内容解析为Markdown
|
||||
md = Markdown(content)
|
||||
console.print(Panel(
|
||||
md,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
except:
|
||||
# 如果解析失败,直接显示文本
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
|
||||
def process_conversation_round(user_input, conversation_history=None):
|
||||
"""处理一轮对话,返回更新后的对话历史"""
|
||||
if conversation_history is None:
|
||||
conversation_history = []
|
||||
|
||||
# 添加用户消息到外部历史
|
||||
conversation_history.append({
|
||||
"role": "user",
|
||||
"content": user_input
|
||||
})
|
||||
|
||||
# 显示用户消息
|
||||
display_message("user", user_input)
|
||||
|
||||
# 内部循环变量
|
||||
max_iterations = 3 # 防止无限循环
|
||||
iterations = 0
|
||||
|
||||
# 分别管理T1和R1的对话历史
|
||||
t1_messages = []
|
||||
r1_messages = []
|
||||
|
||||
# 初始化T1消息历史(从外部历史中提取用户和助手消息)
|
||||
for msg in conversation_history:
|
||||
if msg["role"] in ["user", "assistant"]:
|
||||
t1_messages.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"]
|
||||
})
|
||||
|
||||
# 当前问题(初始为用户输入)
|
||||
current_question = user_input
|
||||
|
||||
while iterations < max_iterations:
|
||||
iterations += 1
|
||||
|
||||
# 如果不是第一次迭代,添加R1生成的后续问题作为新的用户消息
|
||||
if iterations > 1:
|
||||
# 显示后续问题
|
||||
display_message("user", f"[后续问题] {current_question}")
|
||||
|
||||
# 添加到T1消息历史
|
||||
t1_messages.append({
|
||||
"role": "user",
|
||||
"content": current_question
|
||||
})
|
||||
|
||||
# 获取T1模型的响应
|
||||
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
|
||||
|
||||
# 显示T1推理
|
||||
display_message("assistant", reasoning_content, model="MARS-T1")
|
||||
|
||||
# 添加T1的回答到外部历史
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-T1"
|
||||
})
|
||||
|
||||
# 添加T1的回答到T1消息历史
|
||||
t1_messages.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content
|
||||
})
|
||||
|
||||
# 如果没有工具调用,使用T1的推理作为最终答案
|
||||
if not tool_calls_list:
|
||||
# 添加相同的回答作为R1的回答(因为没有工具调用)
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
display_message("assistant", reasoning_content, model="MARS-R1")
|
||||
break
|
||||
|
||||
# 执行工具调用并获取结果
|
||||
tool_call_results = get_all_tool_calls_results(tool_calls_list)
|
||||
tool_call_results_str = "\n".join(tool_call_results)
|
||||
|
||||
# 添加工具调用结果到外部历史
|
||||
conversation_history.append({
|
||||
"role": "tool",
|
||||
"content": tool_call_results_str
|
||||
})
|
||||
|
||||
# 显示工具调用结果
|
||||
display_message("tool", tool_call_results_str)
|
||||
|
||||
# 重置R1消息历史(每次迭代都重新构建)
|
||||
r1_messages = []
|
||||
|
||||
# 添加系统消息,指导R1如何处理信息
|
||||
r1_messages.append({
|
||||
"role": "system",
|
||||
"content": """你是一个能够分析工具调用结果并回答问题的助手。
|
||||
请分析提供的信息,并执行以下操作之一:
|
||||
1. 如果你能够基于提供的工具调用信息直接回答原始问题,请提供完整的回答。
|
||||
2. 如果目前的工具调用信息不足以让你回答原始问题,请明确说明缺少哪些信息,并生成一个新的问题来获取这些信息。
|
||||
新问题格式:<FOLLOW_UP_QUESTION>你的问题</FOLLOW_UP_QUESTION>
|
||||
|
||||
注意:如果你生成了后续问题,系统将自动将其发送给工具调用模型以获取更多信息。"""
|
||||
})
|
||||
|
||||
# 构建R1的用户消息,包含原始问题、工具调用信息和结果
|
||||
r1_user_message = f"""# 原始问题
|
||||
{user_input}
|
||||
|
||||
# 工具调用信息
|
||||
{reasoning_content}
|
||||
|
||||
# 工具调用结果
|
||||
{tool_call_results_str}"""
|
||||
|
||||
# 如果有后续问题,添加到R1用户消息
|
||||
if iterations > 1:
|
||||
r1_user_message += f"\n\n# 后续问题\n{current_question}"
|
||||
|
||||
# 添加构建好的用户消息
|
||||
r1_messages.append({
|
||||
"role": "user",
|
||||
"content": r1_user_message
|
||||
})
|
||||
|
||||
# 获取R1模型的响应
|
||||
r1_response = get_response_from_r1(r1_messages)
|
||||
|
||||
# 显示R1回答
|
||||
display_message("assistant", r1_response, model="MARS-R1")
|
||||
|
||||
# 检查R1是否生成了后续问题
|
||||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', r1_response, re.DOTALL)
|
||||
|
||||
if follow_up_match:
|
||||
# 提取后续问题
|
||||
follow_up_question = follow_up_match.group(1).strip()
|
||||
|
||||
# 将后续问题作为新的当前问题
|
||||
current_question = follow_up_question
|
||||
|
||||
# 添加R1的回答到外部历史(不包括后续问题标记)
|
||||
clean_response = r1_response.replace(follow_up_match.group(0), "")
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": clean_response,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
# 继续循环,使用新问题调用T1
|
||||
else:
|
||||
# R1能够回答问题,添加回答到历史并结束循环
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": r1_response,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
break
|
||||
|
||||
return conversation_history
|
||||
|
||||
def run_demo():
|
||||
"""运行演示,使用初始消息作为第一个用户问题"""
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# 获取初始用户问题
|
||||
initial_user_input = initial_message[0]["content"]
|
||||
|
||||
# 处理第一轮对话
|
||||
conversation_history = process_conversation_round(initial_user_input)
|
||||
|
||||
# 检查R1是否生成了后续问题并自动处理
|
||||
auto_process_follow_up_questions(conversation_history)
|
||||
|
||||
# 多轮对话循环
|
||||
while True:
|
||||
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit' 或 'quit' 退出[/bold cyan]")
|
||||
user_input = input("> ")
|
||||
|
||||
# 检查是否退出
|
||||
if user_input.lower() in ['exit', 'quit', '退出']:
|
||||
console.print("[bold cyan]演示结束,再见![/bold cyan]")
|
||||
break
|
||||
|
||||
# 处理用户输入
|
||||
conversation_history = process_conversation_round(user_input, conversation_history)
|
||||
|
||||
# 检查R1是否生成了后续问题并自动处理
|
||||
auto_process_follow_up_questions(conversation_history)
|
||||
|
||||
def auto_process_follow_up_questions(conversation_history):
|
||||
"""自动处理R1生成的后续问题"""
|
||||
# 检查最后一条消息是否是R1的回答
|
||||
if not conversation_history or len(conversation_history) == 0:
|
||||
return
|
||||
|
||||
last_message = conversation_history[-1]
|
||||
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
|
||||
# 检查是否包含后续问题
|
||||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
|
||||
if follow_up_match:
|
||||
# 提取后续问题
|
||||
follow_up_question = follow_up_match.group(1).strip()
|
||||
|
||||
# 显示检测到的后续问题
|
||||
console.print(Panel(
|
||||
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
|
||||
# 自动处理后续问题
|
||||
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
|
||||
|
||||
# 递归处理后续问题,直到没有更多后续问题或达到最大迭代次数
|
||||
max_auto_iterations = 3
|
||||
current_iterations = 0
|
||||
|
||||
while current_iterations < max_auto_iterations:
|
||||
current_iterations += 1
|
||||
|
||||
# 处理后续问题
|
||||
conversation_history = process_conversation_round(follow_up_question, conversation_history)
|
||||
|
||||
# 检查是否还有后续问题
|
||||
if len(conversation_history) > 0:
|
||||
last_message = conversation_history[-1]
|
||||
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
|
||||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
|
||||
if follow_up_match:
|
||||
# 提取后续问题
|
||||
follow_up_question = follow_up_match.group(1).strip()
|
||||
|
||||
# 显示检测到的后续问题
|
||||
console.print(Panel(
|
||||
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
|
||||
# 自动处理后续问题
|
||||
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
|
||||
continue
|
||||
|
||||
# 如果没有更多后续问题,退出循环
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
run_demo()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
|
||||
import traceback
|
||||
console.print(traceback.format_exc())
|
||||
83
test_tools/test.py
Normal file
83
test_tools/test.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Test script for domain-specific tool retrieval functions
|
||||
"""
|
||||
|
||||
import json
|
||||
from pprint import pprint
|
||||
import sys
|
||||
sys.path.append("/home/ubuntu/sas0/lzy/multi_mcp_server")
|
||||
from sci_mcp.core.llm_tools import get_domain_tools, get_domain_tool_schemas, get_all_tool_schemas,get_all_tools
|
||||
|
||||
def test_get_domain_tools():
|
||||
"""Test retrieving tools from specific domains"""
|
||||
print("\n=== Testing get_domain_tools ===")
|
||||
|
||||
# Test with material and general domains
|
||||
domains = ['material', 'general']
|
||||
print(f"Getting tools for domains: {domains}")
|
||||
|
||||
domain_tools = get_domain_tools(domains)
|
||||
|
||||
# Print results
|
||||
for domain, tools in domain_tools.items():
|
||||
print(f"\nDomain: {domain}")
|
||||
print(f"Number of tools: {len(tools)}")
|
||||
print("Tool names:")
|
||||
for tool_name in tools.keys():
|
||||
print(f" - {tool_name}")
|
||||
|
||||
def test_get_domain_tool_schemas():
|
||||
"""Test retrieving tool schemas from specific domains"""
|
||||
print("\n=== Testing get_domain_tool_schemas (OpenAI format) ===")
|
||||
|
||||
# Test with material and general domains
|
||||
domains = ['material', 'general']
|
||||
print(f"Getting tool schemas for domains: {domains}")
|
||||
|
||||
domain_schemas = get_domain_tool_schemas(domains)
|
||||
|
||||
# Print results
|
||||
for domain, schemas in domain_schemas.items():
|
||||
print(f"\nDomain: {domain}")
|
||||
print(f"Number of schemas: {len(schemas)}")
|
||||
print("Tool names in schemas:")
|
||||
for schema in schemas:
|
||||
print(f" - {schema['function']['name']}")
|
||||
|
||||
|
||||
def test_all_tool_schemas():
|
||||
"""Test retrieving all tool schemas in both formats"""
|
||||
print("\n=== Testing get_all_tool_schemas ===")
|
||||
|
||||
# OpenAI format
|
||||
print("Getting all tool schemas in OpenAI format")
|
||||
openai_schemas = get_all_tool_schemas()
|
||||
print(f"Number of schemas: {len(openai_schemas)}")
|
||||
print("Tool names:")
|
||||
for schema in openai_schemas:
|
||||
print(f" - {schema['function']['name']}")
|
||||
|
||||
# MCP format
|
||||
print("\nGetting all tool schemas in MCP format")
|
||||
mcp_schemas = get_all_tool_schemas(use_mcp_format=True)
|
||||
print(f"Number of schemas: {len(mcp_schemas)}")
|
||||
print("Tool names:")
|
||||
for schema in mcp_schemas:
|
||||
print(f" - {schema['name']}")
|
||||
|
||||
def main():
|
||||
"""Main function to run tests"""
|
||||
print("Testing domain-specific tool retrieval functions")
|
||||
|
||||
#test_get_domain_tools()
|
||||
test_get_domain_tool_schemas()
|
||||
|
||||
test_all_tool_schemas()
|
||||
|
||||
if __name__ == "__main__":
|
||||
#print(get_all_tool_schemas())
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
console.print("Testing domain-specific tool retrieval functions", style="bold green")
|
||||
console.print(get_domain_tool_schemas(['chemistry']))
|
||||
|
||||
253
test_tools/test_mars_t1.py
Normal file
253
test_tools/test_mars_t1.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp_server.core.llm_tools import set_llm_context, clear_llm_context
|
||||
|
||||
# 创建Rich控制台对象
|
||||
console = Console()
|
||||
|
||||
# 定义分隔符样式
|
||||
def print_separator(title=""):
|
||||
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
|
||||
|
||||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
from sci_mcp import *
|
||||
|
||||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||||
tool_map = get_domain_tools(["material",'general'])
|
||||
|
||||
|
||||
# 打印消息的函数
|
||||
def print_message(message):
|
||||
# 处理不同类型的消息对象
|
||||
if hasattr(message, 'role'): # ChatCompletionMessage 对象
|
||||
role = message.role
|
||||
content = message.content if hasattr(message, 'content') else ""
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = None
|
||||
if role == "tool" and hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
else: # 字典类型
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = message.get("name") if role == "tool" else None
|
||||
|
||||
# 根据角色选择不同的颜色
|
||||
role_colors = {
|
||||
"system": "bright_blue",
|
||||
"user": "green",
|
||||
"assistant": "yellow",
|
||||
"tool": "bright_red"
|
||||
}
|
||||
color = role_colors.get(role, "white")
|
||||
|
||||
# 创建富文本面板
|
||||
text = Text()
|
||||
|
||||
# 如果是工具消息,添加工具名称
|
||||
if role == "tool" and tool_name:
|
||||
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
|
||||
else:
|
||||
text.append(f"{role}: ", style=f"bold {color}")
|
||||
|
||||
text.append(str(content))
|
||||
console.print(Panel(text, border_style=color))
|
||||
|
||||
messages = [
|
||||
{"role": "system",
|
||||
"content": "You are MARS-R1, a professional assistant in materials science. You first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> <structured_answer> </structured_answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here <structured_answer> structured answer here <structured_answer> </answer>'"},
|
||||
{"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}
|
||||
]
|
||||
#how to synthesize CsPbBr3 at room temperature
|
||||
#
|
||||
# 打印初始消息
|
||||
print_separator("初始消息")
|
||||
for message in messages:
|
||||
print_message(message)
|
||||
finish_reason = None
|
||||
|
||||
async def execute_tool(tool_name,tool_arguments):
|
||||
# 设置LLM调用上下文标记
|
||||
set_llm_context(True)
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
# if func_name=='generate_material':
|
||||
# print("xxxxx",result)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
clear_llm_context()
|
||||
|
||||
while finish_reason is None or finish_reason == "tool_calls":
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-R1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas, # <-- 我们通过 tools 参数,将定义好的 tools 提交给 Kimi 大模型
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
|
||||
# 打印assistant消息
|
||||
print_separator("Assistant消息")
|
||||
print_message(choice.message)
|
||||
|
||||
# 将ChatCompletionMessage对象转换为字典
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": choice.message.content if hasattr(choice.message, 'content') else None
|
||||
}
|
||||
|
||||
# 如果有工具调用,添加到字典中
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
# 将tool_calls对象转换为字典列表
|
||||
tool_calls_list = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
assistant_message["tool_calls"] = tool_calls_list
|
||||
|
||||
# 添加消息到上下文
|
||||
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
|
||||
|
||||
# 打印工具调用信息
|
||||
print_separator("工具调用")
|
||||
for tool_call in choice.message.tool_calls:
|
||||
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
|
||||
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
|
||||
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
|
||||
console.print("")
|
||||
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object,我们需要使用 json.loads 反序列化一下
|
||||
try:
|
||||
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
|
||||
except Exception as e:
|
||||
tool_result=f'工具调用失败{e}'
|
||||
# 构造工具响应消息
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
|
||||
}
|
||||
|
||||
# 打印工具响应
|
||||
print_separator(f"工具响应: {tool_call_name}")
|
||||
print_message(tool_message)
|
||||
|
||||
# 添加消息到上下文
|
||||
messages.append(tool_message)
|
||||
|
||||
# 打印最终响应
|
||||
if choice.message.content:
|
||||
print_separator("最终响应")
|
||||
console.print(Panel(choice.message.content, border_style="green"))
|
||||
177
test_tools/test_mars_t1_.py
Normal file
177
test_tools/test_mars_t1_.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import asyncio
|
||||
import json
|
||||
from openai import OpenAI
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||||
tool_map = get_domain_tools(["material",'general'])
|
||||
|
||||
|
||||
|
||||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
messages = [ {"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}]
|
||||
|
||||
def get_t1_response(messages):
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
reasoning_content = choice.message.content
|
||||
#print("Reasoning content:", reasoning_content)
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
tool_calls_list = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
return reasoning_content, tool_calls_list
|
||||
|
||||
async def execute_tool(tool_name,tool_arguments):
|
||||
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
# if func_name=='generate_material':
|
||||
# print("xxxxx",result)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
def get_all_tool_calls_results(tool_calls_list):
|
||||
all_results = []
|
||||
for tool_call in tool_calls_list:
|
||||
tool_name = tool_call['function']['name']
|
||||
tool_arguments = tool_call['function']['arguments']
|
||||
result = asyncio.run(execute_tool(tool_name,tool_arguments))
|
||||
result_str = f"[{tool_name} content begin]\n"+result+f"\n[{tool_name} content end]\n"
|
||||
all_results.append(result_str)
|
||||
|
||||
return all_results
|
||||
def get_response_from_r1(messages):
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-R1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
return choice.message.content
|
||||
print("R1 RESPONSE:", choice.message.content)
|
||||
if __name__ == '__main__':
|
||||
reasoning_content, tool_calls_list=get_t1_response(messages)
|
||||
print("Reasoning content:", reasoning_content)
|
||||
tool_call_results=get_all_tool_calls_results(tool_calls_list)
|
||||
tool_call_results_str = "\n".join(tool_call_results)
|
||||
# for tool_call in tool_call_results:
|
||||
# print(tool_call)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": f"# 信息如下:{tool_call_results_str}# 问题如下 {messages[0]['content']}"
|
||||
}
|
||||
print("user_message_for_r1:", user_message)
|
||||
get_response_from_r1([user_message])
|
||||
339
test_tools/test_mars_t1_r1.py
Normal file
339
test_tools/test_mars_t1_r1.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
# 创建Rich控制台对象
|
||||
console = Console()
|
||||
|
||||
# 创建一个列表来存储工具调用结果
|
||||
tool_results = []
|
||||
|
||||
# 定义分隔符样式
|
||||
def print_separator(title=""):
|
||||
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
|
||||
|
||||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
from sci_mcp import *
|
||||
|
||||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||||
tool_map = get_domain_tools(["material",'general'])
|
||||
|
||||
|
||||
# 打印消息的函数
|
||||
def print_message(message):
|
||||
# 处理不同类型的消息对象
|
||||
if hasattr(message, 'role'): # ChatCompletionMessage 对象
|
||||
role = message.role
|
||||
content = message.content if hasattr(message, 'content') else ""
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = None
|
||||
if role == "tool" and hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
else: # 字典类型
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = message.get("name") if role == "tool" else None
|
||||
|
||||
# 根据角色选择不同的颜色
|
||||
role_colors = {
|
||||
"system": "bright_blue",
|
||||
"user": "green",
|
||||
"assistant": "yellow",
|
||||
"tool": "bright_red"
|
||||
}
|
||||
color = role_colors.get(role, "white")
|
||||
|
||||
# 创建富文本面板
|
||||
text = Text()
|
||||
|
||||
# 如果是工具消息,添加工具名称
|
||||
if role == "tool" and tool_name:
|
||||
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
|
||||
else:
|
||||
text.append(f"{role}: ", style=f"bold {color}")
|
||||
|
||||
text.append(str(content))
|
||||
console.print(Panel(text, border_style=color))
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}
|
||||
]
|
||||
#how to synthesize CsPbBr3 at room temperature
|
||||
#
|
||||
# 打印初始消息
|
||||
print_separator("初始消息")
|
||||
for message in messages:
|
||||
print_message(message)
|
||||
finish_reason = None
|
||||
|
||||
async def execute_tool(tool_name,tool_arguments):
|
||||
# 设置LLM调用上下文标记
|
||||
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
# if func_name=='generate_material':
|
||||
# print("xxxxx",result)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
# 定义一个函数来估算消息的token数量(粗略估计)
|
||||
def estimate_tokens(messages):
|
||||
# 简单估计:每个英文单词约1.3个token,每个中文字符约1个token
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "") if isinstance(msg, dict) else (msg.content if hasattr(msg, "content") else "")
|
||||
if content:
|
||||
# 粗略估计内容的token数
|
||||
total += len(content) * 1.3
|
||||
|
||||
# 估计工具调用的token数
|
||||
tool_calls = msg.get("tool_calls", []) if isinstance(msg, dict) else (msg.tool_calls if hasattr(msg, "tool_calls") else [])
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
args = tool_call.get("function", {}).get("arguments", "")
|
||||
total += len(args) * 1.3
|
||||
else:
|
||||
args = tool_call.function.arguments if hasattr(tool_call, "function") else ""
|
||||
total += len(args) * 1.3
|
||||
|
||||
return int(total)
|
||||
|
||||
# 管理消息历史,保持在token限制以内
|
||||
def manage_message_history(messages, max_tokens=7000):
|
||||
# 保留第一条消息(通常是系统消息或初始用户消息)
|
||||
if len(messages) <= 1:
|
||||
return messages
|
||||
|
||||
# 估算当前消息的token数
|
||||
current_tokens = estimate_tokens(messages)
|
||||
|
||||
# 如果当前token数已经接近限制,开始裁剪历史消息
|
||||
if current_tokens > max_tokens:
|
||||
# 保留第一条消息和最近的消息
|
||||
preserved_messages = [messages[0]]
|
||||
|
||||
# 从最新的消息开始添加,直到接近但不超过token限制
|
||||
temp_messages = []
|
||||
for msg in reversed(messages[1:]):
|
||||
temp_messages.insert(0, msg)
|
||||
if estimate_tokens([messages[0]] + temp_messages) > max_tokens:
|
||||
# 如果添加这条消息会超过限制,则停止添加
|
||||
temp_messages.pop(0)
|
||||
break
|
||||
|
||||
# 如果裁剪后的消息太少,至少保留最近的几条消息
|
||||
if len(temp_messages) < 4 and len(messages) > 4:
|
||||
temp_messages = messages[-4:]
|
||||
|
||||
return preserved_messages + temp_messages
|
||||
|
||||
return messages
|
||||
|
||||
while finish_reason is None or finish_reason == "tool_calls":
|
||||
# 在发送请求前管理消息历史
|
||||
managed_messages = manage_message_history(messages)
|
||||
if len(managed_messages) < len(messages):
|
||||
print_separator(f"消息历史已裁剪,从{len(messages)}条减少到{len(managed_messages)}条")
|
||||
messages = managed_messages
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
timeout=120,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
|
||||
# 打印assistant消息
|
||||
print_separator("Assistant消息")
|
||||
print_message(choice.message)
|
||||
|
||||
# 将ChatCompletionMessage对象转换为字典
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": choice.message.content if hasattr(choice.message, 'content') else None
|
||||
}
|
||||
|
||||
# 如果有工具调用,添加到字典中
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
# 将tool_calls对象转换为字典列表
|
||||
tool_calls_list = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
assistant_message["tool_calls"] = tool_calls_list
|
||||
|
||||
# 添加消息到上下文
|
||||
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
|
||||
|
||||
# 打印工具调用信息
|
||||
print_separator("工具调用")
|
||||
for tool_call in choice.message.tool_calls:
|
||||
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
|
||||
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
|
||||
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
|
||||
console.print("")
|
||||
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object,我们需要使用 json.loads 反序列化一下
|
||||
try:
|
||||
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
|
||||
except Exception as e:
|
||||
tool_result=f'工具调用失败{e}'
|
||||
|
||||
# 将工具调用结果保存到单独的列表中,使用指定格式包裹
|
||||
formatted_result = f"[{tool_call_name} content begin]{tool_result}[{tool_call_name} content end]"
|
||||
tool_results.append({
|
||||
"tool_name": tool_call_name,
|
||||
"tool_id": tool_call.id,
|
||||
"formatted_result": formatted_result,
|
||||
"raw_result": tool_result
|
||||
})
|
||||
|
||||
# 打印保存的工具结果信息
|
||||
console.print(f"[bold green]已保存工具结果:[/] [cyan]{tool_call_name}[/]")
|
||||
|
||||
# 构造工具响应消息
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
|
||||
}
|
||||
|
||||
# 打印工具响应
|
||||
print_separator(f"工具响应: {tool_call_name}")
|
||||
print_message(tool_message)
|
||||
|
||||
# 添加消息到上下文
|
||||
#messages.append(tool_message)
|
||||
|
||||
# 打印最终响应
|
||||
if choice.message.content:
|
||||
print_separator("最终响应")
|
||||
console.print(Panel(choice.message.content, border_style="green"))
|
||||
|
||||
# 打印收集的所有工具调用结果
|
||||
if tool_results:
|
||||
print_separator("所有工具调用结果")
|
||||
console.print(f"[bold cyan]共收集了 {len(tool_results)} 个工具调用结果[/]")
|
||||
|
||||
# 将所有格式化的结果写入文件
|
||||
with open("tool_results.txt", "w", encoding="utf-8") as f:
|
||||
for result in tool_results:
|
||||
f.write(f"{result['formatted_result']}\n\n")
|
||||
|
||||
console.print(f"[bold green]所有工具调用结果已保存到 tool_results.txt 文件中[/]")
|
||||
Reference in New Issue
Block a user