182 lines
7.5 KiB
Python
Executable File
182 lines
7.5 KiB
Python
Executable File
import asyncio
|
|
import json
|
|
from rich.console import Console
|
|
|
|
console = Console()
|
|
|
|
async def test_tool(tool_name: str) -> str:
|
|
"""
|
|
测试指定的工具函数是否能正常被调用
|
|
|
|
Args:
|
|
tool_name: 工具函数的名称
|
|
|
|
Returns:
|
|
测试结果信息
|
|
"""
|
|
try:
|
|
print(f"开始测试工具: {tool_name}")
|
|
|
|
if tool_name == "get_current_time":
|
|
from mars_toolkit.misc.misc_tools import get_current_time
|
|
result = await get_current_time(timezone="Asia/Shanghai")
|
|
|
|
elif tool_name == "search_online":
|
|
from mars_toolkit.query.web_search import search_online
|
|
result = await search_online(query="material science", num_results=2)
|
|
|
|
elif tool_name == "search_material_property_from_material_project":
|
|
from mars_toolkit.query.mp_query import search_material_property_from_material_project
|
|
result = await search_material_property_from_material_project(formula="Fe2O3")
|
|
|
|
elif tool_name == "get_crystal_structures_from_materials_project":
|
|
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
|
|
result = await get_crystal_structures_from_materials_project(formulas=["Fe2O3"])
|
|
|
|
elif tool_name == "get_mpid_from_formula":
|
|
from mars_toolkit.query.mp_query import get_mpid_from_formula
|
|
result = await get_mpid_from_formula(formula=["Fe2O3"])
|
|
|
|
elif tool_name == "optimize_crystal_structure":
|
|
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
|
|
# 使用一个简单的CIF字符串作为测试输入
|
|
simple_cif = """
|
|
data_simple
|
|
_cell_length_a 4.0
|
|
_cell_length_b 4.0
|
|
_cell_length_c 4.0
|
|
_cell_angle_alpha 90
|
|
_cell_angle_beta 90
|
|
_cell_angle_gamma 90
|
|
_symmetry_space_group_name_H-M 'P 1'
|
|
loop_
|
|
_atom_site_label
|
|
_atom_site_fract_x
|
|
_atom_site_fract_y
|
|
_atom_site_fract_z
|
|
Si 0.0 0.0 0.0
|
|
O 0.25 0.25 0.25
|
|
"""
|
|
result = await optimize_crystal_structure(content=simple_cif, input_format="cif")
|
|
|
|
elif tool_name == "generate_material":
|
|
from mars_toolkit.compute.material_gen import generate_material
|
|
# 使用简单的属性约束进行测试
|
|
# result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
|
|
result = generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
|
|
|
|
elif tool_name == "fetch_chemical_composition_from_OQMD":
|
|
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
|
result = await fetch_chemical_composition_from_OQMD(composition="Fe2O3")
|
|
|
|
elif tool_name == "retrieval_from_knowledge_base":
|
|
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
|
result = await retrieval_from_knowledge_base(query="CsPbBr3", topk=3)
|
|
|
|
elif tool_name == "predict_properties":
|
|
from mars_toolkit.compute.property_pred import predict_properties
|
|
# 使用一个简单的CsPbBr3结构CIF字符串作为测试输入
|
|
_cif = """
|
|
# generated using pymatgen
|
|
data_CsPbBr3
|
|
_symmetry_space_group_name_H-M 'P 1'
|
|
_cell_length_a 8.37036600
|
|
_cell_length_b 8.42533500
|
|
_cell_length_c 12.01129500
|
|
_cell_angle_alpha 90.00000000
|
|
_cell_angle_beta 90.00000000
|
|
_cell_angle_gamma 90.00000000
|
|
_symmetry_Int_Tables_number 1
|
|
_chemical_formula_structural CsPbBr3
|
|
_chemical_formula_sum 'Cs4 Pb4 Br12'
|
|
_cell_volume 847.07421031
|
|
_cell_formula_units_Z 4
|
|
loop_
|
|
_symmetry_equiv_pos_site_id
|
|
_symmetry_equiv_pos_as_xyz
|
|
1 'x, y, z'
|
|
loop_
|
|
_atom_site_type_symbol
|
|
_atom_site_label
|
|
_atom_site_symmetry_multiplicity
|
|
_atom_site_fract_x
|
|
_atom_site_fract_y
|
|
_atom_site_fract_z
|
|
_atom_site_occupancy
|
|
Cs Cs0 1 0.50831300 0.46818500 0.25000000 1
|
|
Cs Cs1 1 0.00831300 0.03181500 0.75000000 1
|
|
Cs Cs2 1 0.99168700 0.96818500 0.25000000 1
|
|
Cs Cs3 1 0.49168700 0.53181500 0.75000000 1
|
|
Pb Pb4 1 0.50000000 0.00000000 0.50000000 1
|
|
Pb Pb5 1 0.00000000 0.50000000 0.00000000 1
|
|
Pb Pb6 1 0.00000000 0.50000000 0.50000000 1
|
|
Pb Pb7 1 0.50000000 0.00000000 0.00000000 1
|
|
Br Br8 1 0.54824500 0.99370800 0.75000000 1
|
|
Br Br9 1 0.04824500 0.50629200 0.25000000 1
|
|
Br Br10 1 0.79480800 0.20538600 0.02568800 1
|
|
Br Br11 1 0.20519200 0.79461400 0.97431200 1
|
|
Br Br12 1 0.29480800 0.29461400 0.97431200 1
|
|
Br Br13 1 0.70519200 0.70538600 0.02568800 1
|
|
Br Br14 1 0.29480800 0.29461400 0.52568800 1
|
|
Br Br15 1 0.70519200 0.70538600 0.47431200 1
|
|
Br Br16 1 0.20519200 0.79461400 0.52568800 1
|
|
Br Br17 1 0.79480800 0.20538600 0.47431200 1
|
|
Br Br18 1 0.95175500 0.49370800 0.75000000 1
|
|
Br Br19 1 0.45175500 0.00629200 0.25000000 1
|
|
"""
|
|
result = await predict_properties(cif_content=_cif)
|
|
|
|
else:
|
|
return f"未知工具: {tool_name}"
|
|
|
|
print(f"工具 {tool_name} 测试完成")
|
|
return f"工具 {tool_name} 测试成功,返回结果类型: {type(result)}, 返回的结果: {result}"
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
error_details = traceback.format_exc()
|
|
return f"工具 {tool_name} 测试失败: {str(e)}\n{error_details}"
|
|
|
|
|
|
def print_tool_schemas():
|
|
"""打印所有注册的工具函数的JSON模式"""
|
|
import mars_toolkit
|
|
schemas = mars_toolkit.get_tool_schemas()
|
|
console.print("[bold green]已注册的工具函数列表:[/bold green]")
|
|
for i, schema in enumerate(schemas, 1):
|
|
console.print(f"[bold cyan]工具 {i}:[/bold cyan] {schema['function']['name']}")
|
|
console.print(f"[bold yellow]描述:[/bold yellow] {schema['function']['description']}")
|
|
console.print("[bold magenta]参数:[/bold magenta]")
|
|
for param_name, param_info in schema['function']['parameters']['properties'].items():
|
|
required = "必需" if param_name in schema['function']['parameters'].get('required', []) else "可选"
|
|
console.print(f" - [bold]{param_name}[/bold] ({required}): {param_info.get('description', '无描述')}")
|
|
console.print("")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 打印所有工具函数的模式
|
|
print_tool_schemas()
|
|
|
|
# 测试工具函数列表
|
|
tools_to_test = [
|
|
"get_current_time", # 0基础工具
|
|
"search_online", # 1网络搜索工具
|
|
"search_material_property_from_material_project", # 2材料项目查询工具
|
|
"get_crystal_structures_from_materials_project", # 3晶体结构查询工具
|
|
"get_mpid_from_formula", # 4材料ID查询工具
|
|
"optimize_crystal_structure", # 5晶体结构优化工具
|
|
"generate_material", # 6材料生成工具
|
|
"fetch_chemical_composition_from_OQMD", # 7OQMD查询工具
|
|
"retrieval_from_knowledge_base", # 8知识库检索工具
|
|
"predict_properties" # 9属性预测工具
|
|
]
|
|
|
|
# 选择要测试的工具
|
|
tool_name = tools_to_test[7] # 测试 search_online 工具
|
|
|
|
# 运行测试
|
|
result = asyncio.run(test_tool(tool_name))
|
|
print(result)
|
|
console.print(f"[bold blue]测试结果:[/bold blue]")
|
|
console.print(result)
|