构建mars_toolkit,删除tools_for_ms
This commit is contained in:
180
test_mars_toolkit.py
Normal file
180
test_mars_toolkit.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import asyncio
|
||||
import json
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
async def test_tool(tool_name: str) -> str:
|
||||
"""
|
||||
测试指定的工具函数是否能正常被调用
|
||||
|
||||
Args:
|
||||
tool_name: 工具函数的名称
|
||||
|
||||
Returns:
|
||||
测试结果信息
|
||||
"""
|
||||
try:
|
||||
print(f"开始测试工具: {tool_name}")
|
||||
|
||||
if tool_name == "get_current_time":
|
||||
from mars_toolkit.misc.misc_tools import get_current_time
|
||||
result = await get_current_time(timezone="Asia/Shanghai")
|
||||
|
||||
elif tool_name == "search_online":
|
||||
from mars_toolkit.query.web_search import search_online
|
||||
result = await search_online(query="material science", num_results=2)
|
||||
|
||||
elif tool_name == "search_material_property_from_material_project":
|
||||
from mars_toolkit.query.mp_query import search_material_property_from_material_project
|
||||
result = await search_material_property_from_material_project(formula="Fe2O3")
|
||||
|
||||
elif tool_name == "get_crystal_structures_from_materials_project":
|
||||
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
|
||||
result = await get_crystal_structures_from_materials_project(formulas=["Fe2O3"])
|
||||
|
||||
elif tool_name == "get_mpid_from_formula":
|
||||
from mars_toolkit.query.mp_query import get_mpid_from_formula
|
||||
result = await get_mpid_from_formula(formula=["Fe2O3"])
|
||||
|
||||
elif tool_name == "optimize_crystal_structure":
|
||||
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
|
||||
# 使用一个简单的CIF字符串作为测试输入
|
||||
simple_cif = """
|
||||
data_simple
|
||||
_cell_length_a 4.0
|
||||
_cell_length_b 4.0
|
||||
_cell_length_c 4.0
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 90
|
||||
_symmetry_space_group_name_H-M 'P 1'
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Si 0.0 0.0 0.0
|
||||
O 0.25 0.25 0.25
|
||||
"""
|
||||
result = await optimize_crystal_structure(content=simple_cif, input_format="cif")
|
||||
|
||||
elif tool_name == "generate_material":
|
||||
from mars_toolkit.compute.material_gen import generate_material
|
||||
# 使用简单的属性约束进行测试
|
||||
result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
|
||||
|
||||
elif tool_name == "fetch_chemical_composition_from_OQMD":
|
||||
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
result = await fetch_chemical_composition_from_OQMD(composition="Fe2O3")
|
||||
|
||||
elif tool_name == "retrieval_from_knowledge_base":
|
||||
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
result = await retrieval_from_knowledge_base(query="CsPbBr3", topk=3)
|
||||
|
||||
elif tool_name == "predict_properties":
|
||||
from mars_toolkit.compute.property_pred import predict_properties
|
||||
# 使用一个简单的CsPbBr3结构CIF字符串作为测试输入
|
||||
_cif = """
|
||||
# generated using pymatgen
|
||||
data_CsPbBr3
|
||||
_symmetry_space_group_name_H-M 'P 1'
|
||||
_cell_length_a 8.37036600
|
||||
_cell_length_b 8.42533500
|
||||
_cell_length_c 12.01129500
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 1
|
||||
_chemical_formula_structural CsPbBr3
|
||||
_chemical_formula_sum 'Cs4 Pb4 Br12'
|
||||
_cell_volume 847.07421031
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Cs Cs0 1 0.50831300 0.46818500 0.25000000 1
|
||||
Cs Cs1 1 0.00831300 0.03181500 0.75000000 1
|
||||
Cs Cs2 1 0.99168700 0.96818500 0.25000000 1
|
||||
Cs Cs3 1 0.49168700 0.53181500 0.75000000 1
|
||||
Pb Pb4 1 0.50000000 0.00000000 0.50000000 1
|
||||
Pb Pb5 1 0.00000000 0.50000000 0.00000000 1
|
||||
Pb Pb6 1 0.00000000 0.50000000 0.50000000 1
|
||||
Pb Pb7 1 0.50000000 0.00000000 0.00000000 1
|
||||
Br Br8 1 0.54824500 0.99370800 0.75000000 1
|
||||
Br Br9 1 0.04824500 0.50629200 0.25000000 1
|
||||
Br Br10 1 0.79480800 0.20538600 0.02568800 1
|
||||
Br Br11 1 0.20519200 0.79461400 0.97431200 1
|
||||
Br Br12 1 0.29480800 0.29461400 0.97431200 1
|
||||
Br Br13 1 0.70519200 0.70538600 0.02568800 1
|
||||
Br Br14 1 0.29480800 0.29461400 0.52568800 1
|
||||
Br Br15 1 0.70519200 0.70538600 0.47431200 1
|
||||
Br Br16 1 0.20519200 0.79461400 0.52568800 1
|
||||
Br Br17 1 0.79480800 0.20538600 0.47431200 1
|
||||
Br Br18 1 0.95175500 0.49370800 0.75000000 1
|
||||
Br Br19 1 0.45175500 0.00629200 0.25000000 1
|
||||
"""
|
||||
result = await predict_properties(cif_content=_cif)
|
||||
|
||||
else:
|
||||
return f"未知工具: {tool_name}"
|
||||
|
||||
print(f"工具 {tool_name} 测试完成")
|
||||
return f"工具 {tool_name} 测试成功,返回结果类型: {type(result)}, 返回的结果: {result}"
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
return f"工具 {tool_name} 测试失败: {str(e)}\n{error_details}"
|
||||
|
||||
|
||||
def print_tool_schemas():
|
||||
"""打印所有注册的工具函数的JSON模式"""
|
||||
import mars_toolkit
|
||||
schemas = mars_toolkit.get_tool_schemas()
|
||||
console.print("[bold green]已注册的工具函数列表:[/bold green]")
|
||||
for i, schema in enumerate(schemas, 1):
|
||||
console.print(f"[bold cyan]工具 {i}:[/bold cyan] {schema['function']['name']}")
|
||||
console.print(f"[bold yellow]描述:[/bold yellow] {schema['function']['description']}")
|
||||
console.print("[bold magenta]参数:[/bold magenta]")
|
||||
for param_name, param_info in schema['function']['parameters']['properties'].items():
|
||||
required = "必需" if param_name in schema['function']['parameters'].get('required', []) else "可选"
|
||||
console.print(f" - [bold]{param_name}[/bold] ({required}): {param_info.get('description', '无描述')}")
|
||||
console.print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 打印所有工具函数的模式
|
||||
#print_tool_schemas()
|
||||
|
||||
# 测试工具函数列表
|
||||
tools_to_test = [
|
||||
"get_current_time", # 基础工具
|
||||
"search_online", # 网络搜索工具
|
||||
"search_material_property_from_material_project", # 材料项目查询工具
|
||||
"get_crystal_structures_from_materials_project", # 晶体结构查询工具
|
||||
"get_mpid_from_formula", # 材料ID查询工具
|
||||
"optimize_crystal_structure", # 晶体结构优化工具
|
||||
"generate_material", # 材料生成工具
|
||||
"fetch_chemical_composition_from_OQMD", # OQMD查询工具
|
||||
"retrieval_from_knowledge_base", # 知识库检索工具
|
||||
"predict_properties" # 属性预测工具
|
||||
]
|
||||
|
||||
# 选择要测试的工具
|
||||
tool_name = tools_to_test[5] # 测试 search_online 工具
|
||||
|
||||
# 运行测试
|
||||
result = asyncio.run(test_tool(tool_name))
|
||||
print(result)
|
||||
console.print(f"[bold blue]测试结果:[/bold blue]")
|
||||
console.print(result)
|
||||
Reference in New Issue
Block a user