初次提交
This commit is contained in:
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
|
||||
__pycache__/
|
||||
|
||||
# 忽略所有.pth和.ckqt文件(包括子目录)
|
||||
**/*.pth
|
||||
**/*.ckpt
|
||||
/reference
|
||||
/sci_mcp_server
|
||||
/temp
|
||||
/sci_mcp/material_mcp/support/pretrained_models
|
||||
/sci_mcp/material_mcp/mattergen_gen/mattergen/*
|
||||
15
.vscode/launch.json
vendored
Normal file
15
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
// 使用 IntelliSense 了解相关属性。
|
||||
// 悬停以查看现有属性的描述。
|
||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python 调试程序: 当前文件",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
||||
11
multi_mcp_server.code-workspace
Normal file
11
multi_mcp_server.code-workspace
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"folders": [
|
||||
{
|
||||
"path": "../../.."
|
||||
},
|
||||
{
|
||||
"path": "../../../../SciToolAgent/ToolsAgent/ToolsFuns"
|
||||
}
|
||||
],
|
||||
"settings": {}
|
||||
}
|
||||
315
requirements.txt
Normal file
315
requirements.txt
Normal file
@@ -0,0 +1,315 @@
|
||||
absl-py==2.2.2
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.11.16
|
||||
aioitertools==0.12.0
|
||||
aiosignal==1.3.2
|
||||
annotated-types==0.7.0
|
||||
antlr4-python3-runtime==4.9.3
|
||||
anyio==4.9.0
|
||||
argon2-cffi==23.1.0
|
||||
argon2-cffi-bindings==21.2.0
|
||||
arrow==1.3.0
|
||||
ase==3.25.0
|
||||
astroid==3.3.9
|
||||
asttokens==3.0.0
|
||||
async-lru==2.0.5
|
||||
async-timeout==4.0.3
|
||||
atomate2==0.0.18
|
||||
attrs==25.3.0
|
||||
autopep8==2.3.2
|
||||
azure-core==1.33.0
|
||||
azure-identity==1.21.0
|
||||
azure-storage-blob==12.25.1
|
||||
babel==2.17.0
|
||||
bcrypt==4.3.0
|
||||
beautifulsoup4==4.13.4
|
||||
bleach==6.2.0
|
||||
boto3==1.37.35
|
||||
botocore==1.37.35
|
||||
cachetools==5.5.2
|
||||
certifi==2025.1.31
|
||||
cffi==1.17.1
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
cloudpickle==3.1.1
|
||||
colorama==0.4.6
|
||||
comm==0.2.2
|
||||
contextlib2==21.6.0
|
||||
contourpy==1.3.2
|
||||
cryptography==44.0.2
|
||||
custodian==2025.4.14
|
||||
cycler==0.12.1
|
||||
dataclasses-json==0.6.7
|
||||
debugpy==1.8.14
|
||||
decorator==5.2.1
|
||||
defusedxml==0.7.1
|
||||
deprecated==1.2.18
|
||||
dgl==2.4.0+cu124
|
||||
dill==0.4.0
|
||||
distro==1.9.0
|
||||
dnspython==2.7.0
|
||||
docker-pycreds==0.4.0
|
||||
docstring-parser==0.16
|
||||
e3nn==0.5.6
|
||||
emmet-core==0.84.6
|
||||
exceptiongroup==1.2.2
|
||||
executing==2.2.0
|
||||
fairchem-core==1.9.0
|
||||
fastjsonschema==2.21.1
|
||||
filelock==3.18.0
|
||||
fire==0.7.0
|
||||
fonttools==4.57.0
|
||||
fqdn==1.5.1
|
||||
frozenlist==1.5.0
|
||||
fsspec==2025.3.2
|
||||
gitdb==4.0.12
|
||||
gitpython==3.1.44
|
||||
greenlet==3.2.0
|
||||
grpcio==1.71.0
|
||||
h11==0.14.0
|
||||
h5py==3.13.0
|
||||
httpcore==1.0.8
|
||||
httpx==0.28.1
|
||||
httpx-sse==0.4.0
|
||||
huggingface-hub==0.30.2
|
||||
hydra-core==1.3.1
|
||||
hydra-joblib-launcher==1.1.5
|
||||
idna==3.10
|
||||
iniconfig==2.1.0
|
||||
ipykernel==6.29.5
|
||||
ipython==8.35.0
|
||||
isodate==0.7.2
|
||||
isoduration==20.11.0
|
||||
isort==6.0.1
|
||||
jedi==0.19.2
|
||||
jinja2==3.1.6
|
||||
jiter==0.9.0
|
||||
jmespath==1.0.1
|
||||
jobflow==0.1.19
|
||||
joblib==1.4.2
|
||||
json5==0.12.0
|
||||
jsonlines==4.0.0
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
jupyter-client==8.6.3
|
||||
jupyter-core==5.7.2
|
||||
jupyter-events==0.12.0
|
||||
jupyter-lsp==2.2.5
|
||||
jupyter-server==2.15.0
|
||||
jupyter-server-terminals==0.5.3
|
||||
jupyterlab==4.4.0
|
||||
jupyterlab-pygments==0.3.0
|
||||
jupyterlab-server==2.27.3
|
||||
kiwisolver==1.4.8
|
||||
langchain==0.3.23
|
||||
langchain-community==0.3.21
|
||||
langchain-core==0.3.52
|
||||
langchain-text-splitters==0.3.8
|
||||
langsmith==0.3.31
|
||||
latexcodec==3.0.0
|
||||
lightning==2.5.1
|
||||
lightning-utilities==0.14.3
|
||||
llvmlite==0.44.0
|
||||
lmdb==1.6.2
|
||||
loguru==0.7.3
|
||||
lxml==5.3.2
|
||||
maggma==0.71.5
|
||||
markdown==3.8
|
||||
markdown-it-py==3.0.0
|
||||
markupsafe==3.0.2
|
||||
marshmallow==3.26.1
|
||||
matgl==1.2.6
|
||||
matplotlib==3.8.4
|
||||
matplotlib-inline==0.1.7
|
||||
matscipy==1.1.1
|
||||
-e file:///home/ubuntu/sas0/lzy/mars-mcp/mattergen
|
||||
mattersim==1.1.2
|
||||
mccabe==0.7.0
|
||||
mcp==1.6.0
|
||||
mdurl==0.1.2
|
||||
mistune==3.1.3
|
||||
mongomock==4.3.0
|
||||
monty==2025.3.3
|
||||
mp-api==0.45.3
|
||||
mpmath==1.3.0
|
||||
msal==1.32.0
|
||||
msal-extensions==1.3.1
|
||||
msgpack==1.1.0
|
||||
multidict==6.4.3
|
||||
multiprocess==0.70.18
|
||||
mypy-extensions==1.0.0
|
||||
mysql-connector==2.2.9
|
||||
narwhals==1.35.0
|
||||
nbclient==0.10.2
|
||||
nbconvert==7.16.6
|
||||
nbformat==5.10.4
|
||||
nest-asyncio==1.6.0
|
||||
networkx==3.4.2
|
||||
notebook==7.4.0
|
||||
notebook-shim==0.2.4
|
||||
numba==0.61.2
|
||||
numpy==1.26.4
|
||||
nvidia-cublas-cu11==11.11.3.6
|
||||
nvidia-cublas-cu12==12.4.2.65
|
||||
nvidia-cuda-cupti-cu11==11.8.87
|
||||
nvidia-cuda-cupti-cu12==12.4.99
|
||||
nvidia-cuda-nvrtc-cu11==11.8.89
|
||||
nvidia-cuda-nvrtc-cu12==12.4.99
|
||||
nvidia-cuda-runtime-cu11==11.8.89
|
||||
nvidia-cuda-runtime-cu12==12.4.99
|
||||
nvidia-cudnn-cu11==8.7.0.84
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
nvidia-cufft-cu11==10.9.0.58
|
||||
nvidia-cufft-cu12==11.2.0.44
|
||||
nvidia-curand-cu11==10.3.0.86
|
||||
nvidia-curand-cu12==10.3.5.119
|
||||
nvidia-cusolver-cu11==11.4.1.48
|
||||
nvidia-cusolver-cu12==11.6.0.99
|
||||
nvidia-cusparse-cu11==11.7.5.86
|
||||
nvidia-cusparse-cu12==12.3.0.142
|
||||
nvidia-nccl-cu11==2.19.3
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
nvidia-nvjitlink-cu12==12.4.99
|
||||
nvidia-nvtx-cu11==11.8.86
|
||||
nvidia-nvtx-cu12==12.4.99
|
||||
omegaconf==2.3.0
|
||||
openai==1.75.0
|
||||
opt-einsum==3.4.0
|
||||
opt-einsum-fx==0.1.4
|
||||
orjson==3.10.16
|
||||
overrides==7.7.0
|
||||
packaging==24.2
|
||||
palettable==3.3.3
|
||||
pandarallel==1.6.5
|
||||
pandas==2.2.3
|
||||
pandocfilters==1.5.1
|
||||
paramiko==3.5.1
|
||||
parso==0.8.4
|
||||
pathos==0.3.3
|
||||
pexpect==4.9.0
|
||||
phonopy==2.38.1
|
||||
pillow==11.2.1
|
||||
pip==25.0.1
|
||||
platformdirs==4.3.7
|
||||
plotly==6.0.1
|
||||
pluggy==1.5.0
|
||||
pox==0.3.6
|
||||
ppft==1.7.7
|
||||
prometheus-client==0.21.1
|
||||
prompt-toolkit==3.0.51
|
||||
propcache==0.3.1
|
||||
protobuf==5.29.4
|
||||
psutil==7.0.0
|
||||
ptyprocess==0.7.0
|
||||
pubchempy==1.0.4
|
||||
pure-eval==0.2.3
|
||||
pybtex==0.24.0
|
||||
pycodestyle==2.13.0
|
||||
pycparser==2.22
|
||||
pydantic==2.11.3
|
||||
pydantic-core==2.33.1
|
||||
pydantic-settings==2.8.1
|
||||
pydash==8.0.5
|
||||
pyg-lib==0.4.0+pt24cu124
|
||||
pygments==2.19.1
|
||||
pyjwt==2.10.1
|
||||
pylint==3.3.6
|
||||
pymatgen==2024.10.29
|
||||
pymongo==4.10.1
|
||||
pynacl==1.5.0
|
||||
pyparsing==3.2.3
|
||||
pyre-extensions==0.0.32
|
||||
pytest==8.3.5
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.1.0
|
||||
python-json-logger==3.3.0
|
||||
pytorch-lightning==2.0.6
|
||||
pytz==2025.2
|
||||
pyyaml==6.0.2
|
||||
pyzmq==26.4.0
|
||||
referencing==0.36.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
requests-toolbelt==1.0.0
|
||||
rfc3339-validator==0.1.4
|
||||
rfc3986-validator==0.1.1
|
||||
rich==14.0.0
|
||||
rpds-py==0.24.0
|
||||
ruamel-yaml==0.18.10
|
||||
ruamel-yaml-clib==0.2.12
|
||||
s3transfer==0.11.4
|
||||
scikit-learn==1.6.1
|
||||
scipy==1.15.2
|
||||
seaborn==0.13.2
|
||||
seekpath==2.1.0
|
||||
send2trash==1.8.3
|
||||
sentinels==1.0.0
|
||||
sentry-sdk==2.26.1
|
||||
setproctitle==1.3.5
|
||||
setuptools==78.1.0
|
||||
shellingham==1.5.4
|
||||
six==1.17.0
|
||||
smact==3.1.0
|
||||
smart-open==7.1.0
|
||||
smmap==5.0.2
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.6
|
||||
spglib==2.6.0
|
||||
sqlalchemy==2.0.40
|
||||
sse-starlette==2.2.1
|
||||
sshtunnel==0.4.0
|
||||
stack-data==0.6.3
|
||||
starlette==0.46.2
|
||||
submitit==1.5.2
|
||||
symfc==1.3.4
|
||||
sympy==1.13.3
|
||||
tabulate==0.9.0
|
||||
tenacity==9.1.2
|
||||
tensorboard==2.19.0
|
||||
tensorboard-data-server==0.7.2
|
||||
termcolor==3.0.1
|
||||
terminado==0.18.1
|
||||
threadpoolctl==3.6.0
|
||||
tiktoken==0.9.0
|
||||
tinycss2==1.4.0
|
||||
tomli==2.2.1
|
||||
tomlkit==0.13.2
|
||||
torch==2.4.0+cu124
|
||||
torch-cluster==1.6.3+pt24cu124
|
||||
torch-ema==0.3
|
||||
torch-geometric==2.6.1
|
||||
torch-runstats==0.2.0
|
||||
torch-scatter==2.1.2+pt24cu124
|
||||
torch-sparse==0.6.18+pt24cu124
|
||||
torch-spline-conv==1.2.2+pt24cu124
|
||||
torchaudio==2.4.0+cu124
|
||||
torchdata==0.8.0
|
||||
torchmetrics==1.7.1
|
||||
torchtnt==0.2.4
|
||||
torchvision==0.19.0+cu124
|
||||
tornado==6.4.2
|
||||
tqdm==4.67.1
|
||||
traitlets==5.14.3
|
||||
triton==3.0.0
|
||||
typer==0.15.2
|
||||
types-python-dateutil==2.9.0.20241206
|
||||
typing-extensions==4.13.2
|
||||
typing-inspect==0.9.0
|
||||
typing-inspection==0.4.0
|
||||
tzdata==2025.2
|
||||
uncertainties==3.2.2
|
||||
uri-template==1.3.0
|
||||
urllib3==2.4.0
|
||||
uvicorn==0.34.1
|
||||
wandb==0.19.9
|
||||
wcwidth==0.2.13
|
||||
webcolors==24.11.1
|
||||
webencodings==0.5.1
|
||||
websocket-client==1.8.0
|
||||
werkzeug==3.1.3
|
||||
wheel==0.45.1
|
||||
wrapt==1.17.2
|
||||
yarl==1.20.0
|
||||
zstandard==0.23.0
|
||||
38
sci_mcp/__init__.py
Normal file
38
sci_mcp/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from .core.llm_tools import llm_tool,get_all_tools, get_all_tool_schemas, get_domain_tools, get_domain_tool_schemas
|
||||
|
||||
#general_mcp
|
||||
#from .general_mcp.searxng_query.searxng_query_tools import search_online
|
||||
|
||||
# #material_mcp
|
||||
from .material_mcp.mp_query.mp_query_tools import search_crystal_structures_from_materials_project,search_material_property_from_materials_project
|
||||
from .material_mcp.oqmd_query.oqmd_query_tools import query_material_from_OQMD
|
||||
from .material_mcp.knowledge_base_query.retrieval_from_knowledge_base_tools import retrieval_from_knowledge_base
|
||||
|
||||
from .material_mcp.mattersim_pred.property_pred_tools import predict_properties_MatterSim
|
||||
from .material_mcp.mattergen_gen.material_gen_tools import generate_material_MatterGen
|
||||
from .material_mcp.fairchem_structure_opt.structure_opt_tools import optimize_crystal_structure_FairChem
|
||||
from .material_mcp.pymatgen_cal.pymatgen_cal_tools import calculate_density_Pymatgen,get_element_composition_Pymatgen,calculate_symmetry_Pymatgen
|
||||
from .material_mcp.matgl_tools.matgl_tools import relax_crystal_structure_M3GNet,predict_formation_energy_M3GNet,run_molecular_dynamics_M3GNet,calculate_single_point_energy_M3GNet
|
||||
#chemistry_mcp
|
||||
from .chemistry_mcp.pubchem_tools.pubchem_tools import search_advanced_pubchem
|
||||
from .chemistry_mcp.rdkit_tools.rdkit_tools import (
|
||||
calculate_molecular_properties_rdkit,
|
||||
calculate_drug_likeness_rdkit,
|
||||
calculate_topological_descriptors_rdkit,
|
||||
generate_molecular_fingerprints_rdkit,
|
||||
calculate_molecular_similarity_rdkit,
|
||||
analyze_molecular_structure_rdkit,
|
||||
generate_molecular_conformer_rdkit,
|
||||
identify_scaffolds_rdkit,
|
||||
convert_between_chemical_formats_rdkit,
|
||||
standardize_molecule_rdkit,
|
||||
enumerate_stereoisomers_rdkit,
|
||||
perform_substructure_search_rdkit
|
||||
)
|
||||
from .chemistry_mcp.rxn_tools.rxn_tools import (
|
||||
predict_reaction_outcome_rxn,
|
||||
predict_reaction_topn_rxn,
|
||||
predict_reaction_properties_rxn,
|
||||
extract_reaction_actions_rxn
|
||||
)
|
||||
__all__ = ["llm_tool", "get_all_tools", "get_all_tool_schemas", "get_domain_tools", "get_domain_tool_schemas"]
|
||||
9
sci_mcp/chemistry_mcp/__init__.py
Normal file
9
sci_mcp/chemistry_mcp/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Chemistry MCP Module
|
||||
|
||||
This module provides tools for chemistry-related operations.
|
||||
"""
|
||||
|
||||
from .pubchem_tools import search_advanced_pubchem
|
||||
|
||||
__all__ = ["search_advanced_pubchem"]
|
||||
9
sci_mcp/chemistry_mcp/pubchem_tools/__init__.py
Normal file
9
sci_mcp/chemistry_mcp/pubchem_tools/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
PubChem Tools Module
|
||||
|
||||
This module provides tools for accessing and processing chemical data from PubChem.
|
||||
"""
|
||||
|
||||
from .pubchem_tools import search_advanced_pubchem
|
||||
|
||||
__all__ = ["search_advanced_pubchem"]
|
||||
325
sci_mcp/chemistry_mcp/pubchem_tools/pubchem_tools.py
Normal file
325
sci_mcp/chemistry_mcp/pubchem_tools/pubchem_tools.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
PubChem Tools Module
|
||||
|
||||
This module provides tools for searching and retrieving chemical compound information
|
||||
from the PubChem database using the PubChemPy library.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Union, Optional, Any
|
||||
|
||||
import pubchempy as pcp
|
||||
from ...core.llm_tools import llm_tool
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def compound_to_dict(compound: pcp.Compound) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a PubChem compound to a structured dictionary with relevant information.
|
||||
|
||||
Args:
|
||||
compound: PubChem compound object
|
||||
|
||||
Returns:
|
||||
Dictionary containing organized compound information
|
||||
"""
|
||||
if not compound:
|
||||
return {}
|
||||
|
||||
# Basic information
|
||||
result = {
|
||||
"basic_info": {
|
||||
"cid": compound.cid,
|
||||
"iupac_name": compound.iupac_name,
|
||||
"molecular_formula": compound.molecular_formula,
|
||||
"molecular_weight": compound.molecular_weight,
|
||||
"canonical_smiles": compound.canonical_smiles,
|
||||
"isomeric_smiles": compound.isomeric_smiles,
|
||||
},
|
||||
"identifiers": {
|
||||
"inchi": compound.inchi,
|
||||
"inchikey": compound.inchikey,
|
||||
},
|
||||
"physical_properties": {
|
||||
"xlogp": compound.xlogp,
|
||||
"exact_mass": compound.exact_mass,
|
||||
"monoisotopic_mass": compound.monoisotopic_mass,
|
||||
"tpsa": compound.tpsa,
|
||||
"complexity": compound.complexity,
|
||||
"charge": compound.charge,
|
||||
},
|
||||
"molecular_features": {
|
||||
"h_bond_donor_count": compound.h_bond_donor_count,
|
||||
"h_bond_acceptor_count": compound.h_bond_acceptor_count,
|
||||
"rotatable_bond_count": compound.rotatable_bond_count,
|
||||
"heavy_atom_count": compound.heavy_atom_count,
|
||||
"atom_stereo_count": compound.atom_stereo_count,
|
||||
"defined_atom_stereo_count": compound.defined_atom_stereo_count,
|
||||
"undefined_atom_stereo_count": compound.undefined_atom_stereo_count,
|
||||
"bond_stereo_count": compound.bond_stereo_count,
|
||||
"defined_bond_stereo_count": compound.defined_bond_stereo_count,
|
||||
"undefined_bond_stereo_count": compound.undefined_bond_stereo_count,
|
||||
"covalent_unit_count": compound.covalent_unit_count,
|
||||
}
|
||||
}
|
||||
|
||||
# Add synonyms if available
|
||||
if hasattr(compound, 'synonyms') and compound.synonyms:
|
||||
result["alternative_names"] = {
|
||||
"synonyms": compound.synonyms[:10] if len(compound.synonyms) > 10 else compound.synonyms
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _search_by_name(name: str, max_results: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search compounds by name asynchronously.
|
||||
|
||||
Args:
|
||||
name: Chemical compound name
|
||||
max_results: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of compound dictionaries
|
||||
"""
|
||||
try:
|
||||
compounds = await asyncio.to_thread(
|
||||
pcp.get_compounds, name, 'name', max_records=max_results
|
||||
)
|
||||
#print(compounds[0].to_dict())
|
||||
return [compound.to_dict() for compound in compounds]
|
||||
except Exception as e:
|
||||
logging.error(f"Error searching by name '{name}': {str(e)}")
|
||||
return [{"error": f"Error: {str(e)}"}]
|
||||
|
||||
|
||||
async def _search_by_smiles(smiles: str, max_results: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search compounds by SMILES notation asynchronously.
|
||||
|
||||
Args:
|
||||
smiles: SMILES notation of chemical compound
|
||||
max_results: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of compound dictionaries
|
||||
"""
|
||||
try:
|
||||
compounds = await asyncio.to_thread(
|
||||
pcp.get_compounds, smiles, 'smiles', max_records=max_results
|
||||
)
|
||||
return [compound.to_dict() for compound in compounds]
|
||||
except Exception as e:
|
||||
logging.error(f"Error searching by SMILES '{smiles}': {str(e)}")
|
||||
return [{"error": f"Error: {str(e)}"}]
|
||||
|
||||
|
||||
async def _search_by_formula(
|
||||
formula: str,
|
||||
max_results: int = 5,
|
||||
listkey_count: int = 5,
|
||||
listkey_start: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search compounds by molecular formula asynchronously.
|
||||
|
||||
Uses pagination with listkey parameters to avoid timeout errors when searching
|
||||
formulas that might return many results.
|
||||
|
||||
Args:
|
||||
formula: Molecular formula
|
||||
max_results: Maximum number of results to return
|
||||
listkey_count: Number of results per page (default: 5)
|
||||
listkey_start: Starting position for pagination (default: 0)
|
||||
|
||||
Returns:
|
||||
List of compound dictionaries
|
||||
"""
|
||||
try:
|
||||
# Use listkey parameters to avoid timeout errors
|
||||
compounds = await asyncio.to_thread(
|
||||
pcp.get_compounds,
|
||||
formula,
|
||||
'formula',
|
||||
max_records=max_results,
|
||||
listkey_count=listkey_count,
|
||||
listkey_start=listkey_start
|
||||
)
|
||||
|
||||
return [compound.to_dict() for compound in compounds]
|
||||
except Exception as e:
|
||||
logging.error(f"Error searching by formula '{formula}': {str(e)}")
|
||||
return [{"error": f"Error: {str(e)}"}]
|
||||
|
||||
|
||||
def _format_results_as_markdown(results: List[Dict[str, Any]], query_type: str, query_value: str) -> str:
|
||||
"""
|
||||
Format search results as a structured Markdown string.
|
||||
|
||||
Args:
|
||||
results: List of compound dictionaries from compound.to_dict()
|
||||
query_type: Type of search query (name, SMILES, formula)
|
||||
query_value: Value of the search query
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string
|
||||
"""
|
||||
if not results:
|
||||
return f"## PubChem Search Results\n\nNo compounds found for {query_type}: `{query_value}`"
|
||||
|
||||
if "error" in results[0]:
|
||||
return f"## PubChem Search Error\n\n{results[0]['error']}"
|
||||
|
||||
markdown = f"## PubChem Search Results\n\nSearch by {query_type}: `{query_value}`\n\nFound {len(results)} compound(s)\n\n"
|
||||
|
||||
for i, compound in enumerate(results):
|
||||
# Extract information from the compound.to_dict() structure
|
||||
cid = compound.get("cid", "N/A")
|
||||
iupac_name = compound.get("iupac_name", "Unknown")
|
||||
molecular_formula = compound.get("molecular_formula", "N/A")
|
||||
molecular_weight = compound.get("molecular_weight", "N/A")
|
||||
canonical_smiles = compound.get("canonical_smiles", "N/A")
|
||||
isomeric_smiles = compound.get("isomeric_smiles", "N/A")
|
||||
inchi = compound.get("inchi", "N/A")
|
||||
inchikey = compound.get("inchikey", "N/A")
|
||||
xlogp = compound.get("xlogp", "N/A")
|
||||
exact_mass = compound.get("exact_mass", "N/A")
|
||||
tpsa = compound.get("tpsa", "N/A")
|
||||
h_bond_donor_count = compound.get("h_bond_donor_count", "N/A")
|
||||
h_bond_acceptor_count = compound.get("h_bond_acceptor_count", "N/A")
|
||||
rotatable_bond_count = compound.get("rotatable_bond_count", "N/A")
|
||||
heavy_atom_count = compound.get("heavy_atom_count", "N/A")
|
||||
|
||||
# Get atoms and bonds information if available
|
||||
atoms = compound.get("atoms", [])
|
||||
bonds = compound.get("bonds", [])
|
||||
|
||||
# Format the markdown output
|
||||
markdown += f"### Compound {i+1}: {iupac_name}\n\n"
|
||||
|
||||
# Basic information section
|
||||
markdown += "#### Basic Information\n\n"
|
||||
markdown += f"- **CID**: {cid}\n"
|
||||
markdown += f"- **Formula**: {molecular_formula}\n"
|
||||
markdown += f"- **Molecular Weight**: {molecular_weight} g/mol\n"
|
||||
markdown += f"- **Canonical SMILES**: `{canonical_smiles}`\n"
|
||||
markdown += f"- **Isomeric SMILES**: `{isomeric_smiles}`\n"
|
||||
|
||||
# Identifiers section
|
||||
markdown += "\n#### Identifiers\n\n"
|
||||
markdown += f"- **InChI**: `{inchi}`\n"
|
||||
markdown += f"- **InChIKey**: `{inchikey}`\n"
|
||||
|
||||
# Physical properties section
|
||||
markdown += "\n#### Physical Properties\n\n"
|
||||
markdown += f"- **XLogP**: {xlogp}\n"
|
||||
markdown += f"- **Exact Mass**: {exact_mass}\n"
|
||||
markdown += f"- **TPSA**: {tpsa} Ų\n"
|
||||
|
||||
# Molecular features section
|
||||
markdown += "\n#### Molecular Features\n\n"
|
||||
markdown += f"- **H-Bond Donors**: {h_bond_donor_count}\n"
|
||||
markdown += f"- **H-Bond Acceptors**: {h_bond_acceptor_count}\n"
|
||||
markdown += f"- **Rotatable Bonds**: {rotatable_bond_count}\n"
|
||||
markdown += f"- **Heavy Atoms**: {heavy_atom_count}\n"
|
||||
|
||||
# Structure information
|
||||
markdown += "\n#### Structure Information\n\n"
|
||||
markdown += f"- **Atoms Count**: {len(atoms)}\n"
|
||||
markdown += f"- **Bonds Count**: {len(bonds)}\n"
|
||||
|
||||
# Add a summary of atom elements if available
|
||||
if atoms:
|
||||
elements = {}
|
||||
for atom in atoms:
|
||||
element = atom.get("element", "")
|
||||
if element:
|
||||
elements[element] = elements.get(element, 0) + 1
|
||||
|
||||
if elements:
|
||||
markdown += "- **Elements**: "
|
||||
elements_str = ", ".join([f"{element}: {count}" for element, count in elements.items()])
|
||||
markdown += f"{elements_str}\n"
|
||||
|
||||
markdown += "\n---\n\n" if i < len(results) - 1 else "\n"
|
||||
|
||||
return markdown
|
||||
|
||||
|
||||
@llm_tool(name="search_advanced_pubchem",
|
||||
description="Search for chemical compounds on PubChem database using name, SMILES notation, or molecular formula via PubChemPy library")
|
||||
async def search_advanced_pubchem(
|
||||
name: Optional[str] = None,
|
||||
smiles: Optional[str] = None,
|
||||
formula: Optional[str] = None,
|
||||
max_results: int = 3
|
||||
) -> str:
|
||||
"""
|
||||
Perform an advanced search for chemical compounds on PubChem using various identifiers.
|
||||
|
||||
This function allows searching by compound name, SMILES notation, or molecular formula.
|
||||
At least one search parameter must be provided. If multiple parameters are provided,
|
||||
the search will prioritize in the order: name > smiles > formula.
|
||||
|
||||
Args:
|
||||
name: Name of the chemical compound (e.g., "Aspirin", "Caffeine")
|
||||
smiles: SMILES notation of the chemical compound (e.g., "CC(=O)OC1=CC=CC=C1C(=O)O" for Aspirin)
|
||||
formula: Molecular formula (e.g., "C9H8O4" for Aspirin)
|
||||
max_results: Maximum number of results to return (default: 3)
|
||||
|
||||
Returns:
|
||||
A formatted Markdown string with search results
|
||||
|
||||
Examples:
|
||||
>>> search_advanced_pubchem(name="Aspirin")
|
||||
# Returns information about Aspirin
|
||||
|
||||
>>> search_advanced_pubchem(smiles="CC(=O)OC1=CC=CC=C1C(=O)O")
|
||||
# Returns information about compounds matching the SMILES notation
|
||||
|
||||
>>> search_advanced_pubchem(formula="C9H8O4", max_results=5)
|
||||
# Returns up to 5 compounds with the formula C9H8O4
|
||||
"""
|
||||
logging.info(f"Performing advanced PubChem search with parameters: name={name}, smiles={smiles}, formula={formula}, max_results={max_results}")
|
||||
|
||||
# Validate input parameters
|
||||
if name is None and smiles is None and formula is None:
|
||||
return "## PubChem Search Error\n\nError: At least one search parameter (name, smiles, or formula) must be provided"
|
||||
|
||||
# Validate max_results
|
||||
if max_results < 1:
|
||||
max_results = 1
|
||||
elif max_results > 10:
|
||||
max_results = 10 # Limit to 10 results to prevent overwhelming responses
|
||||
|
||||
try:
|
||||
results = []
|
||||
query_type = ""
|
||||
query_value = ""
|
||||
|
||||
# Prioritize search by name, then SMILES, then formula
|
||||
if name is not None:
|
||||
results = await _search_by_name(name, max_results)
|
||||
query_type = "name"
|
||||
query_value = name
|
||||
elif smiles is not None:
|
||||
results = await _search_by_smiles(smiles, max_results)
|
||||
query_type = "SMILES"
|
||||
query_value = smiles
|
||||
elif formula is not None:
|
||||
# Use pagination parameters for formula searches to avoid timeout
|
||||
# Using the default values from _search_by_formula
|
||||
results = await _search_by_formula(formula, max_results)
|
||||
query_type = "formula"
|
||||
query_value = formula
|
||||
|
||||
# Return results as markdown
|
||||
return _format_results_as_markdown(results, query_type, query_value)
|
||||
|
||||
except Exception as e:
|
||||
return f"## PubChem Search Error\n\nError: {str(e)}"
|
||||
9
sci_mcp/chemistry_mcp/rdkit_tools/__init__.py
Normal file
9
sci_mcp/chemistry_mcp/rdkit_tools/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
RDKit Tools Package
|
||||
|
||||
This package provides tools for molecular analysis, manipulation, and visualization
|
||||
using the RDKit library. It includes functions for calculating molecular descriptors,
|
||||
generating molecular fingerprints, analyzing molecular structures, and more.
|
||||
"""
|
||||
|
||||
from .rdkit_tools import *
|
||||
1154
sci_mcp/chemistry_mcp/rdkit_tools/rdkit_tools.py
Normal file
1154
sci_mcp/chemistry_mcp/rdkit_tools/rdkit_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
6
sci_mcp/chemistry_mcp/rxn_tools/__init__.py
Normal file
6
sci_mcp/chemistry_mcp/rxn_tools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
RXN Tools Module
|
||||
|
||||
This module provides tools for chemical reaction prediction and analysis
|
||||
using the IBM RXN for Chemistry API.
|
||||
"""
|
||||
772
sci_mcp/chemistry_mcp/rxn_tools/rxn_tools.py
Normal file
772
sci_mcp/chemistry_mcp/rxn_tools/rxn_tools.py
Normal file
@@ -0,0 +1,772 @@
|
||||
"""
|
||||
RXN Tools Module
|
||||
|
||||
This module provides tools for chemical reaction prediction and analysis
|
||||
using the IBM RXN for Chemistry API through the rxn4chemistry package.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Union, Optional, Any, Tuple
|
||||
|
||||
from rxn4chemistry import RXN4ChemistryWrapper
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import Chemistry_Config
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
DEFAULT_MAX_RESULTS = 3
|
||||
DEFAULT_TIMEOUT = 180 # seconds - increased from 60 to 180
|
||||
|
||||
|
||||
def _get_rxn_wrapper() -> RXN4ChemistryWrapper:
|
||||
"""
|
||||
Get an initialized RXN4Chemistry wrapper with API key and project set.
|
||||
|
||||
Returns:
|
||||
Initialized RXN4ChemistryWrapper instance with project set
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is not available
|
||||
"""
|
||||
# Try to get API key from environment or config
|
||||
api_key = Chemistry_Config.RXN4_CHEMISTRY_KEY or os.environ.get("RXN_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Error: RXN API key not found. Please set the RXN_API_KEY environment variable.")
|
||||
|
||||
# Initialize the wrapper
|
||||
wrapper = RXN4ChemistryWrapper(api_key=api_key)
|
||||
|
||||
try:
|
||||
# Create a new project
|
||||
project_name = f"RXN_Tools_Project_{os.getpid()}" # Add process ID to make name unique
|
||||
project_response = wrapper.create_project(project_name)
|
||||
|
||||
# Extract project ID from response
|
||||
# The API response format is nested: {'response': {'payload': {'id': '...'}}
|
||||
if project_response and isinstance(project_response, dict):
|
||||
# Try to extract project ID from different possible response formats
|
||||
project_id = None
|
||||
|
||||
# Direct format: {"project_id": "..."}
|
||||
if "project_id" in project_response:
|
||||
project_id = project_response["project_id"]
|
||||
|
||||
# Nested format: {"response": {"payload": {"id": "..."}}}
|
||||
elif "response" in project_response and isinstance(project_response["response"], dict):
|
||||
payload = project_response["response"].get("payload", {})
|
||||
if isinstance(payload, dict) and "id" in payload:
|
||||
project_id = payload["id"]
|
||||
|
||||
if project_id:
|
||||
wrapper.set_project(project_id)
|
||||
logger.info(f"RXN project '{project_name}' created and set successfully with ID: {project_id}")
|
||||
else:
|
||||
logger.warning(f"Could not extract project ID from response: {project_response}")
|
||||
else:
|
||||
logger.warning(f"Unexpected project creation response: {project_response}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating RXN project: {e}")
|
||||
|
||||
# Check if project is set
|
||||
if not hasattr(wrapper, "project_id") or not wrapper.project_id:
|
||||
logger.warning("No project set. Some API calls may fail.")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
def _format_reaction_markdown(reactants: str, products: List[str],
|
||||
confidence: Optional[List[float]] = None) -> str:
|
||||
"""
|
||||
Format reaction results as Markdown.
|
||||
|
||||
Args:
|
||||
reactants: SMILES of reactants
|
||||
products: List of product SMILES
|
||||
confidence: Optional list of confidence scores
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string
|
||||
"""
|
||||
markdown = f"## 反应预测结果\n\n"
|
||||
markdown += f"### 输入反应物\n\n`{reactants}`\n\n"
|
||||
|
||||
markdown += f"### 预测产物\n\n"
|
||||
|
||||
for i, product in enumerate(products):
|
||||
conf_str = f" (置信度: {confidence[i]:.2f})" if confidence and i < len(confidence) else ""
|
||||
markdown += f"{i+1}. `{product}`{conf_str}\n"
|
||||
|
||||
return markdown
|
||||
|
||||
|
||||
@llm_tool(name="predict_reaction_outcome_rxn",
|
||||
description="Predict chemical reaction outcomes for given reactants using IBM RXN for Chemistry API")
|
||||
async def predict_reaction_outcome_rxn(reactants: str) -> str:
|
||||
"""
|
||||
Predict chemical reaction outcomes for given reactants.
|
||||
|
||||
This function uses the IBM RXN for Chemistry API to predict the most likely
|
||||
products formed when the given reactants are combined.
|
||||
|
||||
Args:
|
||||
reactants: SMILES notation of reactants, multiple reactants separated by dots (.).
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string containing the predicted reaction results.
|
||||
|
||||
Examples:
|
||||
>>> predict_reaction_outcome_rxn("BrBr.c1ccc2cc3ccccc3cc2c1")
|
||||
# Returns predicted products of bromine and anthracene reaction
|
||||
"""
|
||||
try:
|
||||
# Get RXN wrapper
|
||||
wrapper = _get_rxn_wrapper()
|
||||
|
||||
# Clean input
|
||||
reactants = reactants.strip()
|
||||
|
||||
# Submit prediction
|
||||
response = await asyncio.to_thread(
|
||||
wrapper.predict_reaction, reactants
|
||||
)
|
||||
|
||||
if not response or "prediction_id" not in response:
|
||||
return "Error: 无法提交反应预测请求"
|
||||
|
||||
# 直接获取结果,而不是通过_wait_for_result
|
||||
results = await asyncio.to_thread(
|
||||
wrapper.get_predict_reaction_results,
|
||||
response["prediction_id"]
|
||||
)
|
||||
|
||||
# Extract products
|
||||
try:
|
||||
attempts = results.get("response", {}).get("payload", {}).get("attempts", [])
|
||||
if not attempts:
|
||||
return "Error: 未找到预测结果"
|
||||
|
||||
# Get the top predicted product
|
||||
product_smiles = attempts[0].get("smiles", "")
|
||||
confidence = attempts[0].get("confidence", None)
|
||||
|
||||
# Format results
|
||||
return _format_reaction_markdown(
|
||||
reactants,
|
||||
[product_smiles] if product_smiles else ["无法预测产物"],
|
||||
[confidence] if confidence is not None else None
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.error(f"Error parsing prediction results: {e}")
|
||||
return f"Error: 解析预测结果时出错: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in predict_reaction_outcome: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="predict_reaction_topn_rxn",
|
||||
description="Predict multiple possible products for chemical reactions using IBM RXN for Chemistry API")
|
||||
async def predict_reaction_topn_rxn(reactants: Union[str, List[str], List[List[str]]], topn: int = 3) -> str:
|
||||
"""
|
||||
Predict multiple possible products for chemical reactions.
|
||||
|
||||
This function uses the IBM RXN for Chemistry API to predict multiple products
|
||||
that may be formed from given reactants, ranked by likelihood. Suitable for
|
||||
scenarios where multiple reaction pathways need to be considered.
|
||||
|
||||
Args:
|
||||
reactants: Reactants in one of the following formats:
|
||||
- String: SMILES notation for a single reaction, multiple reactants separated by dots (.)
|
||||
- List of strings: Multiple reactants for a single reaction, each reactant as a SMILES string
|
||||
- List of lists of strings: Multiple reactions, each reaction composed of multiple reactant SMILES strings
|
||||
topn: Number of predicted products to return for each reaction, default is 3.
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string containing multiple predicted reaction results.
|
||||
|
||||
Examples:
|
||||
>>> predict_reaction_topn_rxn("BrBr.c1ccc2cc3ccccc3cc2c1", 5)
|
||||
# Returns top 5 possible products for bromine and anthracene reaction
|
||||
|
||||
>>> predict_reaction_topn_rxn(["BrBr", "c1ccc2cc3ccccc3cc2c1"], 3)
|
||||
# Returns top 3 possible products for bromine and anthracene reaction
|
||||
|
||||
>>> predict_reaction_topn_rxn([
|
||||
... ["BrBr", "c1ccc2cc3ccccc3cc2c1"],
|
||||
... ["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"]
|
||||
... ], 3)
|
||||
# Returns top 3 possible products for two different reactions
|
||||
"""
|
||||
try:
|
||||
# Get RXN wrapper
|
||||
wrapper = _get_rxn_wrapper()
|
||||
|
||||
# Validate topn
|
||||
if topn < 1:
|
||||
topn = 1
|
||||
elif topn > 10:
|
||||
topn = 10
|
||||
logger.warning("topn限制为最大10个结果")
|
||||
|
||||
# Process input to create precursors_lists
|
||||
precursors_lists = []
|
||||
|
||||
if isinstance(reactants, str):
|
||||
# Single reaction as string (e.g., "BrBr.c1ccc2cc3ccccc3cc2c1")
|
||||
reactants = reactants.strip()
|
||||
precursors_lists = [reactants.split(".")]
|
||||
# For display in results
|
||||
reactants_display = [reactants]
|
||||
|
||||
elif isinstance(reactants, list):
|
||||
if all(isinstance(r, str) for r in reactants):
|
||||
# Single reaction as list of strings (e.g., ["BrBr", "c1ccc2cc3ccccc3cc2c1"])
|
||||
precursors_lists = [reactants]
|
||||
# For display in results
|
||||
reactants_display = [".".join(reactants)]
|
||||
|
||||
elif all(isinstance(r, list) for r in reactants):
|
||||
# Multiple reactions as list of lists (e.g., [["BrBr", "c1ccc2cc3ccccc3cc2c1"], ["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"]])
|
||||
precursors_lists = reactants
|
||||
# For display in results
|
||||
reactants_display = [".".join(r) for r in reactants]
|
||||
|
||||
else:
|
||||
return "Error: 反应物列表格式无效,必须是字符串列表或字符串列表的列表"
|
||||
else:
|
||||
return "Error: 反应物参数类型无效,必须是字符串或列表"
|
||||
|
||||
# Submit prediction
|
||||
response = await asyncio.to_thread(
|
||||
wrapper.predict_reaction_batch_topn,
|
||||
precursors_lists=precursors_lists,
|
||||
topn=topn
|
||||
)
|
||||
|
||||
if not response or "task_id" not in response:
|
||||
return "Error: 无法提交多产物反应预测请求"
|
||||
|
||||
# 直接获取结果,不使用循环等待
|
||||
results = await asyncio.to_thread(
|
||||
wrapper.get_predict_reaction_batch_topn_results,
|
||||
response["task_id"]
|
||||
)
|
||||
|
||||
# Extract products
|
||||
try:
|
||||
# 记录结果的结构,以便调试
|
||||
logger.info(f"Results structure: {results.keys()}")
|
||||
|
||||
# 更灵活地获取结果,使用get方法并提供默认值
|
||||
reaction_results = results.get("result", [])
|
||||
|
||||
# 如果结果为空,尝试其他可能的键
|
||||
if not reaction_results and "predictions" in results:
|
||||
reaction_results = results.get("predictions", [])
|
||||
logger.info("Using 'predictions' key instead of 'result'")
|
||||
|
||||
# 如果结果仍然为空,尝试直接使用整个结果
|
||||
if not reaction_results and isinstance(results, list):
|
||||
reaction_results = results
|
||||
logger.info("Using entire results as list")
|
||||
|
||||
if not reaction_results:
|
||||
logger.warning(f"No reaction results found. Available keys: {results.keys()}")
|
||||
return "Error: 未找到预测结果。请检查API响应格式。"
|
||||
|
||||
# Format results for all reactions
|
||||
markdown = "## 反应预测结果\n\n"
|
||||
|
||||
# 确保reaction_results和reactants_display长度匹配
|
||||
if len(reaction_results) != len(reactants_display):
|
||||
logger.warning(f"Mismatch between results ({len(reaction_results)}) and reactants ({len(reactants_display)})")
|
||||
# 如果不匹配,使用较短的长度
|
||||
min_len = min(len(reaction_results), len(reactants_display))
|
||||
reaction_results = reaction_results[:min_len]
|
||||
reactants_display = reactants_display[:min_len]
|
||||
|
||||
for i, (reaction_result, reactants_str) in enumerate(zip(reaction_results, reactants_display), 1):
|
||||
if not reaction_result:
|
||||
markdown += f"### 反应 {i}: 未找到预测结果\n\n"
|
||||
continue
|
||||
|
||||
# 记录每个反应结果的结构
|
||||
logger.info(f"Reaction {i} result structure: {type(reaction_result)}")
|
||||
|
||||
# Extract products and confidences for this reaction
|
||||
products = []
|
||||
confidences = []
|
||||
|
||||
# 处理不同格式的反应结果
|
||||
if isinstance(reaction_result, list):
|
||||
# 标准格式:列表中的每个项目是一个预测
|
||||
for pred in reaction_result:
|
||||
if isinstance(pred, dict) and "smiles" in pred:
|
||||
# 检查smiles是否为列表
|
||||
if isinstance(pred["smiles"], list) and pred["smiles"]:
|
||||
products.append(pred["smiles"][0]) # 取列表中的第一个元素
|
||||
else:
|
||||
products.append(pred["smiles"])
|
||||
confidences.append(pred.get("confidence", 0.0))
|
||||
elif isinstance(reaction_result, dict):
|
||||
# 根据用户反馈,检查是否有'results'键
|
||||
if "results" in reaction_result:
|
||||
# 遍历results列表
|
||||
for pred in reaction_result.get("results", []):
|
||||
if isinstance(pred, dict) and "smiles" in pred:
|
||||
# 检查smiles是否为列表
|
||||
if isinstance(pred["smiles"], list) and pred["smiles"]:
|
||||
products.append(pred["smiles"][0]) # 取列表中的第一个元素
|
||||
else:
|
||||
products.append(pred["smiles"])
|
||||
confidences.append(pred.get("confidence", 0.0))
|
||||
# 替代格式:字典中直接包含预测
|
||||
elif "smiles" in reaction_result:
|
||||
# 检查smiles是否为列表
|
||||
if isinstance(reaction_result["smiles"], list) and reaction_result["smiles"]:
|
||||
products.append(reaction_result["smiles"][0]) # 取列表中的第一个元素
|
||||
else:
|
||||
products.append(reaction_result["smiles"])
|
||||
confidences.append(reaction_result.get("confidence", 0.0))
|
||||
# 另一种可能的格式
|
||||
elif "products" in reaction_result:
|
||||
for prod in reaction_result.get("products", []):
|
||||
if isinstance(prod, dict) and "smiles" in prod:
|
||||
# 检查smiles是否为列表
|
||||
if isinstance(prod["smiles"], list) and prod["smiles"]:
|
||||
products.append(prod["smiles"][0]) # 取列表中的第一个元素
|
||||
else:
|
||||
products.append(prod["smiles"])
|
||||
confidences.append(prod.get("confidence", 0.0))
|
||||
|
||||
# Add results for this reaction
|
||||
markdown += f"### 反应 {i}\n\n"
|
||||
markdown += f"**输入反应物:** `{reactants_str}`\n\n"
|
||||
|
||||
if products:
|
||||
markdown += "**预测产物:**\n\n"
|
||||
for j, (product, confidence) in enumerate(zip(products, confidences), 1):
|
||||
markdown += f"{j}. `{product}` (置信度: {confidence:.2f})\n"
|
||||
else:
|
||||
markdown += "**预测产物:** 无法解析产物结构\n\n"
|
||||
# 添加原始结果以便调试
|
||||
markdown += f"**原始结果:** `{reaction_result}`\n\n"
|
||||
|
||||
markdown += "\n"
|
||||
|
||||
return markdown
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing topn prediction results: {e}", exc_info=True)
|
||||
return f"Error: 解析多产物预测结果时出错: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in predict_reaction_topn: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
# @llm_tool(name="predict_retrosynthesis",
|
||||
# description="预测目标分子的逆合成路径")
|
||||
# async def predict_retrosynthesis(target_molecule: str, max_steps: int = 3) -> str:
|
||||
# """
|
||||
# 预测目标分子的逆合成路径。
|
||||
|
||||
# 此函数使用IBM RXN for Chemistry API建议可能的合成路线,
|
||||
# 将目标分子分解为可能商业可得的更简单前体。
|
||||
|
||||
# Args:
|
||||
# target_molecule: 目标分子的SMILES表示法。
|
||||
# max_steps: 考虑的最大逆合成步骤数,默认为3。
|
||||
|
||||
# Returns:
|
||||
# 包含预测逆合成路径的格式化Markdown字符串。
|
||||
|
||||
# Examples:
|
||||
# >>> predict_retrosynthesis("Brc1c2ccccc2c(Br)c2ccccc12")
|
||||
# # 返回目标分子的可能合成路线
|
||||
# """
|
||||
# try:
|
||||
# # Get RXN wrapper
|
||||
# wrapper = _get_rxn_wrapper()
|
||||
|
||||
# # Clean input
|
||||
# target_molecule = target_molecule.strip()
|
||||
|
||||
# # Validate max_steps
|
||||
# if max_steps < 1:
|
||||
# max_steps = 1
|
||||
# elif max_steps > 5:
|
||||
# max_steps = 5
|
||||
# logger.warning("max_steps限制为最大5步")
|
||||
|
||||
# # Submit prediction
|
||||
# response = await asyncio.to_thread(
|
||||
# wrapper.predict_automatic_retrosynthesis,
|
||||
# product=target_molecule,
|
||||
# max_steps=max_steps
|
||||
# )
|
||||
|
||||
# if not response or "prediction_id" not in response:
|
||||
# return "Error: 无法提交逆合成预测请求"
|
||||
|
||||
# # 直接获取结果,而不是通过_wait_for_result
|
||||
# results = await asyncio.to_thread(
|
||||
# wrapper.get_predict_automatic_retrosynthesis_results,
|
||||
# response["prediction_id"]
|
||||
# )
|
||||
|
||||
# # Extract retrosynthetic paths
|
||||
# try:
|
||||
# paths = results.get("retrosynthetic_paths", [])
|
||||
|
||||
# if not paths:
|
||||
# return "## 逆合成分析结果\n\n未找到可行的逆合成路径。目标分子可能太复杂或结构有问题。"
|
||||
|
||||
# # Format results
|
||||
# markdown = f"## 逆合成分析结果\n\n"
|
||||
# markdown += f"### 目标分子\n\n`{target_molecule}`\n\n"
|
||||
# markdown += f"### 找到的合成路径: {len(paths)}\n\n"
|
||||
|
||||
# # Limit to top 3 paths for readability
|
||||
# display_paths = paths[:3]
|
||||
|
||||
# for i, path in enumerate(display_paths, 1):
|
||||
# markdown += f"#### 路径 {i}\n\n"
|
||||
|
||||
# # Extract sequence information
|
||||
# sequence_id = path.get("sequenceId", "未知")
|
||||
# confidence = path.get("confidence", 0.0)
|
||||
|
||||
# markdown += f"**置信度:** {confidence:.2f}\n\n"
|
||||
|
||||
# # Extract steps
|
||||
# steps = path.get("steps", [])
|
||||
|
||||
# if steps:
|
||||
# markdown += "**合成步骤:**\n\n"
|
||||
|
||||
# for j, step in enumerate(steps, 1):
|
||||
# # Extract reactants and products
|
||||
# reactants = step.get("reactants", [])
|
||||
# reactant_smiles = [r.get("smiles", "") for r in reactants if "smiles" in r]
|
||||
|
||||
# product = step.get("product", {})
|
||||
# product_smiles = product.get("smiles", "")
|
||||
|
||||
# markdown += f"步骤 {j}: "
|
||||
|
||||
# if reactant_smiles and product_smiles:
|
||||
# markdown += f"`{'.' if len(reactant_smiles) > 1 else ''.join(reactant_smiles)}` → `{product_smiles}`\n\n"
|
||||
# else:
|
||||
# markdown += "反应细节不可用\n\n"
|
||||
# else:
|
||||
# markdown += "**合成步骤:** 未提供详细步骤\n\n"
|
||||
|
||||
# markdown += "---\n\n"
|
||||
|
||||
# if len(paths) > 3:
|
||||
# markdown += f"*注: 仅显示前3条路径,共找到{len(paths)}条可能的合成路径。*\n"
|
||||
|
||||
# return markdown
|
||||
|
||||
# except (KeyError, IndexError) as e:
|
||||
# logger.error(f"Error parsing retrosynthesis results: {e}")
|
||||
# return f"Error: 解析逆合成结果时出错: {str(e)}"
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in predict_retrosynthesis: {e}")
|
||||
# return f"Error: {str(e)}"
|
||||
|
||||
|
||||
# @llm_tool(name="predict_biocatalytic_retrosynthesis",
|
||||
# description="使用生物催化模型预测目标分子的逆合成路径")
|
||||
# async def predict_biocatalytic_retrosynthesis(target_molecule: str) -> str:
|
||||
# """
|
||||
# 使用生物催化模型预测目标分子的逆合成路径。
|
||||
|
||||
# 此函数使用IBM RXN for Chemistry API的专门生物催化模型,
|
||||
# 建议可能的酶催化合成路线来创建目标分子。
|
||||
|
||||
# Args:
|
||||
# target_molecule: 目标分子的SMILES表示法。
|
||||
|
||||
# Returns:
|
||||
# 包含预测生物催化逆合成路径的格式化Markdown字符串。
|
||||
|
||||
# Examples:
|
||||
# >>> predict_biocatalytic_retrosynthesis("OC1C(O)C=C(Br)C=C1")
|
||||
# # 返回目标分子的可能酶催化合成路线
|
||||
# """
|
||||
# try:
|
||||
# # Get RXN wrapper
|
||||
# wrapper = _get_rxn_wrapper()
|
||||
|
||||
# # Clean input
|
||||
# target_molecule = target_molecule.strip()
|
||||
|
||||
# # Submit prediction with enzymatic model
|
||||
# # Note: The model name might change in future API versions
|
||||
# response = await asyncio.to_thread(
|
||||
# wrapper.predict_automatic_retrosynthesis,
|
||||
# product=target_molecule,
|
||||
# ai_model="enzymatic-2021-04-16" # Use the enzymatic model
|
||||
# )
|
||||
|
||||
# if not response or "prediction_id" not in response:
|
||||
# return "Error: 无法提交生物催化逆合成预测请求"
|
||||
|
||||
# # 直接获取结果,而不是通过_wait_for_result
|
||||
# results = await asyncio.to_thread(
|
||||
# wrapper.get_predict_automatic_retrosynthesis_results,
|
||||
# response["prediction_id"]
|
||||
# )
|
||||
|
||||
# # Extract retrosynthetic paths
|
||||
# try:
|
||||
# paths = results.get("retrosynthetic_paths", [])
|
||||
|
||||
# if not paths:
|
||||
# return "## 生物催化逆合成分析结果\n\n未找到可行的酶催化合成路径。目标分子可能不适合酶催化或结构有问题。"
|
||||
|
||||
# # Format results
|
||||
# markdown = f"## 生物催化逆合成分析结果\n\n"
|
||||
# markdown += f"### 目标分子\n\n`{target_molecule}`\n\n"
|
||||
# markdown += f"### 找到的酶催化合成路径: {len(paths)}\n\n"
|
||||
|
||||
# # Limit to top 3 paths for readability
|
||||
# display_paths = paths[:3]
|
||||
|
||||
# for i, path in enumerate(display_paths, 1):
|
||||
# markdown += f"#### 路径 {i}\n\n"
|
||||
|
||||
# # Extract sequence information
|
||||
# sequence_id = path.get("sequenceId", "未知")
|
||||
# confidence = path.get("confidence", 0.0)
|
||||
|
||||
# markdown += f"**置信度:** {confidence:.2f}\n\n"
|
||||
|
||||
# # Extract steps
|
||||
# steps = path.get("steps", [])
|
||||
|
||||
# if steps:
|
||||
# markdown += "**酶催化步骤:**\n\n"
|
||||
|
||||
# for j, step in enumerate(steps, 1):
|
||||
# # Extract reactants and products
|
||||
# reactants = step.get("reactants", [])
|
||||
# reactant_smiles = [r.get("smiles", "") for r in reactants if "smiles" in r]
|
||||
|
||||
# product = step.get("product", {})
|
||||
# product_smiles = product.get("smiles", "")
|
||||
|
||||
# markdown += f"步骤 {j}: "
|
||||
|
||||
# if reactant_smiles and product_smiles:
|
||||
# markdown += f"`{'.' if len(reactant_smiles) > 1 else ''.join(reactant_smiles)}` → `{product_smiles}`\n\n"
|
||||
# else:
|
||||
# markdown += "反应细节不可用\n\n"
|
||||
|
||||
# # Add enzyme information if available
|
||||
# if "enzymeClass" in step:
|
||||
# markdown += f"*可能的酶类别: {step['enzymeClass']}*\n\n"
|
||||
# else:
|
||||
# markdown += "**合成步骤:** 未提供详细步骤\n\n"
|
||||
|
||||
# markdown += "---\n\n"
|
||||
|
||||
# if len(paths) > 3:
|
||||
# markdown += f"*注: 仅显示前3条路径,共找到{len(paths)}条可能的酶催化合成路径。*\n"
|
||||
|
||||
# return markdown
|
||||
|
||||
# except (KeyError, IndexError) as e:
|
||||
# logger.error(f"Error parsing biocatalytic retrosynthesis results: {e}")
|
||||
# return f"Error: 解析生物催化逆合成结果时出错: {str(e)}"
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in predict_biocatalytic_retrosynthesis: {e}")
|
||||
# return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="predict_reaction_properties_rxn",
|
||||
description="Predict chemical reaction properties such as atom mapping and yield using IBM RXN for Chemistry API")
|
||||
async def predict_reaction_properties_rxn(
|
||||
reaction: str,
|
||||
property_type: str = "atom-mapping"
|
||||
) -> str:
|
||||
"""
|
||||
Predict chemical reaction properties such as atom mapping and yield.
|
||||
|
||||
This function uses the IBM RXN for Chemistry API to predict various properties
|
||||
of chemical reactions, including atom-to-atom mapping (showing how atoms in
|
||||
reactants correspond to atoms in products) and potential reaction yields.
|
||||
|
||||
Args:
|
||||
reaction: SMILES notation of the reaction (reactants>>products).
|
||||
property_type: Type of property to predict. Options: "atom-mapping", "yield".
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string containing predicted reaction properties.
|
||||
|
||||
Examples:
|
||||
>>> predict_reaction_properties_rxn("CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F", "atom-mapping")
|
||||
# Returns atom mapping for the reaction
|
||||
"""
|
||||
try:
|
||||
# Get RXN wrapper
|
||||
wrapper = _get_rxn_wrapper()
|
||||
|
||||
# Clean input
|
||||
reaction = reaction.strip()
|
||||
|
||||
# Validate property_type
|
||||
valid_property_types = ["atom-mapping", "yield"]
|
||||
if property_type not in valid_property_types:
|
||||
return f"Error: 无效的属性类型 '{property_type}'。支持的类型: {', '.join(valid_property_types)}"
|
||||
|
||||
# Determine model based on property type
|
||||
ai_model = "atom-mapping-2020" if property_type == "atom-mapping" else "yield-2020-08-10"
|
||||
|
||||
# Submit prediction
|
||||
response = await asyncio.to_thread(
|
||||
wrapper.predict_reaction_properties,
|
||||
reactions=[reaction],
|
||||
ai_model=ai_model
|
||||
)
|
||||
|
||||
if not response or "response" not in response:
|
||||
return f"Error: 无法提交{property_type}预测请求"
|
||||
|
||||
# Extract results
|
||||
try:
|
||||
content = response.get("response", {}).get("payload", {}).get("content", [])
|
||||
|
||||
if not content:
|
||||
return f"Error: 未找到{property_type}预测结果"
|
||||
|
||||
# Format results based on property type
|
||||
markdown = f"## 反应{property_type}预测结果\n\n"
|
||||
markdown += f"### 输入反应\n\n`{reaction}`\n\n"
|
||||
|
||||
if property_type == "atom-mapping":
|
||||
# Extract mapped reaction
|
||||
mapped_reaction = content[0].get("value", "")
|
||||
|
||||
if not mapped_reaction:
|
||||
return "Error: 无法生成原子映射"
|
||||
|
||||
markdown += "### 原子映射结果\n\n"
|
||||
markdown += f"`{mapped_reaction}`\n\n"
|
||||
|
||||
# Split into reactants and products for explanation
|
||||
if ">>" in mapped_reaction:
|
||||
reactants, products = mapped_reaction.split(">>")
|
||||
markdown += "### 映射解释\n\n"
|
||||
markdown += "原子映射显示了反应物中的原子如何对应到产物中的原子。\n"
|
||||
markdown += "每个原子上的数字表示映射ID,相同ID的原子在反应前后是同一个原子。\n\n"
|
||||
markdown += f"**映射的反应物:** `{reactants}`\n\n"
|
||||
markdown += f"**映射的产物:** `{products}`\n"
|
||||
|
||||
elif property_type == "yield":
|
||||
# Extract predicted yield
|
||||
predicted_yield = content[0].get("value", "")
|
||||
|
||||
if not predicted_yield:
|
||||
return "Error: 无法预测反应产率"
|
||||
|
||||
try:
|
||||
yield_value = float(predicted_yield)
|
||||
markdown += "### 产率预测结果\n\n"
|
||||
markdown += f"**预测产率:** {yield_value:.1f}%\n\n"
|
||||
|
||||
# Add interpretation
|
||||
if yield_value < 30:
|
||||
markdown += "**解释:** 预测产率较低,反应可能效率不高。考虑优化反应条件或探索替代路线。\n"
|
||||
elif yield_value < 70:
|
||||
markdown += "**解释:** 预测产率中等,反应可能是可行的,但有优化空间。\n"
|
||||
else:
|
||||
markdown += "**解释:** 预测产率较高,反应可能非常有效。\n"
|
||||
except ValueError:
|
||||
markdown += f"**预测产率:** {predicted_yield}\n"
|
||||
|
||||
return markdown
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.error(f"Error parsing reaction properties results: {e}")
|
||||
return f"Error: 解析反应属性预测结果时出错: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in predict_reaction_properties: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="extract_reaction_actions_rxn",
|
||||
description="Extract structured reaction steps from text descriptions using IBM RXN for Chemistry API")
|
||||
async def extract_reaction_actions_rxn(reaction_text: str) -> str:
|
||||
"""
|
||||
Extract structured reaction steps from text descriptions.
|
||||
|
||||
This function uses the IBM RXN for Chemistry API to parse text descriptions
|
||||
of chemical procedures and extract structured actions representing the steps
|
||||
of the procedure.
|
||||
|
||||
Args:
|
||||
reaction_text: Text description of a chemical reaction procedure.
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string containing the extracted reaction steps.
|
||||
|
||||
Examples:
|
||||
>>> extract_reaction_actions_rxn("To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.")
|
||||
# Returns structured steps extracted from the text
|
||||
"""
|
||||
try:
|
||||
# Get RXN wrapper
|
||||
wrapper = _get_rxn_wrapper()
|
||||
|
||||
# Clean input
|
||||
reaction_text = reaction_text.strip()
|
||||
|
||||
if not reaction_text:
|
||||
return "Error: 反应文本为空"
|
||||
|
||||
# Submit extraction request
|
||||
response = await asyncio.to_thread(
|
||||
wrapper.paragraph_to_actions,
|
||||
paragraph=reaction_text
|
||||
)
|
||||
|
||||
# 检查response是否存在
|
||||
if not response:
|
||||
return "Error: 无法从文本中提取反应步骤"
|
||||
|
||||
# 直接返回response,不做任何处理
|
||||
# 这是基于参考代码中直接打印results['actions']的方式
|
||||
# 我们假设response本身就是我们需要的结果
|
||||
return f"""## 反应步骤提取结果
|
||||
|
||||
### 输入文本
|
||||
|
||||
{reaction_text}
|
||||
|
||||
### 提取的反应步骤
|
||||
|
||||
```
|
||||
{response}
|
||||
```
|
||||
"""
|
||||
|
||||
return markdown
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_reaction_actions: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
84
sci_mcp/core/config.py
Executable file
84
sci_mcp/core/config.py
Executable file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Configuration Module
|
||||
|
||||
This module provides configuration settings for the Mars Toolkit.
|
||||
It includes API keys, endpoints, paths, and other configuration parameters.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
class Config:
|
||||
|
||||
@classmethod
|
||||
def as_dict(cls) -> Dict[str, Any]:
|
||||
"""Return all configuration settings as a dictionary"""
|
||||
return {
|
||||
key: value for key, value in cls.__dict__.items()
|
||||
if not key.startswith('__') and not callable(value)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def update(cls, **kwargs):
|
||||
"""Update configuration settings"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(cls, key):
|
||||
setattr(cls, key, value)
|
||||
|
||||
|
||||
class General_Config(Config):
|
||||
"""Configuration class for General MCP"""
|
||||
|
||||
# Searxng
|
||||
SEARXNG_HOST="http://192.168.168.1:40032/"
|
||||
SEARXNG_MAX_RESULTS=10
|
||||
|
||||
|
||||
class Material_Config(Config):
|
||||
|
||||
|
||||
# Materials Project
|
||||
MP_API_KEY = 'PMASAg256b814q3OaSRWeVc7MKx4mlKI'
|
||||
MP_ENDPOINT = 'https://api.materialsproject.org/'
|
||||
MP_TOPK = 3
|
||||
|
||||
LOCAL_MP_PROPS_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Props/'
|
||||
LOCAL_MP_CIF_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/MPDatasets/'
|
||||
# Proxy
|
||||
HTTP_PROXY = ''#'http://192.168.168.1:20171'
|
||||
HTTPS_PROXY = ''#'http://192.168.168.1:20171'
|
||||
|
||||
# FairChem
|
||||
FAIRCHEM_MODEL_PATH = '/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/support/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
|
||||
FMAX = 0.05
|
||||
|
||||
# MatterGen
|
||||
MATTERGENMODEL_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/support/pretrained_models/mattergen_ckpt'
|
||||
MATTERGEN_ROOT='/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/mattergen_gen/mattergen'
|
||||
MATTERGENMODEL_RESULT_PATH = 'results/'
|
||||
|
||||
# Dify
|
||||
DIFY_ROOT_URL = 'http://192.168.191.101:6080'
|
||||
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
|
||||
|
||||
#temp root
|
||||
TEMP_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/material'
|
||||
|
||||
|
||||
class Chemistry_Config(Config):
|
||||
|
||||
|
||||
TEMP_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/chemistry'
|
||||
|
||||
RXN4_CHEMISTRY_KEY='apk-8928522a146c2503f30b16d9909222d7583f412ee8f1049f08d32a089ba88d34'
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
material_config = Material_Config()
|
||||
general_config = General_Config()
|
||||
chemistry_config = Chemistry_Config()
|
||||
|
||||
|
||||
|
||||
315
sci_mcp/core/llm_tools.py
Executable file
315
sci_mcp/core/llm_tools.py
Executable file
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
LLM Tools Module
|
||||
|
||||
This module provides decorators and utilities for defining, registering, and managing LLM tools.
|
||||
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
|
||||
registered tools for use with LLM APIs.
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
|
||||
import docstring_parser
|
||||
from pydantic import BaseModel, create_model, Field
|
||||
|
||||
# Registry to store all registered tools
|
||||
_TOOL_REGISTRY = {}
|
||||
# Mapping of domain names to their module paths
|
||||
_DOMAIN_MODULE_MAPPING = {
|
||||
'material': 'sci_mcp.material_mcp',
|
||||
'general': 'sci_mcp.general_mcp',
|
||||
'biology': 'sci_mcp.biology_mcp',
|
||||
'chemistry': 'sci_mcp.chemistry_mcp'
|
||||
}
|
||||
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
|
||||
"""
|
||||
Decorator to mark a function as an LLM tool.
|
||||
|
||||
This decorator registers the function as an LLM tool, generates a JSON schema for it,
|
||||
and makes it available for retrieval through the get_tools function.
|
||||
|
||||
Args:
|
||||
name: Optional custom name for the tool. If not provided, the function name will be used.
|
||||
description: Optional custom description for the tool. If not provided, the function's
|
||||
docstring will be used.
|
||||
|
||||
Returns:
|
||||
The decorated function with additional attributes for LLM tool functionality.
|
||||
|
||||
Example:
|
||||
@llm_tool(name="weather_lookup", description="Get current weather for a location")
|
||||
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
|
||||
'''Get weather information for a specific location.'''
|
||||
# Implementation...
|
||||
return {"temperature": 22.5, "conditions": "sunny"}
|
||||
"""
|
||||
# Handle case when decorator is used without parentheses: @llm_tool
|
||||
if callable(name):
|
||||
func = name
|
||||
name = None
|
||||
description = None
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
|
||||
def decorator(func: Callable) -> Callable:
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
return decorator
|
||||
|
||||
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
|
||||
"""Implementation of the llm_tool decorator."""
|
||||
# Get function signature and docstring
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func) or ""
|
||||
parsed_doc = docstring_parser.parse(doc)
|
||||
|
||||
# Determine tool name
|
||||
tool_name = name or func.__name__
|
||||
|
||||
# Determine tool description
|
||||
tool_description = description or doc
|
||||
|
||||
# Create parameter properties for JSON schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self parameter for methods
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = None if param.default is inspect.Parameter.empty else param.default
|
||||
param_required = param.default is inspect.Parameter.empty
|
||||
|
||||
# Get parameter description from docstring if available
|
||||
param_desc = ""
|
||||
for param_doc in parsed_doc.params:
|
||||
if param_doc.arg_name == param_name:
|
||||
param_desc = param_doc.description
|
||||
break
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0] # The actual type
|
||||
if len(args) > 1 and isinstance(args[1], str):
|
||||
param_desc = args[1] # The description
|
||||
|
||||
# Create property for parameter
|
||||
param_schema = {
|
||||
"type": _get_json_type(param_type),
|
||||
"description": param_desc,
|
||||
"title": param_name.replace("_", " ").title()
|
||||
}
|
||||
|
||||
# Add default value if available
|
||||
if param_default is not None:
|
||||
param_schema["default"] = param_default
|
||||
|
||||
properties[param_name] = param_schema
|
||||
|
||||
# Add to required list if no default value
|
||||
if param_required:
|
||||
required.append(param_name)
|
||||
|
||||
# Create OpenAI format JSON schema
|
||||
openai_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Create MCP format JSON schema
|
||||
mcp_schema = {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
|
||||
# Create Pydantic model for args schema
|
||||
field_definitions = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = ... if param.default is inspect.Parameter.empty else param.default
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0]
|
||||
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
|
||||
else:
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default))
|
||||
|
||||
# Create args schema model
|
||||
model_name = f"{tool_name.title().replace('_', '')}Schema"
|
||||
args_schema = create_model(model_name, **field_definitions)
|
||||
|
||||
# 根据原始函数是否是异步函数来创建相应类型的包装函数
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Attach metadata to function
|
||||
wrapper.is_llm_tool = True
|
||||
wrapper.tool_name = tool_name
|
||||
wrapper.tool_description = tool_description
|
||||
wrapper.openai_schema = openai_schema
|
||||
wrapper.mcp_schema = mcp_schema
|
||||
wrapper.args_schema = args_schema
|
||||
|
||||
# Register the tool
|
||||
_TOOL_REGISTRY[tool_name] = wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_all_tools() -> Dict[str, Callable]:
|
||||
"""
|
||||
Get all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping tool names to their corresponding functions.
|
||||
"""
|
||||
return _TOOL_REGISTRY
|
||||
|
||||
def get_all_tool_schemas(schema_type='openai') -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get JSON schemas for all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
|
||||
"""
|
||||
return [tool.mcp_schema for tool in _TOOL_REGISTRY.values()] if schema_type == 'mcp' else [tool.openai_schema for tool in _TOOL_REGISTRY.values()]
|
||||
|
||||
def import_domain_tools(domains: List[str]) -> None:
|
||||
"""
|
||||
Import tools from specified domains to ensure they are registered.
|
||||
|
||||
This function dynamically imports modules from the specified domains to ensure
|
||||
that all tools decorated with @llm_tool are registered in the _TOOL_REGISTRY.
|
||||
|
||||
Args:
|
||||
domains: List of domain names (e.g., ['material', 'general'])
|
||||
"""
|
||||
for domain in domains:
|
||||
if domain not in _DOMAIN_MODULE_MAPPING:
|
||||
continue
|
||||
|
||||
module_path = _DOMAIN_MODULE_MAPPING[domain]
|
||||
try:
|
||||
# Import the base module
|
||||
base_module = importlib.import_module(module_path)
|
||||
base_path = os.path.dirname(base_module.__file__)
|
||||
|
||||
# Recursively import all submodules
|
||||
for _, name, is_pkg in pkgutil.walk_packages([base_path], f"{module_path}."):
|
||||
try:
|
||||
importlib.import_module(name)
|
||||
except ImportError as e:
|
||||
print(f"Error importing {name}: {e}")
|
||||
except ImportError as e:
|
||||
print(f"Error importing domain {domain}: {e}")
|
||||
|
||||
def get_domain_tools(domains: List[str]) -> Dict[str, Dict[str, Callable]]:
|
||||
"""
|
||||
Get tools from specified domains.
|
||||
|
||||
Args:
|
||||
domains: List of domain names (e.g., ['material', 'general'])
|
||||
|
||||
Returns:
|
||||
A dictionary that maps tool names and their functions
|
||||
"""
|
||||
# First, ensure all tools from the specified domains are imported and registered
|
||||
import_domain_tools(domains)
|
||||
|
||||
domain_tools = {}
|
||||
for domain in domains:
|
||||
if domain not in _DOMAIN_MODULE_MAPPING:
|
||||
continue
|
||||
|
||||
domain_module_prefix = _DOMAIN_MODULE_MAPPING[domain]
|
||||
|
||||
|
||||
for tool_name, tool_func in _TOOL_REGISTRY.items():
|
||||
# Check if the tool's module belongs to this domain
|
||||
if hasattr(tool_func, "__module__") and tool_func.__module__.startswith(domain_module_prefix):
|
||||
domain_tools[tool_name] = tool_func
|
||||
|
||||
|
||||
return domain_tools
|
||||
|
||||
def get_domain_tool_schemas(domains: List[str],schema_type='openai') -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get JSON schemas for tools from specified domains.
|
||||
|
||||
Args:
|
||||
domains: List of domain names (e.g., ['material', 'general'])
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping domain names to lists of tool schemas
|
||||
"""
|
||||
# First, get all domain tools
|
||||
domain_tools = get_domain_tools(domains)
|
||||
|
||||
if schema_type == 'mcp':
|
||||
tools_schema_list = [tool.mcp_schema for tool in domain_tools.values()]
|
||||
else:
|
||||
tools_schema_list = [tool.openai_schema for tool in domain_tools.values()]
|
||||
|
||||
|
||||
return tools_schema_list
|
||||
|
||||
|
||||
def _get_json_type(python_type: Any) -> str:
|
||||
"""
|
||||
Convert Python type to JSON schema type.
|
||||
|
||||
Args:
|
||||
python_type: Python type annotation
|
||||
|
||||
Returns:
|
||||
Corresponding JSON schema type as string
|
||||
"""
|
||||
if python_type is str:
|
||||
return "string"
|
||||
elif python_type is int:
|
||||
return "integer"
|
||||
elif python_type is float:
|
||||
return "number"
|
||||
elif python_type is bool:
|
||||
return "boolean"
|
||||
elif python_type is list or python_type is List:
|
||||
return "array"
|
||||
elif python_type is dict or python_type is Dict:
|
||||
return "object"
|
||||
else:
|
||||
# Default to string for complex types
|
||||
return "string"
|
||||
0
sci_mcp/general_mcp/__init__.py
Normal file
0
sci_mcp/general_mcp/__init__.py
Normal file
0
sci_mcp/general_mcp/searxng_query/__init__.py
Normal file
0
sci_mcp/general_mcp/searxng_query/__init__.py
Normal file
78
sci_mcp/general_mcp/searxng_query/searxng_query_tools.py
Normal file
78
sci_mcp/general_mcp/searxng_query/searxng_query_tools.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Search Online Module
|
||||
|
||||
This module provides functions for searching information on the web.
|
||||
"""
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import general_config
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Annotated, Any, Dict, List, Union
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
import mcp.types as types
|
||||
|
||||
|
||||
@llm_tool(name="search_online_searxng", description="Search scientific information online using searxng")
|
||||
async def search_online_searxng(
|
||||
query: Annotated[str, "Search term"],
|
||||
num_results: Annotated[int, "Number of results"] = 5
|
||||
) -> str:
|
||||
"""
|
||||
Searches for scientific information online and returns results as a formatted string.
|
||||
|
||||
Args:
|
||||
query: Search term for scientific content
|
||||
num_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
Formatted string with search results (titles, snippets, links)
|
||||
"""
|
||||
# lzy: 此部分到正式发布时可能要删除,因为searxng 已在本地部署,因此本地调试时无需设置代理
|
||||
os.environ['HTTP_PROXY'] = ''
|
||||
os.environ['HTTPS_PROXY'] = ''
|
||||
try:
|
||||
max_results = min(int(num_results), general_config.SEARXNG_MAX_RESULTS)
|
||||
search = SearxSearchWrapper(
|
||||
searx_host=general_config.SEARXNG_HOST,
|
||||
categories=["science",],
|
||||
k=num_results
|
||||
)
|
||||
|
||||
# Execute search in a separate thread to avoid blocking the event loop
|
||||
# since SearxSearchWrapper doesn't have native async support
|
||||
loop = asyncio.get_event_loop()
|
||||
raw_results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: search.results(query, language=['en','zh'], num_results=max_results)
|
||||
)
|
||||
|
||||
# Transform results into structured format
|
||||
formatted_results = []
|
||||
for result in raw_results:
|
||||
formatted_results.append({
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("snippet", ""),
|
||||
"link": result.get("link", ""),
|
||||
"source": result.get("source", "")
|
||||
})
|
||||
|
||||
# Format results into a readable Markdown string
|
||||
result_str = f"Search Results for '{query}' ({len(formatted_results)} items):\n\n"
|
||||
if len(formatted_results) > 0:
|
||||
for i, res in enumerate(formatted_results):
|
||||
title = res.get("title", "No Title")
|
||||
snippet = res.get("snippet", "No Snippet")
|
||||
link = res.get("link", "No Link")
|
||||
source = res.get("source", "No Source")
|
||||
result_str += f"{i + 1}. **{title}**\n"
|
||||
result_str += f" - Snippet: {snippet}\n"
|
||||
result_str += f" - Link: [{link}]({link})\n"
|
||||
result_str += f" - Source: {source}\n\n"
|
||||
else:
|
||||
result_str += "No results found.\n"
|
||||
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
34
sci_mcp/material_mcp/__init__.py
Executable file
34
sci_mcp/material_mcp/__init__.py
Executable file
@@ -0,0 +1,34 @@
|
||||
|
||||
# # Core modules
|
||||
# from mars_toolkit.core.config import config
|
||||
|
||||
|
||||
# # Basic tools
|
||||
# from mars_toolkit.misc.misc_tools import get_current_time
|
||||
|
||||
# # Compute modules
|
||||
# from mars_toolkit.compute.material_gen import generate_material
|
||||
# from mars_toolkit.compute.property_pred import predict_properties
|
||||
# from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
|
||||
|
||||
# # Query modules
|
||||
# from mars_toolkit.query.mp_query import (
|
||||
# search_material_property_from_material_project,
|
||||
# get_crystal_structures_from_materials_project,
|
||||
# get_mpid_from_formula
|
||||
# )
|
||||
# from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
# from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
# from mars_toolkit.query.web_search import search_online
|
||||
|
||||
# # Visualization modules
|
||||
|
||||
|
||||
# from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# __version__ = "0.1.0"
|
||||
# __all__ = ["llm_tool", "get_tools", "get_tool_schemas"]
|
||||
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Structure Optimization Module
|
||||
|
||||
This module provides functions for optimizing crystal structures using the FairChem model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from ase.io import read, write
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
from ..support.utils import convert_structure, remove_symmetry_equiv_xyz, read_structure_from_file_name_or_content_string
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FairChem模型
|
||||
calc = None
|
||||
|
||||
def init_model():
|
||||
"""初始化FairChem模型"""
|
||||
global calc
|
||||
if calc is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from fairchem.core import OCPCalculator
|
||||
calc = OCPCalculator(checkpoint_path=material_config.FAIRCHEM_MODEL_PATH)
|
||||
logger.info("FairChem model initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def generate_symmetry_cif(structure: Structure) -> str:
|
||||
"""
|
||||
生成对称性CIF
|
||||
|
||||
Args:
|
||||
structure: Pymatgen Structure对象
|
||||
|
||||
Returns:
|
||||
CIF格式的字符串
|
||||
"""
|
||||
analyzer = SpacegroupAnalyzer(structure)
|
||||
structure_refined = analyzer.get_refined_structure()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
|
||||
cif_writer.write_file(tmp_file.name)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
os.unlink(tmp_file.name)
|
||||
return content
|
||||
|
||||
def optimize_structure(atoms: Atoms, output_format: str, fmax: float = 0.05) -> str:
|
||||
"""
|
||||
优化晶体结构
|
||||
|
||||
Args:
|
||||
atoms: ASE Atoms对象
|
||||
output_format: 输出格式 (cif, xyz, vasp等)
|
||||
fmax: 力收敛标准
|
||||
|
||||
Returns:
|
||||
包含优化结果的格式化字符串
|
||||
"""
|
||||
atoms.calc = calc
|
||||
|
||||
try:
|
||||
# 捕获优化过程的输出
|
||||
temp_output = StringIO()
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = temp_output
|
||||
|
||||
# 执行优化
|
||||
dyn = FIRE(FrechetCellFilter(atoms))
|
||||
dyn.run(fmax=fmax)
|
||||
|
||||
# 恢复标准输出并获取日志
|
||||
sys.stdout = original_stdout
|
||||
optimization_log = temp_output.getvalue()
|
||||
temp_output.close()
|
||||
|
||||
# 获取总能量
|
||||
total_energy = atoms.get_potential_energy()
|
||||
|
||||
# 处理优化后的结构
|
||||
if output_format == "cif":
|
||||
optimized_structure = Structure.from_ase_atoms(atoms)
|
||||
content = generate_symmetry_cif(optimized_structure)
|
||||
content = remove_symmetry_equiv_xyz(content)
|
||||
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||||
write(tmp_file.name, atoms)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
# 格式化返回结果
|
||||
format_result = f"""
|
||||
The following is the optimized crystal structure information:
|
||||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||
**Total Energy: {total_energy} eV**
|
||||
|
||||
#### Optimizing Log:
|
||||
```text
|
||||
{optimization_log}
|
||||
```
|
||||
### Optimized {output_format.upper()} Content:
|
||||
```
|
||||
{content}
|
||||
```
|
||||
"""
|
||||
return format_result
|
||||
except Exception as e:
|
||||
return f"Error: Failed to optimize structure: {str(e)}"
|
||||
|
||||
@llm_tool(name="optimize_crystal_structure_FairChem",
|
||||
description="Optimizes crystal structures using the FairChem model")
|
||||
async def optimize_crystal_structure_FairChem(
|
||||
structure_source: str,
|
||||
format_type: str = "auto",
|
||||
optimization_level: str = "normal"
|
||||
) -> str:
|
||||
"""
|
||||
Optimizes a crystal structure to find its lowest energy configuration.
|
||||
|
||||
Args:
|
||||
structure_source: Either a file name or direct structure content (CIF, XYZ, POSCAR)
|
||||
format_type: Structure format type (auto, cif, xyz, poscar). Default "auto" will attempt to detect format.
|
||||
optimization_level: Optimization precision level (quick, normal, precise)
|
||||
|
||||
Returns:
|
||||
Optimized structure with total energy and optimization details
|
||||
"""
|
||||
# 确保模型已初始化
|
||||
if calc is None:
|
||||
init_model()
|
||||
|
||||
# 设置优化参数
|
||||
fmax_values = {
|
||||
"quick": 0.1,
|
||||
"normal": 0.05,
|
||||
"precise": 0.01
|
||||
}
|
||||
fmax = fmax_values.get(optimization_level, 0.05)
|
||||
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_optimization():
|
||||
try:
|
||||
# 处理输入结构
|
||||
content, actual_format = read_structure_from_file_name_or_content_string(structure_source, format_type)
|
||||
|
||||
# 转换格式映射
|
||||
format_mapping = {
|
||||
"cif": "cif",
|
||||
"xyz": "xyz",
|
||||
"poscar": "vasp",
|
||||
"vasp": "vasp"
|
||||
}
|
||||
final_format = format_mapping.get(actual_format.lower(), "cif")
|
||||
|
||||
# 转换结构
|
||||
atoms = convert_structure(final_format, content)
|
||||
if atoms is None:
|
||||
return f"Error: Unable to convert input structure. Please check if the format is correct."
|
||||
|
||||
# 优化结构
|
||||
return optimize_structure(atoms, final_format, fmax=fmax)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: Failed to optimize structure: {str(e)}"
|
||||
|
||||
return await asyncio.to_thread(run_optimization)
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import codecs
|
||||
import json
|
||||
import requests
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import material_config
|
||||
|
||||
@llm_tool(
|
||||
name="retrieval_from_knowledge_base",
|
||||
description="Retrieve information from local materials science literature knowledge base"
|
||||
)
|
||||
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
|
||||
"""
|
||||
检索本地材料科学文献知识库中的相关信息
|
||||
|
||||
Args:
|
||||
query: 查询字符串,如材料名称"CsPbBr3"
|
||||
topk: 返回结果数量,默认3条
|
||||
|
||||
Returns:
|
||||
包含文档ID、标题和相关性分数的字典
|
||||
"""
|
||||
# 设置Dify API的URL端点
|
||||
url = f'{material_config.DIFY_ROOT_URL}/v1/chat-messages'
|
||||
|
||||
# 配置请求头,包含API密钥和内容类型
|
||||
headers = {
|
||||
'Authorization': f'Bearer {material_config.DIFY_API_KEY}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# 准备请求数据
|
||||
data = {
|
||||
"inputs": {"topK": topk}, # 设置返回的最大结果数量
|
||||
"query": query, # 设置查询字符串
|
||||
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
|
||||
"conversation_id": "", # 不使用会话ID,每次都是独立查询
|
||||
"user": "abc-123" # 用户标识符
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送POST请求到Dify API并获取响应
|
||||
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
|
||||
response = requests.post(url, headers=headers, json=data, timeout=1111)
|
||||
|
||||
# 获取响应文本
|
||||
response_text = response.text
|
||||
|
||||
# 解码响应文本中的Unicode转义序列
|
||||
response_text = codecs.decode(response_text, 'unicode_escape')
|
||||
|
||||
# 将响应文本解析为JSON对象
|
||||
result_json = json.loads(response_text)
|
||||
|
||||
# 从响应中提取元数据
|
||||
metadata = result_json.get("metadata", {})
|
||||
|
||||
# 构建包含关键信息的结果字典
|
||||
useful_info = {
|
||||
"id": metadata.get("document_id"), # 文档ID
|
||||
"title": result_json.get("title"), # 文档标题
|
||||
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
|
||||
"score": metadata.get("score") # 相关性分数
|
||||
}
|
||||
|
||||
# 返回提取的有用信息
|
||||
return json.dumps(useful_info, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并处理所有可能的异常,返回错误信息
|
||||
return f"Error: {str(e)}"
|
||||
8
sci_mcp/material_mcp/matgl_tools/__init__.py
Normal file
8
sci_mcp/material_mcp/matgl_tools/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
MatGL Tools Module
|
||||
|
||||
This module provides tools for material structure relaxation and property prediction
|
||||
using MatGL (Materials Graph Library) models.
|
||||
"""
|
||||
|
||||
from .matgl_tools import *
|
||||
487
sci_mcp/material_mcp/matgl_tools/matgl_tools.py
Normal file
487
sci_mcp/material_mcp/matgl_tools/matgl_tools.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
MatGL Tools Module
|
||||
|
||||
This module provides tools for material structure relaxation and property prediction
|
||||
using MatGL (Materials Graph Library) models.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from ...core.config import material_config
|
||||
|
||||
|
||||
import warnings
|
||||
import json
|
||||
from typing import Dict, List, Union, Optional, Any
|
||||
|
||||
import torch
|
||||
from pymatgen.core import Lattice, Structure
|
||||
from pymatgen.ext.matproj import MPRester
|
||||
from pymatgen.io.ase import AseAtomsAdaptor
|
||||
|
||||
import matgl
|
||||
from matgl.ext.ase import Relaxer, MolecularDynamics, PESCalculator
|
||||
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
|
||||
|
||||
from ...core.llm_tools import llm_tool
|
||||
import os
|
||||
from ..support.utils import read_structure_from_file_name_or_content_string
|
||||
# To suppress warnings for clearer output
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
|
||||
@llm_tool(name="relax_crystal_structure_M3GNet",
|
||||
description="Optimize crystal structure geometry using M3GNet universal potential from a structure file or content string")
|
||||
async def relax_crystal_structure_M3GNet(
|
||||
structure_source: str,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Optimize crystal structure geometry to find its equilibrium configuration.
|
||||
|
||||
Uses M3GNet universal potential for fast and accurate structure relaxation without DFT.
|
||||
Accepts a structure file or content string.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
fmax: Maximum force tolerance for convergence in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string with the relaxation results or an error message.
|
||||
"""
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Create a relaxer and relax the structure
|
||||
relaxer = Relaxer(potential=pot)
|
||||
relax_results = relaxer.relax(structure, fmax=fmax)
|
||||
|
||||
# Get the relaxed structure
|
||||
relaxed_structure = relax_results["final_structure"]
|
||||
reduced_formula = relaxed_structure.composition.reduced_formula
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = relaxed_structure.lattice
|
||||
volume = relaxed_structure.volume
|
||||
density = relaxed_structure.density
|
||||
symmetry = relaxed_structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(relaxed_structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Structure Relaxation\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Force Tolerance**: `{fmax} eV/Å`\n"
|
||||
f"- **Status**: `Successfully relaxed`\n\n"
|
||||
f"### Relaxed Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {relaxed_structure.lattice.pbc[0]!s:5s} {relaxed_structure.lattice.pbc[1]!s:5s} {relaxed_structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(relaxed_structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error during structure relaxation: {str(e)}"
|
||||
|
||||
|
||||
# 内部函数,用于结构优化,返回结构对象而不是格式化字符串
|
||||
async def _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source: str,
|
||||
fmax: float = 0.01
|
||||
) -> Union[Structure, str]:
|
||||
"""
|
||||
内部使用的结构优化函数,返回结构对象而不是格式化字符串。
|
||||
|
||||
Args:
|
||||
structure_source: 结构文件名或内容字符串
|
||||
fmax: 力收敛阈值 (eV/Å)
|
||||
|
||||
Returns:
|
||||
优化后的结构对象或错误信息
|
||||
"""
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Create a relaxer and relax the structure
|
||||
relaxer = Relaxer(potential=pot)
|
||||
relax_results = relaxer.relax(structure, fmax=fmax)
|
||||
|
||||
# Get the relaxed structure
|
||||
relaxed_structure = relax_results["final_structure"]
|
||||
|
||||
return relaxed_structure
|
||||
except Exception as e:
|
||||
return f"Error during structure relaxation: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="predict_formation_energy_M3GNet",
|
||||
description="Predict the formation energy of a crystal structure using the M3GNet formation energy model from a structure file or content string, with optional structure optimization")
|
||||
async def predict_formation_energy_M3GNet(
|
||||
structure_source: str,
|
||||
optimize_structure: bool = True,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Predict the formation energy of a crystal structure using the M3GNet formation energy model.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
optimize_structure: Whether to optimize the structure before prediction (default: True).
|
||||
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string containing the predicted formation energy in eV/atom or an error message.
|
||||
"""
|
||||
try:
|
||||
# 获取结构(优化或不优化)
|
||||
if optimize_structure:
|
||||
# 使用内部函数优化结构
|
||||
structure = await _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source=structure_source,
|
||||
fmax=fmax
|
||||
)
|
||||
|
||||
# 检查优化是否成功
|
||||
if isinstance(structure, str) and structure.startswith("Error"):
|
||||
return structure
|
||||
else:
|
||||
# 直接读取结构,不进行优化
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# 加载预训练模型
|
||||
model = matgl.load_model("M3GNet-MP-2018.6.1-Eform")
|
||||
|
||||
# 预测形成能
|
||||
eform = model.predict_structure(structure)
|
||||
reduced_formula = structure.composition.reduced_formula
|
||||
|
||||
# 构建结果字符串
|
||||
optimization_status = "optimized" if optimize_structure else "non-optimized"
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = structure.lattice
|
||||
volume = structure.volume
|
||||
density = structure.density
|
||||
symmetry = structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Formation Energy Prediction\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Structure Status**: `{optimization_status}`\n"
|
||||
f"- **Formation Energy**: `{float(eform):.3f} eV/atom`\n\n"
|
||||
f"### Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="run_molecular_dynamics_M3GNet",
|
||||
description="Run molecular dynamics simulation on a crystal structure using M3GNet universal potential, with optional structure optimization")
|
||||
async def run_molecular_dynamics_M3GNet(
|
||||
structure_source: str,
|
||||
temperature_K: float = 300,
|
||||
steps: int = 100,
|
||||
optimize_structure: bool = True,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Run molecular dynamics simulation on a crystal structure using M3GNet universal potential.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
temperature_K: Temperature for MD simulation in Kelvin (default: 300).
|
||||
steps: Number of MD steps to run (default: 100).
|
||||
optimize_structure: Whether to optimize the structure before simulation (default: True).
|
||||
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string containing the simulation results, including final potential energy.
|
||||
"""
|
||||
try:
|
||||
# 获取结构(优化或不优化)
|
||||
if optimize_structure:
|
||||
# 使用内部函数优化结构
|
||||
structure = await _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source=structure_source,
|
||||
fmax=fmax
|
||||
)
|
||||
|
||||
# 检查优化是否成功
|
||||
if isinstance(structure, str) and structure.startswith("Error"):
|
||||
return structure
|
||||
else:
|
||||
# 直接读取结构,不进行优化
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Convert pymatgen structure to ASE atoms
|
||||
ase_adaptor = AseAtomsAdaptor()
|
||||
atoms = ase_adaptor.get_atoms(structure)
|
||||
|
||||
# Initialize the velocity according to Maxwell Boltzmann distribution
|
||||
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
|
||||
|
||||
# Create the MD class and run simulation
|
||||
driver = MolecularDynamics(atoms, potential=pot, temperature=temperature_K)
|
||||
driver.run(steps)
|
||||
|
||||
# Get final potential energy
|
||||
final_energy = atoms.get_potential_energy()
|
||||
|
||||
# Get final structure
|
||||
final_structure = ase_adaptor.get_structure(atoms)
|
||||
reduced_formula = final_structure.composition.reduced_formula
|
||||
|
||||
# 构建结果字符串
|
||||
optimization_status = "optimized" if optimize_structure else "non-optimized"
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = final_structure.lattice
|
||||
volume = final_structure.volume
|
||||
density = final_structure.density
|
||||
symmetry = final_structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(final_structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Molecular Dynamics Simulation\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Structure Status**: `{optimization_status}`\n"
|
||||
f"- **Temperature**: `{temperature_K} K`\n"
|
||||
f"- **Steps**: `{steps}`\n"
|
||||
f"- **Final Potential Energy**: `{float(final_energy):.3f} eV`\n\n"
|
||||
f"### Final Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {final_structure.lattice.pbc[0]!s:5s} {final_structure.lattice.pbc[1]!s:5s} {final_structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(final_structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="calculate_single_point_energy_M3GNet",
|
||||
description="Calculate single point energy of a crystal structure using M3GNet universal potential, with optional structure optimization")
|
||||
async def calculate_single_point_energy_M3GNet(
|
||||
structure_source: str,
|
||||
optimize_structure: bool = True,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Calculate single point energy of a crystal structure using M3GNet universal potential.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
optimize_structure: Whether to optimize the structure before calculation (default: True).
|
||||
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string containing the calculated potential energy in eV.
|
||||
"""
|
||||
try:
|
||||
# 获取结构(优化或不优化)
|
||||
if optimize_structure:
|
||||
# 使用内部函数优化结构
|
||||
structure = await _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source=structure_source,
|
||||
fmax=fmax
|
||||
)
|
||||
|
||||
# 检查优化是否成功
|
||||
if isinstance(structure, str) and structure.startswith("Error"):
|
||||
return structure
|
||||
else:
|
||||
# 直接读取结构,不进行优化
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Convert pymatgen structure to ASE atoms
|
||||
ase_adaptor = AseAtomsAdaptor()
|
||||
atoms = ase_adaptor.get_atoms(structure)
|
||||
|
||||
# Set up the calculator for atoms object
|
||||
calc = PESCalculator(pot)
|
||||
atoms.set_calculator(calc)
|
||||
|
||||
# Calculate potential energy
|
||||
energy = atoms.get_potential_energy()
|
||||
reduced_formula = structure.composition.reduced_formula
|
||||
|
||||
# 构建结果字符串
|
||||
optimization_status = "optimized" if optimize_structure else "non-optimized"
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = structure.lattice
|
||||
volume = structure.volume
|
||||
density = structure.density
|
||||
symmetry = structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Single Point Energy Calculation\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Structure Status**: `{optimization_status}`\n"
|
||||
f"- **Potential Energy**: `{float(energy):.3f} eV`\n\n"
|
||||
f"### Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
#Error: Bad serialized model or bad model name. It is possible that you have an older model cached. Please clear your cache by running `python -c "import matgl; matgl.clear_cache()"`
|
||||
# @llm_tool(name="predict_band_gap",
|
||||
# description="Predict the band gap of a crystal structure using MEGNet multi-fidelity model from either a chemical formula or CIF file, with structure optimization")
|
||||
# async def predict_band_gap(
|
||||
# formula: str = None,
|
||||
# cif_file_name: str = None,
|
||||
# method: str = "PBE",
|
||||
# fmax: float = 0.01
|
||||
# ) -> str:
|
||||
# """
|
||||
# Predict the band gap of a crystal structure using the MEGNet multi-fidelity band gap model.
|
||||
|
||||
# First optimizes the crystal structure using M3GNet universal potential, then predicts
|
||||
# the band gap based on the relaxed structure for more accurate results.
|
||||
|
||||
# Accepts either a chemical formula (searches Materials Project database) or a CIF file.
|
||||
|
||||
# Args:
|
||||
# formula: Chemical formula to retrieve from Materials Project (e.g., "Fe2O3").
|
||||
# cif_file_name: Name of CIF file in temp directory to use as structure source.
|
||||
# method: The DFT method to use for the prediction. Options are "PBE", "GLLB-SC", "HSE", or "SCAN".
|
||||
# Default is "PBE".
|
||||
# fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
# Returns:
|
||||
# A string containing the predicted band gap in eV or an error message.
|
||||
# """
|
||||
# try:
|
||||
# # First, relax the crystal structure
|
||||
# relaxed_result = await relax_crystal_structure(
|
||||
# formula=formula,
|
||||
# cif_file_name=cif_file_name,
|
||||
# fmax=fmax
|
||||
# )
|
||||
|
||||
# # Check if relaxation was successful
|
||||
# if isinstance(relaxed_result, str) and relaxed_result.startswith("Error"):
|
||||
# return relaxed_result
|
||||
|
||||
# # Use the relaxed structure for band gap prediction
|
||||
# structure = relaxed_result
|
||||
|
||||
# if structure is None:
|
||||
# return "Error: Failed to obtain a valid relaxed structure"
|
||||
|
||||
# # Load the pre-trained MEGNet band gap model
|
||||
# model = matgl.load_model("MEGNet-MP-2019.4.1-BandGap-mfi")
|
||||
|
||||
# # Map method name to index
|
||||
# method_map = {"PBE": 0, "GLLB-SC": 1, "HSE": 2, "SCAN": 3}
|
||||
# if method not in method_map:
|
||||
# return f"Error: Unsupported method: {method}. Choose from PBE, GLLB-SC, HSE, or SCAN."
|
||||
|
||||
# # Set the graph label based on the method
|
||||
# graph_attrs = torch.tensor([method_map[method]])
|
||||
|
||||
# # Predict the band gap using the relaxed structure
|
||||
# bandgap = model.predict_structure(structure=structure, state_attr=graph_attrs)
|
||||
# reduced_formula = structure.reduced_formula
|
||||
|
||||
# # Return the band gap as a string
|
||||
# return f"The predicted band gap for relaxed {reduced_formula} using {method} method is {float(bandgap):.3f} eV."
|
||||
# except Exception as e:
|
||||
# return f"Error: {str(e)}"
|
||||
|
||||
|
||||
240
sci_mcp/material_mcp/mattergen_gen/material_gen_tools.py
Executable file
240
sci_mcp/material_mcp/mattergen_gen/material_gen_tools.py
Executable file
@@ -0,0 +1,240 @@
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
import tempfile
|
||||
import os
|
||||
import datetime
|
||||
import asyncio
|
||||
import zipfile
|
||||
import shutil
|
||||
import re
|
||||
import multiprocessing
|
||||
from multiprocessing import Process, Queue
|
||||
from pathlib import Path
|
||||
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
||||
import logging
|
||||
# 设置多进程启动方法为spawn,解决CUDA初始化错误
|
||||
try:
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
except RuntimeError:
|
||||
# 如果已经设置过启动方法,会抛出RuntimeError
|
||||
pass
|
||||
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from ase.io import read, write
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
# 导入路径已更新
|
||||
from ...core.llm_tools import llm_tool
|
||||
from .mattergen_wrapper import *
|
||||
|
||||
# 使用mattergen_wrapper
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
|
||||
def convert_values(data_str):
|
||||
"""
|
||||
将字符串转换为字典
|
||||
|
||||
Args:
|
||||
data_str: JSON字符串
|
||||
|
||||
Returns:
|
||||
解析后的数据,如果解析失败则返回原字符串
|
||||
"""
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return data_str # 如果无法解析为JSON,返回原字符串
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
|
||||
"""
|
||||
Preprocess a property value based on its name, converting it to the appropriate type.
|
||||
|
||||
Args:
|
||||
property_name: Name of the property
|
||||
property_value: Value of the property (can be string, float, or int)
|
||||
|
||||
Returns:
|
||||
Tuple of (property_name, processed_value)
|
||||
|
||||
Raises:
|
||||
ValueError: If the property value is invalid for the given property name
|
||||
"""
|
||||
valid_properties = [
|
||||
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
|
||||
"energy_above_hull", "formation_energy_per_atom", "space_group",
|
||||
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
|
||||
]
|
||||
|
||||
if property_name not in valid_properties:
|
||||
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
|
||||
|
||||
# Process property_value if it's a string
|
||||
if isinstance(property_value, str):
|
||||
try:
|
||||
# Try to convert string to float for numeric properties
|
||||
if property_name != "chemical_system":
|
||||
property_value = float(property_value)
|
||||
except ValueError:
|
||||
# If conversion fails, keep as string (for chemical_system)
|
||||
pass
|
||||
|
||||
# Handle special cases for properties that need specific types
|
||||
if property_name == "chemical_system":
|
||||
if isinstance(property_value, (int, float)):
|
||||
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
|
||||
property_value = str(property_value)
|
||||
elif property_name == "space_group" :
|
||||
space_group = property_value
|
||||
if space_group < 1 or space_group > 230:
|
||||
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
|
||||
|
||||
return property_name, property_value
|
||||
|
||||
|
||||
def main(
|
||||
output_path: str,
|
||||
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
|
||||
model_path: str | None = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
config_overrides: list[str] | None = None,
|
||||
checkpoint_epoch: Literal["best", "last"] | int = "last",
|
||||
properties_to_condition_on: TargetProperty | None = None,
|
||||
sampling_config_path: str | None = None,
|
||||
sampling_config_name: str = "default",
|
||||
sampling_config_overrides: list[str] | None = None,
|
||||
record_trajectories: bool = True,
|
||||
diffusion_guidance_factor: float | None = None,
|
||||
strict_checkpoint_loading: bool = True,
|
||||
target_compositions: list[dict[str, int]] | None = None,
|
||||
):
|
||||
"""
|
||||
Evaluate diffusion model against molecular metrics.
|
||||
|
||||
Args:
|
||||
model_path: Path to DiffusionLightningModule checkpoint directory.
|
||||
output_path: Path to output directory.
|
||||
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
|
||||
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
|
||||
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
|
||||
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
|
||||
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
|
||||
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
|
||||
record: Whether to record the trajectories of the generated structures. (default: True)
|
||||
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
|
||||
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
|
||||
Only supported for models trained for crystal structure prediction (CSP) (default: None)
|
||||
|
||||
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
|
||||
"""
|
||||
assert (
|
||||
pretrained_name is not None or model_path is not None
|
||||
), "Either pretrained_name or model_path must be provided."
|
||||
assert (
|
||||
pretrained_name is None or model_path is None
|
||||
), "Only one of pretrained_name or model_path can be provided."
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
sampling_config_overrides = sampling_config_overrides or []
|
||||
config_overrides = config_overrides or []
|
||||
properties_to_condition_on = properties_to_condition_on or {}
|
||||
target_compositions = target_compositions or []
|
||||
|
||||
if pretrained_name is not None:
|
||||
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
|
||||
pretrained_name, config_overrides=config_overrides
|
||||
)
|
||||
else:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch=checkpoint_epoch,
|
||||
config_overrides=config_overrides,
|
||||
strict_checkpoint_loading=strict_checkpoint_loading,
|
||||
)
|
||||
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name=sampling_config_name,
|
||||
sampling_config_path=_sampling_config_path,
|
||||
sampling_config_overrides=sampling_config_overrides,
|
||||
record_trajectories=record_trajectories,
|
||||
diffusion_guidance_factor=(
|
||||
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
|
||||
),
|
||||
target_compositions_dict=target_compositions,
|
||||
)
|
||||
generator.generate(output_dir=Path(output_path))
|
||||
|
||||
|
||||
@llm_tool(name="generate_material_MatterGen", description="Generate crystal structures with optional property constraints using MatterGen model")
|
||||
def generate_material_MatterGen(
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
This unified function can generate materials in three modes:
|
||||
1. Unconditional generation (no properties specified)
|
||||
2. Single property conditional generation (one property specified)
|
||||
3. Multi-property conditional generation (multiple properties specified)
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints. Can be:
|
||||
- None or empty dict for unconditional generation
|
||||
- Dict with single key-value pair for single property conditioning
|
||||
- Dict with multiple key-value pairs for multi-property conditioning
|
||||
Valid property names include: "dft_band_gap", "chemical_system", etc.
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
|
||||
|
||||
# 导入MatterGenService
|
||||
from .mattergen_service import MatterGenService
|
||||
logger.info("子进程成功导入MatterGenService")
|
||||
|
||||
# 获取MatterGenService实例
|
||||
service = MatterGenService.get_instance()
|
||||
logger.info("子进程成功获取MatterGenService实例")
|
||||
|
||||
# 使用服务生成材料
|
||||
logger.info("子进程开始调用generate方法...")
|
||||
result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
|
||||
logger.info("子进程generate方法调用完成")
|
||||
if "Error generating structures" in result:
|
||||
return f"Error: Invalid properties {properties}."
|
||||
else:
|
||||
return result
|
||||
466
sci_mcp/material_mcp/mattergen_gen/mattergen_service.py
Executable file
466
sci_mcp/material_mcp/mattergen_gen/mattergen_service.py
Executable file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
MatterGen service for mars_toolkit.
|
||||
|
||||
This module provides a service for generating crystal structures using MatterGen.
|
||||
The service initializes the CrystalGenerator once and reuses it for multiple
|
||||
generation requests, improving performance.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
import threading
|
||||
import torch
|
||||
|
||||
from .mattergen_wrapper import *
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
def format_cif_content(content):
|
||||
"""
|
||||
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
||||
|
||||
Args:
|
||||
content: String containing CIF content, possibly with PK headers
|
||||
|
||||
Returns:
|
||||
Formatted string with each CIF file properly labeled and formatted
|
||||
"""
|
||||
# 如果内容为空,直接返回空字符串
|
||||
if not content or content.strip() == '':
|
||||
return ''
|
||||
|
||||
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
|
||||
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
|
||||
|
||||
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
|
||||
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
|
||||
|
||||
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
|
||||
# 但我们需要保留这个字段在每个CIF文件中
|
||||
cif_blocks = []
|
||||
|
||||
# 查找所有_chemical_formula_structural的位置
|
||||
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
|
||||
|
||||
# 如果没有找到任何_chemical_formula_structural,返回空字符串
|
||||
if not formula_positions:
|
||||
return ''
|
||||
|
||||
# 分割CIF块
|
||||
for i in range(len(formula_positions)):
|
||||
start_pos = formula_positions[i]
|
||||
# 如果是最后一个块,结束位置是字符串末尾
|
||||
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
|
||||
|
||||
cif_block = content[start_pos:end_pos].strip()
|
||||
|
||||
# 提取formula值
|
||||
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
|
||||
if formula_match:
|
||||
formula = formula_match.group(1)
|
||||
cif_blocks.append((formula, cif_block))
|
||||
|
||||
# 格式化输出
|
||||
result = []
|
||||
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
|
||||
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]\n"
|
||||
result.append(formatted)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def extract_cif_file_from_zip(cifs_zip_path: str):
|
||||
"""
|
||||
Extract CIF files from a zip archive, extract formula from each CIF file,
|
||||
and save each CIF file with its formula as the filename.
|
||||
|
||||
Args:
|
||||
cifs_zip_path: Path to the zip file
|
||||
|
||||
Returns:
|
||||
list: List of tuples containing (index, formula, cif_path)
|
||||
"""
|
||||
result_dict = {}
|
||||
if os.path.exists(cifs_zip_path):
|
||||
with open(cifs_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
cifs_content = format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))
|
||||
pattern = r'\[cif (\d+) begin\]\n(.*?)\n\[cif \1 end\]'
|
||||
matches = re.findall(pattern, cifs_content, re.DOTALL)
|
||||
|
||||
# 处理每个匹配项,提取formula并保存CIF文件
|
||||
saved_files = []
|
||||
for idx, cif_content in matches:
|
||||
# 提取data_{formula}中的formula
|
||||
formula_match = re.search(r'data_([^\s]+)', cif_content)
|
||||
if formula_match:
|
||||
formula = formula_match.group(1)
|
||||
# 构建保存路径
|
||||
cif_path = os.path.join(material_config.TEMP_ROOT, f"{formula}.cif")
|
||||
# 保存CIF文件
|
||||
with open(cif_path, 'w') as f:
|
||||
f.write(cif_content)
|
||||
saved_files.append((idx, formula, cif_path))
|
||||
|
||||
return saved_files
|
||||
|
||||
|
||||
class MatterGenService:
|
||||
"""
|
||||
Service for generating crystal structures using MatterGen.
|
||||
|
||||
This service initializes the CrystalGenerator once and reuses it for multiple
|
||||
generation requests, improving performance.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
# 模型到GPU ID的映射
|
||||
MODEL_TO_GPU = {
|
||||
"mattergen_base": "0", # 基础模型使用GPU 0
|
||||
"dft_mag_density": "1", # 磁密度模型使用GPU 1
|
||||
"dft_bulk_modulus": "2", # 体积模量模型使用GPU 2
|
||||
"dft_shear_modulus": "3", # 剪切模量模型使用GPU 3
|
||||
"energy_above_hull": "4", # 能量模型使用GPU 4
|
||||
"formation_energy_per_atom": "5", # 形成能模型使用GPU 5
|
||||
"space_group": "6", # 空间群模型使用GPU 6
|
||||
"hhi_score": "7", # HHI评分模型使用GPU 7
|
||||
"ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0
|
||||
"chemical_system": "1", # 化学系统模型使用GPU 1
|
||||
"dft_band_gap": "2", # 带隙模型使用GPU 2
|
||||
"dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3
|
||||
"chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""
|
||||
Get the singleton instance of MatterGenService.
|
||||
|
||||
Returns:
|
||||
MatterGenService: The singleton instance.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the MatterGenService.
|
||||
|
||||
This initializes the base generator without any property conditioning.
|
||||
Specific generators for different property conditions will be initialized
|
||||
on demand.
|
||||
"""
|
||||
self._generators = {}
|
||||
self._output_dir = material_config.TEMP_ROOT
|
||||
|
||||
# 确保输出目录存在
|
||||
if not os.path.exists(self._output_dir):
|
||||
os.makedirs(self._output_dir)
|
||||
|
||||
# 初始化基础生成器(无条件生成)
|
||||
self._init_base_generator()
|
||||
|
||||
def _init_base_generator(self):
|
||||
"""
|
||||
Initialize the base generator for unconditional generation.
|
||||
"""
|
||||
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.")
|
||||
return
|
||||
|
||||
logger.info(f"Initializing base MatterGen generator from {model_path}")
|
||||
|
||||
try:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch="last",
|
||||
config_overrides=[],
|
||||
strict_checkpoint_loading=True,
|
||||
)
|
||||
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=None,
|
||||
batch_size=2, # 默认值,可在生成时覆盖
|
||||
num_batches=1, # 默认值,可在生成时覆盖
|
||||
sampling_config_name="default",
|
||||
sampling_config_path=None,
|
||||
sampling_config_overrides=[],
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=0.0,
|
||||
target_compositions_dict=[],
|
||||
)
|
||||
|
||||
self._generators["base"] = generator
|
||||
logger.info("Base MatterGen generator initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize base MatterGen generator: {e}")
|
||||
|
||||
def _get_or_create_generator(
|
||||
self,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
):
|
||||
"""
|
||||
Get or create a generator for the specified properties.
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
tuple: (generator, generator_key, properties_to_condition_on, gpu_id)
|
||||
"""
|
||||
# 如果没有属性约束,使用基础生成器
|
||||
if not properties:
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0
|
||||
return self._generators.get("base"), "base", None, gpu_id
|
||||
|
||||
# 处理属性约束
|
||||
properties_to_condition_on = {}
|
||||
for property_name, property_value in properties.items():
|
||||
properties_to_condition_on[property_name] = property_value
|
||||
|
||||
# 确定模型目录
|
||||
if len(properties) == 1:
|
||||
# 单属性条件
|
||||
property_name = list(properties.keys())[0]
|
||||
property_to_model = {
|
||||
"dft_mag_density": "dft_mag_density",
|
||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||
"dft_shear_modulus": "dft_shear_modulus",
|
||||
"energy_above_hull": "energy_above_hull",
|
||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||
"space_group": "space_group",
|
||||
"hhi_score": "hhi_score",
|
||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||
"chemical_system": "chemical_system",
|
||||
"dft_band_gap": "dft_band_gap"
|
||||
}
|
||||
model_dir = property_to_model.get(property_name, property_name)
|
||||
generator_key = f"single_{property_name}"
|
||||
else:
|
||||
# 多属性条件
|
||||
property_keys = set(properties.keys())
|
||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||
model_dir = "dft_mag_density_hhi_score"
|
||||
generator_key = "multi_dft_mag_density_hhi_score"
|
||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||
model_dir = "chemical_system_energy_above_hull"
|
||||
generator_key = "multi_chemical_system_energy_above_hull"
|
||||
else:
|
||||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||||
first_property = list(properties.keys())[0]
|
||||
model_dir = first_property
|
||||
generator_key = f"multi_{first_property}_etc"
|
||||
|
||||
# 获取对应的GPU ID
|
||||
gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0
|
||||
|
||||
# 构建完整的模型路径
|
||||
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
# 检查模型目录是否存在
|
||||
if not os.path.exists(model_path):
|
||||
# 如果特定模型不存在,回退到基础模型
|
||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
generator_key = "base"
|
||||
|
||||
# 检查是否已经有这个生成器
|
||||
if generator_key in self._generators:
|
||||
# 更新生成器的参数
|
||||
generator = self._generators[generator_key]
|
||||
generator.batch_size = batch_size
|
||||
generator.num_batches = num_batches
|
||||
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
|
||||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||||
|
||||
# 创建新的生成器
|
||||
try:
|
||||
logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}")
|
||||
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch="last",
|
||||
config_overrides=[],
|
||||
strict_checkpoint_loading=True,
|
||||
)
|
||||
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name="default",
|
||||
sampling_config_path=None,
|
||||
sampling_config_overrides=[],
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0,
|
||||
target_compositions_dict=[],
|
||||
)
|
||||
|
||||
self._generators[generator_key] = generator
|
||||
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
|
||||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
|
||||
# 回退到基础生成器
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0")
|
||||
return self._generators.get("base"), "base", None, base_gpu_id
|
||||
|
||||
def generate(
|
||||
self,
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
str: Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
|
||||
|
||||
# 处理字符串输入(如果提供)
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||
|
||||
# 如果为None,默认为空字典
|
||||
properties = properties or {}
|
||||
|
||||
# 获取或创建生成器和GPU ID
|
||||
generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator(
|
||||
properties, batch_size, num_batches, diffusion_guidance_factor
|
||||
)
|
||||
print("gpu_id",gpu_id)
|
||||
if generator is None:
|
||||
return "Error: Failed to initialize MatterGen generator"
|
||||
|
||||
# 使用torch.cuda.set_device()直接设置当前GPU
|
||||
try:
|
||||
# 将字符串类型的gpu_id转换为整数
|
||||
cuda_device_id = int(gpu_id)
|
||||
torch.cuda.set_device(cuda_device_id)
|
||||
logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}")
|
||||
print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.")
|
||||
|
||||
# 生成结构
|
||||
try:
|
||||
|
||||
output_dir= Path(self._output_dir+f'/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}')
|
||||
Path.mkdir(output_dir, parents=True, exist_ok=True)
|
||||
generator.generate(output_dir=output_dir)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating structures: {e}")
|
||||
return f"Error generating structures: {e}"
|
||||
|
||||
# 创建字典存储文件内容
|
||||
result_dict = {}
|
||||
|
||||
# 定义文件路径
|
||||
cif_zip_path = os.path.join(str(output_dir), f"generated_crystals_cif.zip")
|
||||
xyz_file_path = os.path.join(str(output_dir), f"generated_crystals.extxyz")
|
||||
trajectories_zip_path = os.path.join(str(output_dir), f"generated_trajectories.zip")
|
||||
|
||||
# 读取CIF压缩文件
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
|
||||
# 根据生成类型创建描述性提示
|
||||
if not properties:
|
||||
generation_type = "unconditional"
|
||||
title = "Generated Material Structures"
|
||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||
property_description = "unconditionally"
|
||||
elif len(properties) == 1:
|
||||
generation_type = "single_property"
|
||||
property_name = list(properties.keys())[0]
|
||||
property_value = properties[property_name]
|
||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||
property_description = f"conditioned on {property_name} = {property_value}"
|
||||
else:
|
||||
generation_type = "multi_property"
|
||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||
|
||||
# 创建完整的提示
|
||||
prompt = f"""
|
||||
# {title}
|
||||
|
||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||
|
||||
{'' if generation_type == 'unconditional' else f'''
|
||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||
the generation adheres to the specified property values. Higher values produce samples that more
|
||||
closely match the target properties but may reduce diversity.
|
||||
'''}
|
||||
|
||||
## CIF Files (Crystallographic Information Files)
|
||||
|
||||
- Standard format for crystallographic structures
|
||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||
- Used by crystallographic software and visualization tools
|
||||
|
||||
```
|
||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||
```
|
||||
|
||||
{description}
|
||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||
"""
|
||||
# print("prompt",prompt)
|
||||
# 清理文件(读取后删除)
|
||||
# try:
|
||||
# if os.path.exists(cif_zip_path):
|
||||
# os.remove(cif_zip_path)
|
||||
# if os.path.exists(xyz_file_path):
|
||||
# os.remove(xyz_file_path)
|
||||
# if os.path.exists(trajectories_zip_path):
|
||||
# os.remove(trajectories_zip_path)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error cleaning up files: {e}")
|
||||
|
||||
# GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理
|
||||
logger.info(f"Generation completed on GPU for model {generator_key}")
|
||||
|
||||
return prompt
|
||||
26
sci_mcp/material_mcp/mattergen_gen/mattergen_wrapper.py
Executable file
26
sci_mcp/material_mcp/mattergen_gen/mattergen_wrapper.py
Executable file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
This is a wrapper module that provides access to the mattergen modules
|
||||
by modifying the Python path at runtime.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from ...core.config import material_config
|
||||
# Add the mattergen directory to the Python path
|
||||
mattergen_dir = material_config.MATTERGEN_ROOT
|
||||
sys.path.insert(0, mattergen_dir)
|
||||
|
||||
# Import the necessary modules from the mattergen package
|
||||
try:
|
||||
from mattergen import generator
|
||||
from mattergen.common.data import chemgraph
|
||||
from mattergen.common.data.types import TargetProperty
|
||||
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||||
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
|
||||
except ImportError as e:
|
||||
print(f"Error importing mattergen modules: {e}")
|
||||
print(f"Python path: {sys.path}")
|
||||
raise
|
||||
CrystalGenerator = generator.CrystalGenerator
|
||||
# Re-export the modules
|
||||
__all__ = ['generator', 'chemgraph', 'TargetProperty', 'MatterGenCheckpointInfo', 'PRETRAINED_MODEL_NAME','CrystalGenerator']
|
||||
73
sci_mcp/material_mcp/mattersim_pred/property_pred_tools.py
Normal file
73
sci_mcp/material_mcp/mattersim_pred/property_pred_tools.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Property Prediction Module
|
||||
|
||||
This module provides functions for predicting properties of crystal structures.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
import numpy as np
|
||||
from ase.units import GPa
|
||||
from mattersim.forcefield import MatterSimCalculator
|
||||
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ..support.utils import convert_structure,read_structure_from_file_name_or_content_string
|
||||
|
||||
@llm_tool(
|
||||
name="predict_properties_MatterSim",
|
||||
description="Predict energy, forces, and stress of crystal structures using MatterSim model based on CIF string",
|
||||
)
|
||||
async def predict_properties_MatterSim(structure_source: str) -> str:
|
||||
"""
|
||||
Use MatterSim model to predict energy, forces, and stress of crystal structures.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string
|
||||
|
||||
Returns:
|
||||
String containing prediction results
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_prediction():
|
||||
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
|
||||
structure_content,content_format=read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = convert_structure(content_format, structure_content)
|
||||
if structure is None:
|
||||
return "Unable to parse CIF string. Please check if the format is correct."
|
||||
|
||||
# 设置设备
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 使用 MatterSimCalculator 计算属性
|
||||
structure.calc = MatterSimCalculator(device=device)
|
||||
|
||||
# 直接获取能量、力和应力
|
||||
energy = structure.get_potential_energy()
|
||||
forces = structure.get_forces()
|
||||
stresses = structure.get_stress(voigt=False)
|
||||
|
||||
# 计算每原子能量
|
||||
num_atoms = len(structure)
|
||||
energy_per_atom = energy / num_atoms
|
||||
|
||||
# 计算应力(GPa和eV/A^3格式)
|
||||
stresses_ev_a3 = stresses
|
||||
stresses_gpa = stresses / GPa
|
||||
|
||||
# 构建返回的提示信息
|
||||
prompt = f"""
|
||||
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
|
||||
|
||||
Prediction results using the provided CIF structure:
|
||||
|
||||
- Total Energy (eV): {energy}
|
||||
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
|
||||
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
|
||||
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
|
||||
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
# 异步执行预测操作
|
||||
return await asyncio.to_thread(run_prediction)
|
||||
0
sci_mcp/material_mcp/mp_query/__init__.py
Normal file
0
sci_mcp/material_mcp/mp_query/__init__.py
Normal file
42
sci_mcp/material_mcp/mp_query/get_mp_id.py
Normal file
42
sci_mcp/material_mcp/mp_query/get_mp_id.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import List
|
||||
from mp_api.client import MPRester
|
||||
from ...core.config import material_config
|
||||
|
||||
async def get_mpid_from_formula(formula: str) -> List[str]:
|
||||
"""
|
||||
Get material IDs (mpid) from Materials Project database by chemical formula.
|
||||
Returns mpids for the lowest energy structures.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula (e.g., "Fe2O3")
|
||||
|
||||
Returns:
|
||||
List of material IDs
|
||||
"""
|
||||
os.environ['HTTP_PROXY'] = material_config.HTTP_PROXY or ''
|
||||
os.environ['HTTPS_PROXY'] =material_config.HTTPS_PROXY or ''
|
||||
|
||||
|
||||
try:
|
||||
id_list = []
|
||||
|
||||
cleaned_formula = formula.replace(" ", "").replace("\n", "").replace("\'", "").replace("\"", "")
|
||||
if "=" in cleaned_formula:
|
||||
name, id = cleaned_formula.split("=")
|
||||
else:
|
||||
id = cleaned_formula
|
||||
|
||||
formula_list = [id]
|
||||
|
||||
with MPRester(material_config.MP_API_KEY) as mpr:
|
||||
docs = mpr.materials.summary.search(formula=formula_list)
|
||||
if not docs:
|
||||
return "No materials found"
|
||||
else:
|
||||
for doc in docs:
|
||||
id_list.append(doc.material_id)
|
||||
return id_list
|
||||
except Exception as e:
|
||||
|
||||
return f"Error: get_mpid_from_formula: {str(e)}"
|
||||
168
sci_mcp/material_mcp/mp_query/mp_query_tools.py
Normal file
168
sci_mcp/material_mcp/mp_query/mp_query_tools.py
Normal file
@@ -0,0 +1,168 @@
|
||||
|
||||
import glob
|
||||
import json
|
||||
from typing import Dict, Any, Union
|
||||
from ...core.llm_tools import llm_tool
|
||||
from .get_mp_id import get_mpid_from_formula
|
||||
from ..support.utils import extract_cif_info, remove_symmetry_equiv_xyz
|
||||
from ...core.config import material_config
|
||||
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
@llm_tool(name="search_crystal_structures_from_materials_project",
|
||||
description="Retrieve and optimize crystal structures from Materials Project database using a chemical formula")
|
||||
async def search_crystal_structures_from_materials_project(
|
||||
formula: str,
|
||||
conventional_unit_cell: bool = True,
|
||||
symprec: float = 0.1
|
||||
) -> str:
|
||||
"""
|
||||
Retrieves crystal structures for a given chemical formula from Materials Project database and applies symmetry optimization.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula to search for (e.g., "Fe2O3")
|
||||
conventional_unit_cell: If True, returns conventional unit cell; if False, returns primitive cell
|
||||
symprec: Symmetry precision parameter for structure refinement (default: 0.1)
|
||||
|
||||
Returns:
|
||||
Formatted CIF data for the retrieved crystal structures with symmetry analysis
|
||||
"""
|
||||
try:
|
||||
structures = {}
|
||||
mp_id_list = await get_mpid_from_formula(formula=formula)
|
||||
if isinstance(mp_id_list, str):
|
||||
return mp_id_list # 直接返回错误信息
|
||||
|
||||
for i, mp_id in enumerate(mp_id_list):
|
||||
try:
|
||||
# 文件操作可能引发异常
|
||||
cif_files = glob.glob(material_config.LOCAL_MP_CIF_ROOT + f"/{mp_id}.cif")
|
||||
if not cif_files:
|
||||
continue # 如果没有找到文件,跳过这个mp_id
|
||||
|
||||
cif_file = cif_files[0]
|
||||
structure = Structure.from_file(cif_file)
|
||||
|
||||
# 结构处理可能引发异常
|
||||
if conventional_unit_cell:
|
||||
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
|
||||
|
||||
# 对结构进行对称化处理
|
||||
sga = SpacegroupAnalyzer(structure, symprec=symprec)
|
||||
symmetrized_structure = sga.get_refined_structure()
|
||||
|
||||
# 使用CifWriter生成CIF数据
|
||||
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
|
||||
cif_data = str(cif_writer)
|
||||
|
||||
# 删除CIF文件中的对称性操作部分
|
||||
cif_data = remove_symmetry_equiv_xyz(cif_data)
|
||||
cif_data = cif_data.replace('# generated using pymatgen', "")
|
||||
|
||||
# 生成一个唯一的键
|
||||
formula_key = structure.composition.reduced_formula
|
||||
key = f"{formula_key}_{i}"
|
||||
|
||||
structures[key] = cif_data
|
||||
|
||||
# 只保留前config.MP_TOPK个结果
|
||||
if len(structures) >= material_config.MP_TOPK:
|
||||
break
|
||||
|
||||
except (FileNotFoundError, IndexError) as file_error:
|
||||
# 处理文件相关错误
|
||||
continue # 跳过这个mp_id,继续处理下一个
|
||||
except ValueError as value_error:
|
||||
# 处理结构处理中的值错误
|
||||
continue # 跳过这个mp_id,继续处理下一个
|
||||
except Exception as process_error:
|
||||
# 记录处理特定结构时的错误,但继续处理其他结构
|
||||
print(f"Error: processing structure {mp_id}: {str(process_error)}")
|
||||
continue
|
||||
|
||||
# 如果没有成功处理任何结构
|
||||
if not structures:
|
||||
return f"No valid crystal structures found for formula: {formula}"
|
||||
|
||||
# 格式化结果为可读字符串
|
||||
prompt = f"""
|
||||
# Materials Project Symmetrized Crystal Structure Data
|
||||
|
||||
Below are symmetrized crystal structure data for {len(structures)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
|
||||
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
|
||||
"""
|
||||
|
||||
for i, (key, cif_data) in enumerate(structures.items(), 1):
|
||||
prompt += f"[cif {i} begin]\n"
|
||||
prompt += cif_data
|
||||
prompt += f"\n[cif {i} end]\n\n"
|
||||
|
||||
return prompt
|
||||
|
||||
except Exception as e:
|
||||
# 捕获整个函数执行过程中的任何未处理异常
|
||||
return f"Error: An unexpected error occurred while processing crystal structures: {str(e)}"
|
||||
|
||||
@llm_tool(name="search_material_property_from_material_project",
|
||||
description="Query material properties from Materials Project database using chemical formula")
|
||||
async def search_material_property_from_materials_project(
|
||||
formula: str,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve detailed property data for materials matching a chemical formula from Materials Project database.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula of the material(s) to search for (e.g. 'Fe2O3', 'LiFePO4')
|
||||
|
||||
Returns:
|
||||
Formatted string containing material properties including structure, electronic, thermodynamic and mechanical data
|
||||
"""
|
||||
# 获取MP ID列表
|
||||
mp_id_list = await get_mpid_from_formula(formula=formula)
|
||||
|
||||
# 检查get_mpid_from_formula的返回值类型
|
||||
# 如果返回的是字符串,说明发生了错误或没有找到材料
|
||||
if isinstance(mp_id_list, str):
|
||||
return mp_id_list # 直接返回错误信息
|
||||
|
||||
# 如果代码执行到这里,说明mp_id_list是一个有效的ID列表
|
||||
try:
|
||||
# 获取材料属性
|
||||
properties = []
|
||||
for mp_id in mp_id_list:
|
||||
try:
|
||||
file_path = material_config.LOCAL_MP_PROPS_ROOT + f"/{mp_id}.json"
|
||||
crystal_props = extract_cif_info(file_path, ['all_fields'])
|
||||
properties.append(crystal_props)
|
||||
except Exception as file_error:
|
||||
# 记录单个文件处理错误但继续处理其他ID
|
||||
continue
|
||||
|
||||
# 检查是否有结果
|
||||
if len(properties) == 0:
|
||||
return "No material properties found for the given formula, please try again."
|
||||
|
||||
# 只保留前MP_TOPK个结果
|
||||
properties = properties[:material_config.MP_TOPK]
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = []
|
||||
for i, item in enumerate(properties, 1):
|
||||
formatted_result = f"[property {i} begin]\n"
|
||||
formatted_result += json.dumps(item, indent=2)
|
||||
formatted_result += f"\n[property {i} end]\n\n"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
# 将所有结果合并为一个字符串
|
||||
res_chunk = "\n\n".join(formatted_results)
|
||||
res_template = f"""
|
||||
Here are the search material property from the Materials Project database:
|
||||
Due to length limitations, only the top {len(properties)} results are shown below:\n
|
||||
{res_chunk}
|
||||
"""
|
||||
return res_template
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: processing material properties: {str(e)}"
|
||||
0
sci_mcp/material_mcp/oqmd_query/__init__.py
Normal file
0
sci_mcp/material_mcp/oqmd_query/__init__.py
Normal file
92
sci_mcp/material_mcp/oqmd_query/oqmd_query_tools.py
Executable file
92
sci_mcp/material_mcp/oqmd_query/oqmd_query_tools.py
Executable file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from bs4 import BeautifulSoup
|
||||
from io import StringIO
|
||||
from typing import Annotated, Any, Dict, List
|
||||
import mcp.types as types
|
||||
from ...core.llm_tools import llm_tool
|
||||
|
||||
|
||||
|
||||
@llm_tool(name="query_material_from_OQMD", description="Query material properties by chemical formula from OQMD database")
|
||||
async def query_material_from_OQMD(
|
||||
formula: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
|
||||
) -> str:
|
||||
"""
|
||||
Query material information by chemical formula from OQMD database.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula of the material (e.g., Fe2O3, LiFePO4)
|
||||
|
||||
Returns:
|
||||
Formatted text with material information and property tables
|
||||
"""
|
||||
# Fetch data from OQMD
|
||||
url = f"https://www.oqmd.org/materials/composition/{formula}"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=100.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Validate response content
|
||||
if not response.text or len(response.text) < 100:
|
||||
raise ValueError("Invalid response content from OQMD API")
|
||||
|
||||
# Parse HTML data
|
||||
html = response.text
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# Parse basic data
|
||||
basic_data = []
|
||||
h1_element = soup.find('h1')
|
||||
if h1_element:
|
||||
basic_data.append(h1_element.text.strip())
|
||||
else:
|
||||
basic_data.append(f"Material: {formula}")
|
||||
|
||||
for script in soup.find_all('p'):
|
||||
if script:
|
||||
combined_text = ""
|
||||
for element in script.contents:
|
||||
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
|
||||
url = "https://www.oqmd.org" + element['href']
|
||||
combined_text += f"[{element.text.strip()}]({url}) "
|
||||
elif hasattr(element, 'text'):
|
||||
combined_text += element.text.strip() + " "
|
||||
else:
|
||||
combined_text += str(element).strip() + " "
|
||||
basic_data.append(combined_text.strip())
|
||||
|
||||
# Parse table data
|
||||
table_data = ""
|
||||
table = soup.find('table')
|
||||
if table:
|
||||
try:
|
||||
df = pd.read_html(StringIO(str(table)))[0]
|
||||
df = df.fillna('')
|
||||
df = df.replace([float('inf'), float('-inf')], '')
|
||||
table_data = df.to_markdown(index=False)
|
||||
except Exception as e:
|
||||
|
||||
table_data = "Error: parsing table data"
|
||||
|
||||
# Integrate data into a single text
|
||||
combined_text = "\n\n".join(basic_data)
|
||||
if table_data:
|
||||
combined_text += "\n\n## Material Properties Table\n\n" + table_data
|
||||
|
||||
return combined_text
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return f"Error: OQMD API request failed - {str(e)}"
|
||||
except httpx.TimeoutException:
|
||||
return "Error: OQMD API request timed out"
|
||||
except httpx.NetworkError as e:
|
||||
return f"Error: Network error occurred - {str(e)}"
|
||||
except ValueError as e:
|
||||
return f"Error: Invalid response content - {str(e)}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error occurred - {str(e)}"
|
||||
|
||||
|
||||
95
sci_mcp/material_mcp/pymatgen_cal/pymatgen_cal_tools.py
Normal file
95
sci_mcp/material_mcp/pymatgen_cal/pymatgen_cal_tools.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
import asyncio
|
||||
from pymatgen.core import Structure
|
||||
from ...core.config import material_config
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ..support.utils import read_structure_from_file_name_or_content_string
|
||||
|
||||
@llm_tool(name="calculate_density_Pymatgen", description="Calculate the density of a crystal structure from a file or content string using Pymatgen")
|
||||
async def calculate_density_Pymatgen(structure_source: str) -> str:
|
||||
"""
|
||||
Calculates the density of a structure from a file or content string.
|
||||
|
||||
Args:
|
||||
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
|
||||
Returns:
|
||||
str: A Markdown formatted string with the density or an error message if the calculation fails.
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
# # 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content,fmt=content_format)
|
||||
density = structure.density
|
||||
|
||||
# 删除临时文件
|
||||
|
||||
return (f"## Density Calculation\n\n"
|
||||
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n")
|
||||
except Exception as e:
|
||||
return f"Error: error occurred while calculating density: {str(e)}\n"
|
||||
|
||||
|
||||
@llm_tool(name="get_element_composition_Pymatgen", description="Analyze and retrieve the elemental composition of a crystal structure from a file or content string using Pymatgen")
|
||||
async def get_element_composition_Pymatgen(structure_source: str) -> str:
|
||||
"""
|
||||
Returns the elemental composition of a structure from a file or content string.
|
||||
|
||||
Args:
|
||||
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
|
||||
Returns:
|
||||
str: A Markdown formatted string with the elemental composition or an error message if the operation fails.
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
composition = structure.composition
|
||||
|
||||
return (f"## Element Composition\n\n"
|
||||
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
|
||||
f"- **Composition**: `{composition}`\n")
|
||||
except Exception as e:
|
||||
return f"Error: error occurred while getting element composition: {str(e)}\n"
|
||||
|
||||
|
||||
|
||||
@llm_tool(name="calculate_symmetry_Pymatgen", description="Determine the space group and symmetry operations of a crystal structure from a file or content string using Pymatgen")
|
||||
async def calculate_symmetry_Pymatgen(structure_source: str) -> str:
|
||||
"""
|
||||
Calculates the symmetry of a structure from a file or content string.
|
||||
|
||||
Args:
|
||||
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
|
||||
Returns:
|
||||
str: A Markdown formatted string with the symmetry information or an error message if the operation fails.
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
symmetry = structure.get_space_group_info()
|
||||
|
||||
return (f"## Symmetry Information\n\n"
|
||||
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
|
||||
f"- **Space Group**: `{symmetry[0]}`\n"
|
||||
f"- **Number**: `{symmetry[1]}`\n")
|
||||
except Exception as e:
|
||||
return f"Error: error occurred while calculating symmetry: {str(e)}\n"
|
||||
|
||||
0
sci_mcp/material_mcp/support/__init__.py
Normal file
0
sci_mcp/material_mcp/support/__init__.py
Normal file
212
sci_mcp/material_mcp/support/utils.py
Executable file
212
sci_mcp/material_mcp/support/utils.py
Executable file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
CIF Utilities Module
|
||||
|
||||
This module provides basic functions for handling CIF (Crystallographic Information File) files,
|
||||
which are commonly used in materials science for representing crystal structures.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from ase.io import read
|
||||
import tempfile
|
||||
from typing import Optional, Tuple
|
||||
from ase import Atoms
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def read_cif_txt_file(file_path):
|
||||
"""
|
||||
Read the CIF file and return its content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the CIF file
|
||||
|
||||
Returns:
|
||||
String content of the CIF file or None if an error occurs
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def extract_cif_info(path: str, fields_name: list):
|
||||
"""
|
||||
Extract specific fields from the CIF description JSON file.
|
||||
|
||||
Args:
|
||||
path: Path to the JSON file containing CIF information
|
||||
fields_name: List of field categories to extract. Use 'all_fields' to extract all fields.
|
||||
Other options include 'basic_fields', 'energy_electronic_fields', 'metal_magentic_fields'
|
||||
|
||||
Returns:
|
||||
Dictionary containing the extracted fields
|
||||
"""
|
||||
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
|
||||
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
|
||||
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
|
||||
|
||||
selected_fields = []
|
||||
if fields_name[0] == 'all_fields':
|
||||
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
|
||||
else:
|
||||
for field in fields_name:
|
||||
selected_fields.extend(locals().get(field, []))
|
||||
|
||||
with open(path, 'r') as f:
|
||||
docs = json.load(f)
|
||||
|
||||
new_docs = {}
|
||||
for field_name in selected_fields:
|
||||
new_docs[field_name] = docs.get(field_name, '')
|
||||
|
||||
return new_docs
|
||||
|
||||
def remove_symmetry_equiv_xyz(cif_content):
|
||||
"""
|
||||
Remove symmetry operations section from CIF file content.
|
||||
|
||||
This is often useful when working with CIF files in certain visualization tools
|
||||
or when focusing on the basic structure without symmetry operations.
|
||||
|
||||
Args:
|
||||
cif_content: CIF file content string
|
||||
|
||||
Returns:
|
||||
Cleaned CIF content string with symmetry operations removed
|
||||
"""
|
||||
lines = cif_content.split('\n')
|
||||
output_lines = []
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# 检测循环开始
|
||||
if line == 'loop_':
|
||||
# 查看下一行,检查是否是对称性循环
|
||||
next_lines = []
|
||||
j = i + 1
|
||||
while j < len(lines) and lines[j].strip().startswith('_'):
|
||||
next_lines.append(lines[j].strip())
|
||||
j += 1
|
||||
|
||||
# 检查是否包含对称性操作标签
|
||||
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
|
||||
# 跳过整个循环块
|
||||
while i < len(lines):
|
||||
if i + 1 >= len(lines):
|
||||
break
|
||||
|
||||
next_line = lines[i + 1].strip()
|
||||
# 检查是否到达下一个循环或数据块
|
||||
if next_line == 'loop_' or next_line.startswith('data_'):
|
||||
break
|
||||
|
||||
# 检查是否到达原子位置部分
|
||||
if next_line.startswith('_atom_site_'):
|
||||
break
|
||||
|
||||
i += 1
|
||||
else:
|
||||
# 不是对称性循环,保留loop_行
|
||||
output_lines.append(lines[i])
|
||||
else:
|
||||
# 非循环开始行,直接保留
|
||||
output_lines.append(lines[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
|
||||
def read_structure_from_file_name_or_content_string(file_name_or_content_string: str, format_type: str = "auto") -> Tuple[str, str]:
|
||||
"""
|
||||
处理结构输入,判断是文件名还是直接内容
|
||||
|
||||
当file_name_or_content_string被视为文件名时,会在material_config.TEMP_ROOT目录下查找该文件。
|
||||
这适用于大模型生成的临时文件,这些文件通常存储在临时目录中。
|
||||
|
||||
Args:
|
||||
file_name_or_content_string: 文件名或结构内容字符串
|
||||
format_type: 结构格式类型,"auto"表示自动检测
|
||||
|
||||
Returns:
|
||||
tuple: (内容字符串, 实际格式类型)
|
||||
"""
|
||||
# 首先检查是否是完整路径的文件
|
||||
if os.path.exists(file_name_or_content_string) and os.path.isfile(file_name_or_content_string):
|
||||
# 是完整路径文件,读取文件内容
|
||||
with open(file_name_or_content_string, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 如果格式为auto,从文件扩展名推断
|
||||
if format_type == "auto":
|
||||
ext = os.path.splitext(file_name_or_content_string)[1].lower().lstrip('.')
|
||||
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
|
||||
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
|
||||
else:
|
||||
# 默认假设为CIF
|
||||
format_type = 'cif'
|
||||
else:
|
||||
# 检查是否是临时目录中的文件名
|
||||
temp_path = os.path.join(material_config.TEMP_ROOT, file_name_or_content_string)
|
||||
if os.path.exists(temp_path) and os.path.isfile(temp_path):
|
||||
# 是临时目录中的文件,读取文件内容
|
||||
with open(temp_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 如果格式为auto,从文件扩展名推断
|
||||
if format_type == "auto":
|
||||
ext = os.path.splitext(temp_path)[1].lower().lstrip('.')
|
||||
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
|
||||
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
|
||||
else:
|
||||
# 默认假设为CIF
|
||||
format_type = 'cif'
|
||||
else:
|
||||
# 不是文件路径,假设是直接内容
|
||||
content = file_name_or_content_string
|
||||
|
||||
# 如果格式为auto,尝试从内容推断
|
||||
if format_type == "auto":
|
||||
# 简单启发式判断:
|
||||
# CIF文件通常包含"data_"和"_cell_"
|
||||
if "data_" in content and "_cell_" in content:
|
||||
format_type = "cif"
|
||||
# XYZ文件通常第一行是原子数量
|
||||
elif content.strip().split('\n')[0].strip().isdigit():
|
||||
format_type = "xyz"
|
||||
# POSCAR/VASP格式通常第一行是注释
|
||||
elif len(content.strip().split('\n')) > 5 and all(len(line.split()) == 3 for line in content.strip().split('\n')[2:5]):
|
||||
format_type = "vasp"
|
||||
# 默认假设为CIF
|
||||
else:
|
||||
format_type = "cif"
|
||||
|
||||
return content, format_type
|
||||
|
||||
def convert_structure(input_format: str='cif', content: str=None) -> Optional[Atoms]:
|
||||
"""
|
||||
将输入内容转换为Atoms对象
|
||||
|
||||
Args:
|
||||
input_format: 输入格式 (cif, xyz, vasp等)
|
||||
content: 结构内容字符串
|
||||
|
||||
Returns:
|
||||
ASE Atoms对象,如果转换失败则返回None
|
||||
"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
atoms = read(tmp_path)
|
||||
os.unlink(tmp_path)
|
||||
return atoms
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert structure: {str(e)}")
|
||||
return None
|
||||
306
server.py
Executable file
306
server.py
Executable file
@@ -0,0 +1,306 @@
|
||||
"""Mars Toolkit MCP Server implementation."""
|
||||
|
||||
import anyio
|
||||
import asyncio
|
||||
import click
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import time
|
||||
from prompts.material_synthesis import create_messages
|
||||
|
||||
# 添加mars_toolkit模块的路径
|
||||
sys.path.append('/home/ubuntu/50T/lzy/mars-mcp')
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel import Server
|
||||
|
||||
# 导入提示词处理器
|
||||
#from prompts.material_synthesis import register_prompt_handlers
|
||||
|
||||
# 导入Mars Toolkit工具函数
|
||||
try:
|
||||
# 获取当前时间
|
||||
from mars_toolkit.misc.misc_tools import get_current_time
|
||||
# 网络搜索
|
||||
from mars_toolkit.query.web_search import search_online
|
||||
# 从Materials Project查询材料属性
|
||||
from mars_toolkit.query.mp_query import search_material_property_from_material_project
|
||||
# 从Materials Project获取晶体结构
|
||||
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
|
||||
# 从化学式获取Materials Project ID
|
||||
from mars_toolkit.query.mp_query import get_mpid_from_formula
|
||||
# 优化晶体结构
|
||||
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
|
||||
# 生成材料
|
||||
from mars_toolkit.compute.material_gen import generate_material
|
||||
# 从OQMD获取化学成分
|
||||
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
# 从知识库检索
|
||||
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
# 预测属性
|
||||
from mars_toolkit.compute.property_pred import predict_properties
|
||||
|
||||
# 获取所有工具函数
|
||||
from mars_toolkit import get_tools, get_tool_schemas
|
||||
|
||||
MARS_TOOLKIT_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
print(f"警告: 无法导入Mars Toolkit: {e}", file=sys.stderr)
|
||||
MARS_TOOLKIT_AVAILABLE = False
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
app = Server("mars-toolkit-server")
|
||||
|
||||
|
||||
async def call_mars_toolkit_function(func_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
调用Mars Toolkit中的工具函数
|
||||
|
||||
Args:
|
||||
func_name: 工具函数名称
|
||||
arguments: 工具函数参数
|
||||
|
||||
Returns:
|
||||
工具函数的执行结果
|
||||
"""
|
||||
if not MARS_TOOLKIT_AVAILABLE:
|
||||
raise ValueError("Mars Toolkit不可用")
|
||||
|
||||
# 获取所有注册的工具函数
|
||||
tools = get_tools()
|
||||
|
||||
# 检查函数名是否存在于工具函数字典中
|
||||
if func_name not in tools:
|
||||
raise ValueError(f"函数 '{func_name}' 不存在于工具函数字典中")
|
||||
|
||||
# 获取对应的工具函数
|
||||
tool_func = tools[func_name]
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
print("result1",result)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_tool_schemas_dict() -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取所有工具函数的模式字典
|
||||
|
||||
Returns:
|
||||
工具函数名称到模式的映射字典
|
||||
"""
|
||||
if not MARS_TOOLKIT_AVAILABLE:
|
||||
return {}
|
||||
|
||||
schemas = get_tool_schemas()
|
||||
schemas_dict = {}
|
||||
|
||||
for schema in schemas:
|
||||
func_name = schema["function"]["name"]
|
||||
schemas_dict[func_name] = schema
|
||||
|
||||
return schemas_dict
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--port", default=5666, help="Port to listen on for SSE")
|
||||
@click.option(
|
||||
"--transport",
|
||||
type=click.Choice(["stdio", "sse"]),
|
||||
default="sse",
|
||||
help="Transport type",
|
||||
)
|
||||
def main(port: int, transport: str='SSE') -> int:
|
||||
"""
|
||||
Mars Toolkit MCP Server主函数
|
||||
|
||||
Args:
|
||||
port: SSE传输的端口号
|
||||
transport: 传输类型,stdio或sse
|
||||
|
||||
Returns:
|
||||
退出码
|
||||
"""
|
||||
if not MARS_TOOLKIT_AVAILABLE:
|
||||
print("错误: Mars Toolkit不可用,请确保已正确安装", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
# 获取工具函数模式字典
|
||||
schemas_dict = get_tool_schemas_dict()
|
||||
|
||||
# 注册提示词处理器
|
||||
#register_prompt_handlers(app)
|
||||
|
||||
|
||||
|
||||
|
||||
@app.list_prompts()
|
||||
async def list_prompts() -> list[types.Prompt]:
|
||||
return [
|
||||
types.Prompt(
|
||||
name="material_synthesis",
|
||||
description="生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||||
arguments=[
|
||||
types.PromptArgument(
|
||||
name="properties",
|
||||
description="材料性质及其值的JSON字符串,例如 {\"dft_band_gap\": \"2.0\"}",
|
||||
required=False,
|
||||
),
|
||||
types.PromptArgument(
|
||||
name="batch_size",
|
||||
description="生成材料的数量,默认为2",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
@app.get_prompt()
|
||||
async def get_prompt(
|
||||
name: str, arguments: dict[str, str] | None = None
|
||||
) -> types.GetPromptResult:
|
||||
if name != "material_synthesis":
|
||||
raise ValueError(f"未知的提示词: {name}")
|
||||
|
||||
if arguments is None:
|
||||
arguments = {}
|
||||
|
||||
# 解析properties参数
|
||||
properties = {}
|
||||
if "properties" in arguments and arguments["properties"]:
|
||||
try:
|
||||
import json
|
||||
properties = json.loads(arguments["properties"])
|
||||
except json.JSONDecodeError:
|
||||
properties = {}
|
||||
|
||||
# 解析batch_size参数
|
||||
batch_size = 2 # 默认值
|
||||
if "batch_size" in arguments and arguments["batch_size"]:
|
||||
try:
|
||||
batch_size = int(arguments["batch_size"])
|
||||
except ValueError:
|
||||
pass # 使用默认值
|
||||
|
||||
return types.GetPromptResult(
|
||||
messages=create_messages(properties=properties, batch_size=batch_size),
|
||||
description="生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||||
)
|
||||
@app.call_tool()
|
||||
async def call_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
"""
|
||||
调用工具函数
|
||||
|
||||
Args:
|
||||
name: 工具函数名称
|
||||
arguments: 工具函数参数
|
||||
|
||||
Returns:
|
||||
工具函数的执行结果
|
||||
"""
|
||||
try:
|
||||
print(f"调用{name},参数为{arguments}")
|
||||
result = await call_mars_toolkit_function(name, arguments)
|
||||
print("result",result)
|
||||
# 将结果转换为字符串
|
||||
if isinstance(result, (dict, list)):
|
||||
result_str = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
result_str = str(result)
|
||||
|
||||
return [types.TextContent(type="text", text=result_str)]
|
||||
except Exception as e:
|
||||
error_msg = f"调用工具函数 {name} 时出错: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
@app.list_tools()
|
||||
async def list_tools() -> List[types.Tool]:
|
||||
"""
|
||||
列出所有可用的工具函数
|
||||
|
||||
Returns:
|
||||
工具函数列表
|
||||
"""
|
||||
tools = []
|
||||
print("列举所有可用的工具函数")
|
||||
for func_name, schema in schemas_dict.items():
|
||||
# 获取函数描述
|
||||
description = schema["function"].get("description", f"Mars Toolkit工具: {func_name}")
|
||||
|
||||
# 获取参数模式
|
||||
parameters = schema["function"].get("parameters", {})
|
||||
|
||||
# 创建工具
|
||||
tool = types.Tool(
|
||||
name=func_name,
|
||||
description=description,
|
||||
inputSchema=parameters,
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
if transport == "sse":
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await app.run(
|
||||
streams[0], streams[1], app.create_initialization_options()
|
||||
)
|
||||
|
||||
starlette_app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
|
||||
else:
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
async def arun():
|
||||
async with stdio_server() as streams:
|
||||
await app.run(
|
||||
streams[0], streams[1], app.create_initialization_options()
|
||||
)
|
||||
|
||||
anyio.run(arun)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_tool_schemas_dict())
|
||||
main()
|
||||
|
||||
437
test_tools/agent_test.py
Executable file
437
test_tools/agent_test.py
Executable file
@@ -0,0 +1,437 @@
|
||||
import asyncio
|
||||
|
||||
from api_key import *
|
||||
from openai import OpenAI
|
||||
import json
|
||||
from typing import Dict, List, Any, Union, Optional
|
||||
from rich.console import Console
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
from sci_mcp import *
|
||||
|
||||
# 获取工具
|
||||
all_tools_schemas = get_all_tool_schemas()
|
||||
tools = get_all_tools()
|
||||
chemistry_tools = get_domain_tools("chemistry")
|
||||
|
||||
console = Console()
|
||||
#console.print(all_tools_schemas)
|
||||
|
||||
class ModelAgent:
|
||||
"""
|
||||
只支持 gpt-4o 模型的代理类
|
||||
处理返回值格式并提供统一的工具调用接口
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "gpt-4o"):
|
||||
"""
|
||||
初始化模型客户端
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
"""
|
||||
# 初始化客户端
|
||||
self.client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_URL,
|
||||
)
|
||||
|
||||
# 模型名称
|
||||
self.model_name = model_name
|
||||
|
||||
# 定义工具列表
|
||||
self.tools = all_tools_schemas
|
||||
def get_response(self, messages: List[Dict[str, Any]]) -> Any:
|
||||
"""
|
||||
获取模型的响应
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
"""
|
||||
completion = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
tools=self.tools,
|
||||
tool_choice="auto",
|
||||
temperature=0.6,
|
||||
)
|
||||
return completion
|
||||
|
||||
def extract_tool_calls(self, response: Any) -> Optional[List[Any]]:
|
||||
"""
|
||||
从响应中提取工具调用信息
|
||||
|
||||
Args:
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
工具调用列表,如果没有则返回 None
|
||||
"""
|
||||
if hasattr(response.choices[0].message, 'tool_calls'):
|
||||
return response.choices[0].message.tool_calls
|
||||
return None
|
||||
|
||||
def extract_content(self, response: Any) -> str:
|
||||
"""
|
||||
从响应中提取内容
|
||||
|
||||
Args:
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
内容字符串
|
||||
"""
|
||||
content = response.choices[0].message.content
|
||||
return content if content is not None else ""
|
||||
|
||||
def extract_finish_reason(self, response: Any) -> str:
|
||||
"""
|
||||
从响应中提取完成原因
|
||||
|
||||
Args:
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
完成原因
|
||||
"""
|
||||
return response.choices[0].finish_reason
|
||||
|
||||
async def call_tool(self, tool_name: str, tool_arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具函数,支持同步和异步函数
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
if tool_name in tools:
|
||||
tool_function = tools[tool_name]
|
||||
try:
|
||||
# 检查函数是同步的还是异步的
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
if asyncio.iscoroutinefunction(tool_function) or inspect.isawaitable(tool_function):
|
||||
# 异步调用工具函数
|
||||
tool_result = await tool_function(**tool_arguments)
|
||||
else:
|
||||
# 同步调用工具函数
|
||||
tool_result = tool_function(**tool_arguments)
|
||||
|
||||
return tool_result
|
||||
except Exception as e:
|
||||
return f"工具调用错误: {str(e)}"
|
||||
else:
|
||||
return f"未找到工具: {tool_name}"
|
||||
|
||||
async def chat(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
|
||||
"""
|
||||
与模型对话,支持工具调用
|
||||
|
||||
Args:
|
||||
messages: 初始消息列表
|
||||
max_turns: 最大对话轮数
|
||||
|
||||
Returns:
|
||||
最终回答
|
||||
"""
|
||||
current_messages = messages.copy()
|
||||
turn = 0
|
||||
|
||||
while turn < max_turns:
|
||||
turn += 1
|
||||
console.print(f"\n[bold magenta]第 {turn} 轮对话[/bold magenta]")
|
||||
|
||||
# 获取响应
|
||||
response = self.get_response(current_messages)
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# 将助手消息添加到上下文
|
||||
current_messages.append(assistant_message.model_dump())
|
||||
|
||||
# 提取内容和工具调用
|
||||
content = self.extract_content(response)
|
||||
tool_calls = self.extract_tool_calls(response)
|
||||
finish_reason = self.extract_finish_reason(response)
|
||||
|
||||
console.print(f"[green]助手回复:[/green] {content}")
|
||||
|
||||
# 如果没有工具调用或已完成,返回内容
|
||||
if tool_calls is None or finish_reason != "tool_calls":
|
||||
return content
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in tool_calls:
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = json.loads(tool_call.function.arguments)
|
||||
|
||||
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
|
||||
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
|
||||
|
||||
# 执行工具调用
|
||||
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
|
||||
console.print(f"[blue]工具结果:[/blue] {tool_result}")
|
||||
|
||||
# 添加工具结果到上下文
|
||||
current_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": tool_result
|
||||
})
|
||||
|
||||
return "达到最大对话轮数限制"
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
# 创建模型代理
|
||||
agent = ModelAgent()
|
||||
|
||||
while True:
|
||||
# 获取用户输入
|
||||
user_input = input("\n请输入问题(输入 'exit' 退出): ")
|
||||
if user_input.lower() == 'exit':
|
||||
break
|
||||
|
||||
try:
|
||||
# 使用 GPT-4o 模型
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": user_input}
|
||||
]
|
||||
response = await agent.chat(messages)
|
||||
console.print(f"\n[bold magenta]最终回答:[/bold magenta] {response}")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
# 为每个工具函数生成的问题列表
|
||||
def get_tool_questions():
|
||||
"""
|
||||
获取为每个工具函数生成的问题列表
|
||||
这些问题设计为能够引导大模型调用相应的工具函数,但不直接命令模型调用特定工具
|
||||
|
||||
Returns:
|
||||
包含工具名称和对应问题的字典
|
||||
"""
|
||||
questions = {
|
||||
# 材料科学相关工具函数
|
||||
"search_crystal_structures_from_materials_project": [
|
||||
"我想了解氧化铁(Fe2O3)的晶体结构,能帮我找一下相关信息吗?",
|
||||
"锂离子电池中常用的LiFePO4材料有什么样的晶体结构?",
|
||||
"能否帮我查询一下钙钛矿(CaTiO3)的晶体结构数据?"
|
||||
],
|
||||
|
||||
"search_material_property_from_material_project": [
|
||||
"二氧化钛(TiO2)有哪些重要的物理和化学性质?",
|
||||
"我正在研究锂电池材料,能告诉我LiCoO2的主要性质吗?",
|
||||
"硅(Si)的带隙和电子性质是什么?能帮我查一下详细数据吗?"
|
||||
],
|
||||
|
||||
"query_material_from_OQMD": [
|
||||
"能帮我从OQMD数据库中查询一下铝合金(Al-Cu)的形成能吗?",
|
||||
"我想了解镍基超合金在OQMD数据库中的热力学稳定性数据",
|
||||
"OQMD数据库中有关于锌氧化物(ZnO)的什么信息?"
|
||||
],
|
||||
|
||||
"retrieval_from_knowledge_base": [
|
||||
"有关高温超导体的最新研究进展是什么?",
|
||||
"能否从材料科学知识库中找到关于石墨烯应用的信息?",
|
||||
"我想了解钙钛矿太阳能电池的工作原理和效率限制"
|
||||
],
|
||||
|
||||
"predict_properties": [
|
||||
"这个化学式为Li2FeSiO4的材料可能有什么样的电子性质?",
|
||||
"能预测一下Na3V2(PO4)3这种材料的离子导电性吗?",
|
||||
"如果我设计一个新的钙钛矿结构,能预测它的稳定性和带隙吗?"
|
||||
],
|
||||
|
||||
"generate_material": [
|
||||
"能生成一种可能具有铁磁性的新材料结构吗?"
|
||||
],
|
||||
|
||||
"optimize_crystal_structure": [
|
||||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能帮我优化一下使其更稳定吗?"
|
||||
],
|
||||
|
||||
"calculate_density": [
|
||||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能计算一下它的密度吗?"
|
||||
],
|
||||
|
||||
"get_element_composition": [
|
||||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能分析一下它的元素组成吗?"
|
||||
],
|
||||
|
||||
"calculate_symmetry": [
|
||||
"我有一个CIF文件HEu2H3EuH2EuH5.cif,能分析一下它的对称性和空间群吗?"
|
||||
],
|
||||
|
||||
# 化学相关工具函数
|
||||
"search_pubchem_advanced": [
|
||||
"阿司匹林的分子结构和性质是什么?"
|
||||
],
|
||||
|
||||
"calculate_molecular_properties": [
|
||||
"对乙酰氨基酚的SMILES是CC(=O)NC1=CC=C(C=C1)O,计算它的物理化学性质"
|
||||
],
|
||||
|
||||
"calculate_drug_likeness": [
|
||||
"布洛芬的SMILES是CC(C)CC1=CC=C(C=C1)C(C)C(=O)O,能计算它的药物性吗?"
|
||||
],
|
||||
|
||||
"calculate_topological_descriptors": [
|
||||
"咖啡因的SMILES是CN1C=NC2=C1C(=O)N(C(=O)N2C)C,能计算它的拓扑描述符吗?"
|
||||
],
|
||||
|
||||
"generate_molecular_fingerprints": [
|
||||
"尼古丁的SMILES是CN1CCCC1C2=CN=CC=C2,能为它生成Morgan指纹吗?"
|
||||
],
|
||||
|
||||
"calculate_molecular_similarity": [
|
||||
"阿司匹林的SMILES是CC(=O)OC1=CC=CC=C1C(=O)O,对乙酰氨基酚的SMILES是CC(=O)NC1=CC=C(C=C1)O,它们的分子相似性如何?"
|
||||
],
|
||||
|
||||
"analyze_molecular_structure": [
|
||||
"苯甲酸的SMILES是C1=CC=C(C=C1)C(=O)O,能分析它的结构特征吗?"
|
||||
],
|
||||
|
||||
"generate_molecular_conformer": [
|
||||
"甲基苯并噻唑的SMILES是CC1=NC2=CC=CC=C2S1,能生成它的3D构象吗?"
|
||||
],
|
||||
|
||||
"identify_scaffolds": [
|
||||
"奎宁的SMILES是COC1=CC2=C(C=CN=C2C=C1)C(C3CC4CCN3CC4C=C)O,它的核心骨架是什么?"
|
||||
],
|
||||
|
||||
"convert_between_chemical_formats": [
|
||||
"将乙醇的SMILES:CCO转换为InChI格式"
|
||||
],
|
||||
|
||||
"standardize_molecule": [
|
||||
"将四环素的SMILES:CC1C2C(C(=O)C3(C(CC4C(C3C(=O)C2C(=O)C(=C1O)C(=O)N)O)(C(=O)CO4)O)O)N(C)C标准化处理"
|
||||
],
|
||||
|
||||
"enumerate_stereoisomers": [
|
||||
"2-丁醇的SMILES是CCC(C)O,它可能有哪些立体异构体?,不要单纯靠你自身的知识,如果不确定可以使用工具。"
|
||||
],
|
||||
|
||||
"perform_substructure_search": [
|
||||
"在阿莫西林的SMILES:CC1(C(N2C(S1)C(C2=O)NC(=O)C(C3=CC=C(C=C3)O)N)C(=O)O)C中搜索羧酸基团"
|
||||
],
|
||||
|
||||
# RXN工具函数的测试问题
|
||||
"predict_reaction_outcome": [
|
||||
"我正在研究乙酸和乙醇的酯化反应,想知道这个反应的产物是什么。反应物的SMILES表示法是'CC(=O)O.CCO'。能帮我预测一下这个反应最可能的结果吗?"
|
||||
],
|
||||
|
||||
"predict_reaction_batch": [
|
||||
"我在实验室中设计了三个酯化反应系列,想同时了解它们的可能产物。这三个反应的SMILES分别是'CC(=O)O.CCO'(乙酸和乙醇)、'CC(=O)O.CCCO'(乙酸和丙醇)和'CC(=O)O.CCCCO'(乙酸和丁醇)。能否一次性预测这些反应的结果?"
|
||||
],
|
||||
|
||||
"predict_reaction_topn": [
|
||||
"我在研究丙烯醛和甲胺的反应机理,这个反应可能有多种产物路径。反应物的SMILES是'C=CC=O.CN'。能帮我分析出最可能的前3种产物及它们的相对可能性吗?"
|
||||
],
|
||||
|
||||
"predict_retrosynthesis": [
|
||||
"我需要为实验室合成阿司匹林,但不确定最佳的合成路线。阿司匹林的SMILES是'CC(=O)OC1=CC=CC=C1C(=O)O'。能帮我分析一下可能的合成路径,将其分解为更简单的前体化合物吗?"
|
||||
],
|
||||
|
||||
"predict_biocatalytic_retrosynthesis": [
|
||||
"我们实验室正在研究绿色化学合成方法,想知道是否可以使用酶催化方式合成这个含溴的芳香化合物(SMILES: 'OC1C(O)C=C(Br)C=C1')。能帮我分析可能的生物催化合成路径吗?"
|
||||
],
|
||||
|
||||
"predict_reaction_properties": [
|
||||
"我正在研究这个有机反应的机理:'CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F',特别想了解反应中的原子映射关系。能帮我分析一下反应前后各原子的对应关系吗?我需要atom-mapping属性。"
|
||||
],
|
||||
|
||||
"extract_reaction_actions": [
|
||||
"我从一篇有机合成文献中找到了这段实验步骤:'To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.' 能帮我将这段文本转换为结构化的反应步骤吗?这样我可以更清晰地理解每个操作。"
|
||||
]
|
||||
}
|
||||
|
||||
return questions
|
||||
|
||||
# 测试特定工具函数的问题
|
||||
async def test_tool_with_question(question_index: int = 0):
|
||||
"""
|
||||
使用预设问题测试特定工具函数
|
||||
|
||||
Args:
|
||||
question_index: 问题索引,默认为0
|
||||
"""
|
||||
# 获取所有工具问题
|
||||
all_questions = get_tool_questions()
|
||||
|
||||
# 创建工具名称到问题的映射
|
||||
tool_questions = {}
|
||||
for tool_name, questions in all_questions.items():
|
||||
if questions:
|
||||
tool_questions[tool_name] = questions[min(question_index, len(questions)-1)]
|
||||
|
||||
# 打印可用的工具和问题
|
||||
console.print("[bold]可用的工具和问题:[/bold]")
|
||||
for i, (tool_name, question) in enumerate(tool_questions.items(), 1):
|
||||
console.print(f"{i}. [cyan]{tool_name}[/cyan]: {question}")
|
||||
|
||||
# 选择要测试的工具
|
||||
choice = input("\n请选择要测试的工具编号(输入'all'测试所有工具): ")
|
||||
|
||||
agent = ModelAgent()
|
||||
|
||||
if choice.lower() == 'all':
|
||||
# 测试所有工具
|
||||
for tool_name, question in tool_questions.items():
|
||||
console.print(f"\n[bold]测试工具: [cyan]{tool_name}[/cyan][/bold]")
|
||||
console.print(f"问题: {question}")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": question}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await agent.chat(messages)
|
||||
#console.print(f"[green]回答:[/green] {response[:200]}...")
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
else:
|
||||
try:
|
||||
index = int(choice) - 1
|
||||
if 0 <= index < len(tool_questions):
|
||||
tool_name = list(tool_questions.keys())[index]
|
||||
question = tool_questions[tool_name]
|
||||
|
||||
console.print(f"\n[bold]测试工具: [cyan]{tool_name}[/cyan][/bold]")
|
||||
console.print(f"问题: {question}")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": question+'如果你不确定答案,请使用工具'}
|
||||
]
|
||||
|
||||
response = await agent.chat(messages)
|
||||
#console.print(f"[green]回答:[/green] {response}")
|
||||
else:
|
||||
console.print("[bold red]无效的选择[/bold red]")
|
||||
except ValueError:
|
||||
console.print("[bold red]请输入有效的数字[/bold red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 取消注释以运行主函数
|
||||
# asyncio.run(main())
|
||||
|
||||
# 取消注释以测试工具函数问题
|
||||
asyncio.run(test_tool_with_question())
|
||||
|
||||
# pass
|
||||
|
||||
# 知识检索API的接口 数据库
|
||||
8
test_tools/api_key.py
Executable file
8
test_tools/api_key.py
Executable file
@@ -0,0 +1,8 @@
|
||||
OPENAI_API_KEY='sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d'
|
||||
OPENAI_API_URL='https://vip.apiyi.com/v1'
|
||||
|
||||
#OPENAI_API_KEY='gpustack_56f0adc61a865d22_c61cdbf601fa2cb95979d417618060e6'
|
||||
#OPENAI_API_URL='http://192.168.191.100:5080/v1'
|
||||
|
||||
|
||||
|
||||
123
test_tools/chemistry/test_pubchem.py
Normal file
123
test_tools/chemistry/test_pubchem.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Test script for PubChem tools
|
||||
|
||||
This script tests the search_pubchem_advanced function from the chemistry_mcp module.
|
||||
"""
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
import asyncio
|
||||
|
||||
from sci_mcp.chemistry_mcp.pubchem_tools.pubchem_tools import _search_by_formula
|
||||
|
||||
|
||||
from sci_mcp.chemistry_mcp import search_pubchem_advanced
|
||||
|
||||
async def test_search_by_name():
|
||||
"""Test searching compounds by name"""
|
||||
print("\n=== Testing search by name ===")
|
||||
result = await search_pubchem_advanced(name="Aspirin")
|
||||
print(result)
|
||||
|
||||
async def test_search_by_smiles():
|
||||
"""Test searching compounds by SMILES notation"""
|
||||
print("\n=== Testing search by SMILES ===")
|
||||
# SMILES for Caffeine
|
||||
result = await search_pubchem_advanced(smiles="CN1C=NC2=C1C(=O)N(C(=O)N2C)C")
|
||||
print(result)
|
||||
|
||||
async def test_search_by_formula():
|
||||
"""Test searching compounds by molecular formula"""
|
||||
print("\n=== Testing search by formula ===")
|
||||
# Formula for Aspirin
|
||||
result = await search_pubchem_advanced(formula="C9H8O4", max_results=2)
|
||||
print(result)
|
||||
|
||||
async def test_complex_formula():
|
||||
"""Test searching with a more complex formula that might cause timeout"""
|
||||
print("\n=== Testing complex formula search ===")
|
||||
# A more complex formula that might return many results
|
||||
result = await search_pubchem_advanced(
|
||||
formula="C6H12O6", # Glucose and isomers
|
||||
max_results=5
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def test_complex_molecules():
|
||||
"""Test searching for complex molecules with rich molecular features"""
|
||||
print("\n=== Testing complex molecules with rich features ===")
|
||||
|
||||
# 1. Paclitaxel (Taxol) - Complex anticancer drug with many rotatable bonds and H-bond donors/acceptors
|
||||
print("\n--- Testing Paclitaxel (anticancer drug) ---")
|
||||
result = await search_pubchem_advanced(name="Paclitaxel")
|
||||
print(result)
|
||||
|
||||
# 2. Vancomycin - Complex antibiotic with many H-bond donors/acceptors
|
||||
print("\n--- Testing Vancomycin (antibiotic) ---")
|
||||
result = await search_pubchem_advanced(name="Vancomycin")
|
||||
print(result)
|
||||
|
||||
# 3. Cholesterol - Steroid with complex ring structure
|
||||
print("\n--- Testing Cholesterol (steroid) ---")
|
||||
result = await search_pubchem_advanced(name="Cholesterol")
|
||||
print(result)
|
||||
|
||||
# 4. Ibuprofen - Common NSAID with rotatable bonds
|
||||
print("\n--- Testing Ibuprofen (NSAID) ---")
|
||||
result = await search_pubchem_advanced(name="Ibuprofen")
|
||||
print(result)
|
||||
|
||||
# 5. Amoxicillin - Antibiotic with multiple functional groups
|
||||
print("\n--- Testing Amoxicillin (antibiotic) ---")
|
||||
result = await search_pubchem_advanced(name="Amoxicillin")
|
||||
print(result)
|
||||
|
||||
async def test_molecules_by_smiles():
|
||||
"""Test searching for complex molecules using SMILES notation"""
|
||||
print("\n=== Testing complex molecules by SMILES ===")
|
||||
|
||||
# 1. Atorvastatin (Lipitor) - Cholesterol-lowering drug with complex structure
|
||||
print("\n--- Testing Atorvastatin (Lipitor) ---")
|
||||
result = await search_pubchem_advanced(
|
||||
smiles="CC(C)C1=C(C(=C(C=C1)C(C)C)C2=CC(=C(C=C2)F)F)C(CC(CC(=O)O)O)NC(=O)C3=CC=C(C=C3)F"
|
||||
)
|
||||
print(result)
|
||||
|
||||
# 2. Morphine - Opioid with multiple rings and H-bond features
|
||||
print("\n--- Testing Morphine ---")
|
||||
result = await search_pubchem_advanced(
|
||||
smiles="CN1CCC23C4C1CC5=C2C(=C(C=C5)O)OC3C(C=C4)O"
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def test_invalid_search():
|
||||
"""Test searching with invalid parameters"""
|
||||
print("\n=== Testing invalid search ===")
|
||||
# No parameters provided
|
||||
result = await search_pubchem_advanced()
|
||||
print(result)
|
||||
|
||||
# Invalid SMILES
|
||||
print("\n=== Testing invalid SMILES ===")
|
||||
result = await search_pubchem_advanced(smiles="INVALID_SMILES_STRING")
|
||||
print(result)
|
||||
|
||||
async def run_all_tests():
|
||||
"""Run all test functions"""
|
||||
await test_search_by_name()
|
||||
await test_search_by_smiles()
|
||||
await test_search_by_formula()
|
||||
await test_complex_formula()
|
||||
# await test_complex_molecules()
|
||||
# await test_molecules_by_smiles()
|
||||
#await test_invalid_search()
|
||||
# from sci_mcp.chemistry_mcp.pubchem_tools.pubchem_tools import _search_by_name
|
||||
# compounds=await _search_by_formula('C6H12O6')
|
||||
# print(compounds[0])
|
||||
if __name__ == "__main__":
|
||||
print("Testing PubChem search tools...")
|
||||
asyncio.run(run_all_tests())
|
||||
print("\nAll tests completed.")
|
||||
# import pubchempy
|
||||
# compunnds = pubchempy.get_compounds('Aspirin', 'name')
|
||||
# print(compunnds[0].to_dict())
|
||||
|
||||
159
test_tools/chemistry/test_rdkit.py
Normal file
159
test_tools/chemistry/test_rdkit.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Test script for RDKit tools.
|
||||
|
||||
This script tests the functionality of the RDKit tools implemented in the
|
||||
sci_mcp/chemistry_mcp/rdkit_tools module.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the project root directory to the Python path
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
|
||||
from sci_mcp.chemistry_mcp.rdkit_tools.rdkit_tools import (
|
||||
calculate_molecular_properties,
|
||||
calculate_drug_likeness,
|
||||
calculate_topological_descriptors,
|
||||
generate_molecular_fingerprints,
|
||||
calculate_molecular_similarity,
|
||||
analyze_molecular_structure,
|
||||
generate_molecular_conformer,
|
||||
identify_scaffolds,
|
||||
convert_between_chemical_formats,
|
||||
standardize_molecule,
|
||||
enumerate_stereoisomers,
|
||||
perform_substructure_search
|
||||
)
|
||||
|
||||
def test_molecular_properties():
|
||||
"""Test the calculation of molecular properties."""
|
||||
print("Testing calculate_molecular_properties...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = calculate_molecular_properties(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_drug_likeness():
|
||||
"""Test the calculation of drug-likeness properties."""
|
||||
print("Testing calculate_drug_likeness...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = calculate_drug_likeness(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_topological_descriptors():
|
||||
"""Test the calculation of topological descriptors."""
|
||||
print("Testing calculate_topological_descriptors...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = calculate_topological_descriptors(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_molecular_fingerprints():
|
||||
"""Test the generation of molecular fingerprints."""
|
||||
print("Testing generate_molecular_fingerprints...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = generate_molecular_fingerprints(smiles, fingerprint_type="morgan")
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_molecular_similarity():
|
||||
"""Test the calculation of molecular similarity."""
|
||||
print("Testing calculate_molecular_similarity...")
|
||||
# Aspirin and Ibuprofen
|
||||
smiles1 = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin
|
||||
smiles2 = "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" # Ibuprofen
|
||||
result = calculate_molecular_similarity(smiles1, smiles2)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_molecular_structure():
|
||||
"""Test the analysis of molecular structure."""
|
||||
print("Testing analyze_molecular_structure...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = analyze_molecular_structure(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_molecular_conformer():
|
||||
"""Test the generation of molecular conformers."""
|
||||
print("Testing generate_molecular_conformer...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = generate_molecular_conformer(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_scaffolds():
|
||||
"""Test the identification of molecular scaffolds."""
|
||||
print("Testing identify_scaffolds...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = identify_scaffolds(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_format_conversion():
|
||||
"""Test the conversion between chemical formats."""
|
||||
print("Testing convert_between_chemical_formats...")
|
||||
# Aspirin
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
result = convert_between_chemical_formats(smiles, "smiles", "inchi")
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_standardize_molecule():
|
||||
"""Test the standardization of molecules."""
|
||||
print("Testing standardize_molecule...")
|
||||
# Betaine with charges
|
||||
smiles = "C[N+](C)(C)CC(=O)[O-]"
|
||||
result = standardize_molecule(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_stereoisomers():
|
||||
"""Test the enumeration of stereoisomers."""
|
||||
print("Testing enumerate_stereoisomers...")
|
||||
# 3-penten-2-ol (has both a stereocenter and a stereobond)
|
||||
smiles = "CC(O)C=CC"
|
||||
result = enumerate_stereoisomers(smiles)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def test_substructure_search():
|
||||
"""Test the substructure search."""
|
||||
print("Testing perform_substructure_search...")
|
||||
# Aspirin, search for carboxylic acid group
|
||||
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
|
||||
pattern = "C(=O)O"
|
||||
result = perform_substructure_search(smiles, pattern)
|
||||
print(result)
|
||||
print("-" * 80)
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("Testing RDKit tools...\n")
|
||||
|
||||
# Uncomment the tests you want to run
|
||||
test_molecular_properties()
|
||||
test_drug_likeness()
|
||||
test_topological_descriptors()
|
||||
test_molecular_fingerprints()
|
||||
test_molecular_similarity()
|
||||
test_molecular_structure()
|
||||
test_molecular_conformer()
|
||||
test_scaffolds()
|
||||
test_format_conversion()
|
||||
test_standardize_molecule()
|
||||
test_stereoisomers()
|
||||
test_substructure_search()
|
||||
|
||||
print("All tests completed.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
289
test_tools/chemistry/test_rxn.py
Normal file
289
test_tools/chemistry/test_rxn.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
测试RXN工具函数模块
|
||||
|
||||
此模块包含用于测试rxn_tools模块中化学反应预测和分析工具函数的测试用例。
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
|
||||
from sci_mcp.chemistry_mcp.rxn_tools.rxn_tools import (
|
||||
predict_reaction_outcome_rxn,
|
||||
predict_reaction_topn_rxn,
|
||||
predict_reaction_properties_rxn,
|
||||
extract_reaction_actions_rxn
|
||||
)
|
||||
|
||||
# 创建控制台对象用于格式化输出
|
||||
console = Console()
|
||||
|
||||
|
||||
async def test_predict_reaction_outcome():
|
||||
"""测试反应结果预测功能"""
|
||||
console.print("[bold cyan]测试反应结果预测功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:溴和蒽的反应
|
||||
fixed_reactants = "BrBr.c1ccc2cc3ccccc3cc2c1"
|
||||
console.print(f"固定反应物: {fixed_reactants}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_outcome(fixed_reactants)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_predict_reaction_topn():
|
||||
"""测试多产物预测功能"""
|
||||
console.print("\n[bold cyan]测试多产物预测功能[/bold cyan]")
|
||||
|
||||
# 测试1:单个反应(字符串格式)
|
||||
fixed_reactants = "C=CC=O.CN" # 丙烯醛和甲胺
|
||||
fixed_topn = 2
|
||||
console.print(f"测试1 - 单个反应(字符串格式)")
|
||||
console.print(f"固定反应物: {fixed_reactants}")
|
||||
console.print(f"固定预测产物数量: {fixed_topn}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_topn(fixed_reactants, fixed_topn)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
# 测试2:单个反应(列表格式)
|
||||
fixed_reactants_list = ["BrBr", "c1ccc2cc3ccccc3cc2c1"] # 溴和蒽
|
||||
console.print(f"\n测试2 - 单个反应(列表格式)")
|
||||
console.print(f"固定反应物: {fixed_reactants_list}")
|
||||
console.print(f"固定预测产物数量: {fixed_topn}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_topn(fixed_reactants_list, fixed_topn)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
# 测试3:多个反应(列表的列表格式)
|
||||
fixed_reactants_batch = [
|
||||
["BrBr", "c1ccc2cc3ccccc3cc2c1"], # 溴和蒽
|
||||
["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"] # 溴和修饰的蒽
|
||||
]
|
||||
console.print(f"\n测试3 - 多个反应(列表的列表格式)")
|
||||
console.print(f"固定反应物批量: {fixed_reactants_batch}")
|
||||
console.print(f"固定预测产物数量: {fixed_topn}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_topn(fixed_reactants_batch, fixed_topn)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_predict_retrosynthesis():
|
||||
"""测试逆合成分析功能"""
|
||||
console.print("\n[bold cyan]测试逆合成分析功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:阿司匹林的逆合成分析
|
||||
fixed_target_molecule = "CC(=O)OC1=CC=CC=C1C(=O)O" # 阿司匹林
|
||||
fixed_max_steps = 1
|
||||
console.print(f"固定目标分子: {fixed_target_molecule}")
|
||||
console.print(f"固定最大步骤数: {fixed_max_steps}")
|
||||
|
||||
try:
|
||||
result = await predict_retrosynthesis(fixed_target_molecule, fixed_max_steps)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_predict_biocatalytic_retrosynthesis():
|
||||
"""测试生物催化逆合成分析功能"""
|
||||
console.print("\n[bold cyan]测试生物催化逆合成分析功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:一个可能适合酶催化的分子
|
||||
fixed_target_molecule = "OC1C(O)C=C(Br)C=C1"
|
||||
console.print(f"固定目标分子: {fixed_target_molecule}")
|
||||
|
||||
try:
|
||||
result = await predict_biocatalytic_retrosynthesis(fixed_target_molecule)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_predict_reaction_properties():
|
||||
"""测试反应属性预测功能"""
|
||||
console.print("\n[bold cyan]测试反应属性预测功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:原子映射
|
||||
fixed_reaction = "CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F"
|
||||
fixed_property_type = "atom-mapping"
|
||||
console.print(f"固定反应: {fixed_reaction}")
|
||||
console.print(f"固定属性类型: {fixed_property_type}")
|
||||
|
||||
try:
|
||||
result = await predict_reaction_properties(fixed_reaction, fixed_property_type)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_extract_reaction_actions():
|
||||
"""测试从文本提取反应步骤功能"""
|
||||
console.print("\n[bold cyan]测试从文本提取反应步骤功能[/bold cyan]")
|
||||
|
||||
# 使用固定参数:从文本描述中提取反应步骤
|
||||
fixed_reaction_text = """To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour."""
|
||||
console.print(f"固定反应文本: {fixed_reaction_text}")
|
||||
|
||||
try:
|
||||
result = await extract_reaction_actions(fixed_reaction_text)
|
||||
console.print("[green]结果:[/green]")
|
||||
console.print(result)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
async def test_all():
|
||||
"""测试所有RXN工具函数"""
|
||||
console.print("[bold magenta]===== 开始测试RXN工具函数 =====[/bold magenta]\n")
|
||||
|
||||
# 测试各个功能
|
||||
await test_predict_reaction_outcome()
|
||||
await test_predict_reaction_topn()
|
||||
await test_predict_retrosynthesis()
|
||||
await test_predict_biocatalytic_retrosynthesis()
|
||||
await test_predict_reaction_properties()
|
||||
await test_extract_reaction_actions()
|
||||
|
||||
console.print("\n[bold magenta]===== RXN工具函数测试完成 =====[/bold magenta]")
|
||||
|
||||
|
||||
def get_rxn_tool_questions():
|
||||
"""
|
||||
获取为RXN工具函数生成的问题列表
|
||||
这些问题设计为能够引导大模型调用相应的工具函数
|
||||
|
||||
Returns:
|
||||
包含工具名称和对应问题的字典
|
||||
"""
|
||||
questions = {
|
||||
"predict_reaction_outcome": [
|
||||
"如果我将溴和蒽混合在一起,会形成什么产物?",
|
||||
"乙酸和乙醇反应会生成什么?",
|
||||
"预测一下丙烯醛和甲胺反应的结果"
|
||||
],
|
||||
|
||||
|
||||
"predict_reaction_topn": [
|
||||
"丙烯醛和甲胺反应可能生成哪几种主要产物?",
|
||||
"预测溴和蒽反应可能的前3个产物",
|
||||
"乙酸和乙醇反应可能有哪些不同的结果?请给出最可能的几种产物"
|
||||
],
|
||||
|
||||
"predict_retrosynthesis": [
|
||||
"如何合成阿司匹林?请给出可能的合成路线",
|
||||
"对于分子CC(=O)OC1=CC=CC=C1C(=O)O,有哪些可能的合成路径?",
|
||||
"请分析一下布洛芬的可能合成路线"
|
||||
],
|
||||
|
||||
"predict_biocatalytic_retrosynthesis": [
|
||||
"有没有可能用酶催化合成OC1C(O)C=C(Br)C=C1这个分子?",
|
||||
"请提供一种使用生物催化方法合成对羟基苯甲醇的路线",
|
||||
"我想用酶催化方法合成一些复杂分子,能否分析一下OC1C(O)C=C(Br)C=C1的可能合成路径?"
|
||||
],
|
||||
|
||||
"predict_reaction_properties": [
|
||||
"在这个反应中:CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F,原子是如何映射的?",
|
||||
"这个反应的产率可能是多少:Clc1ccccn1.Cc1ccc(N)cc1>>Cc1ccc(Nc2ccccn2)cc1",
|
||||
"能分析一下这个反应中原子的去向吗:CC(=O)O.CCO>>CC(=O)OCC"
|
||||
],
|
||||
|
||||
"extract_reaction_actions": [
|
||||
"能否将这段实验描述转换为结构化的反应步骤?'To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.'",
|
||||
"请从这段文本中提取出具体的反应操作步骤:'A solution of benzoic acid (1.0 g, 8.2 mmol) in thionyl chloride (10 mL) was heated under reflux for 2 hours. The excess thionyl chloride was removed under reduced pressure to give benzoyl chloride as a colorless liquid.'",
|
||||
"帮我解析这个实验步骤,提取出关键操作:'The aldehyde (5 mmol) was dissolved in methanol (20 mL) and sodium borohydride (7.5 mmol) was added portionwise at 0°C. The mixture was allowed to warm to room temperature and stirred for 3 hours.'"
|
||||
]
|
||||
}
|
||||
|
||||
return questions
|
||||
|
||||
|
||||
def update_agent_test_questions():
|
||||
"""
|
||||
更新agent_test.py中的工具问题字典,添加RXN工具函数的问题
|
||||
"""
|
||||
try:
|
||||
# 获取agent_test.py文件路径
|
||||
agent_test_path = Path('/home/ubuntu/sas0/lzy/multi_mcp_server/test_tools/agent_test.py')
|
||||
|
||||
# 读取文件内容
|
||||
with open(agent_test_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# 获取RXN工具函数的问题
|
||||
rxn_questions = get_rxn_tool_questions()
|
||||
|
||||
# 检查文件中是否已包含RXN工具函数的问题
|
||||
rxn_tools_exist = any(tool in content for tool in rxn_questions.keys())
|
||||
|
||||
if not rxn_tools_exist:
|
||||
# 找到questions字典的结束位置
|
||||
dict_end_pos = content.find(' return questions')
|
||||
|
||||
if dict_end_pos != -1:
|
||||
# 构建要插入的RXN工具函数问题
|
||||
rxn_questions_str = ""
|
||||
for tool_name, questions_list in rxn_questions.items():
|
||||
rxn_questions_str += f'\n "{tool_name}": [\n'
|
||||
for q in questions_list:
|
||||
rxn_questions_str += f' "{q}",\n'
|
||||
rxn_questions_str += ' ],'
|
||||
|
||||
# 在字典结束前插入RXN工具函数问题
|
||||
new_content = content[:dict_end_pos] + rxn_questions_str + content[dict_end_pos:]
|
||||
|
||||
# 写回文件
|
||||
with open(agent_test_path, 'w') as f:
|
||||
f.write(new_content)
|
||||
|
||||
console.print("[green]成功更新agent_test.py,添加了RXN工具函数的测试问题[/green]")
|
||||
else:
|
||||
console.print("[yellow]无法找到questions字典的结束位置,未更新agent_test.py[/yellow]")
|
||||
else:
|
||||
console.print("[yellow]agent_test.py中已包含RXN工具函数的问题,无需更新[/yellow]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]更新agent_test.py时出错:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行所有测试
|
||||
# asyncio.run(test_all())
|
||||
|
||||
# # 更新agent_test.py中的工具问题
|
||||
# update_agent_test_questions()
|
||||
|
||||
api_key = 'apk-8928522a146c2503f30b16d9909222d7583f412ee8f1049f08d32a089ba88d34'
|
||||
from rxn4chemistry import RXN4ChemistryWrapper
|
||||
|
||||
rxn4chemistry_wrapper = RXN4ChemistryWrapper(api_key=api_key)
|
||||
rxn4chemistry_wrapper.create_project('test_wrapper')
|
||||
response = rxn4chemistry_wrapper.predict_automatic_retrosynthesis(
|
||||
'Brc1c2ccccc2c(Br)c2ccccc12')
|
||||
results = rxn4chemistry_wrapper.get_predict_automatic_retrosynthesis_results(response['prediction_id'])
|
||||
print(results['status'])
|
||||
# NOTE: upon 'SUCCESS' you can inspect the predicted retrosynthetic paths.
|
||||
print(results['retrosynthetic_paths'][0])
|
||||
55
test_tools/complex_material_query.py
Normal file
55
test_tools/complex_material_query.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
from test_tools.multi_round_conversation import process_conversation_round
|
||||
|
||||
# 初始化rich控制台
|
||||
console = Console()
|
||||
|
||||
# 设计一个简单但需要多轮查询的问题(可能会调用mattergen)
|
||||
# complex_question = """我想了解LiFePO4材料在不同温度下的性能变化。请先告诉我这种材料的基本结构特性。"""
|
||||
|
||||
# 设计一个不调用mattergen但仍然可以触发多轮工具调用的问题(之前的尝试)
|
||||
# complex_question = """我想比较TiO2和ZnO这两种材料作为光催化剂的性能。请先告诉我TiO2的晶体结构和能带特性。"""
|
||||
|
||||
# 设计一个需要先获取信息然后基于这些信息进行进一步分析的问题
|
||||
complex_question = """我需要分析一种名为Na2Fe2(SO4)3的钠离子电池材料。请先查询这种材料的晶体结构。"""
|
||||
|
||||
def run_complex_query():
|
||||
"""运行复杂的材料科学查询演示"""
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]复杂材料科学查询演示[/bold cyan] - 测试多轮对话逻辑",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# 处理复杂问题
|
||||
conversation_history = process_conversation_round(complex_question)
|
||||
|
||||
# 多轮对话循环
|
||||
while True:
|
||||
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit' 或 'quit' 退出[/bold cyan]")
|
||||
user_input = input("> ")
|
||||
|
||||
# 检查是否退出
|
||||
if user_input.lower() in ['exit', 'quit', '退出']:
|
||||
console.print("[bold cyan]演示结束,再见![/bold cyan]")
|
||||
break
|
||||
|
||||
# 处理用户输入
|
||||
conversation_history = process_conversation_round(user_input, conversation_history)
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
run_complex_query()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
|
||||
import traceback
|
||||
console.print(traceback.format_exc())
|
||||
375
test_tools/demo_conversation.py
Normal file
375
test_tools/demo_conversation.py
Normal file
@@ -0,0 +1,375 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}]
|
||||
# 初始化rich控制台
|
||||
console = Console()
|
||||
|
||||
# 获取工具模式和映射
|
||||
tools_schemas = get_domain_tool_schemas(["material", 'general'])
|
||||
tool_map = get_domain_tools(["material", 'general'])
|
||||
|
||||
# API配置
|
||||
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url = "http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
def get_t1_response(messages):
|
||||
"""获取T1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]正在调用MARS-T1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
reasoning_content = choice.message.content
|
||||
|
||||
tool_calls_list = []
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
|
||||
return reasoning_content, tool_calls_list
|
||||
|
||||
async def execute_tool(tool_name, tool_arguments):
|
||||
"""执行工具调用"""
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
def get_all_tool_calls_results(tool_calls_list):
|
||||
"""获取所有工具调用的结果"""
|
||||
all_results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold green]正在执行工具调用..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
task = progress.add_task("", total=len(tool_calls_list))
|
||||
|
||||
for tool_call in tool_calls_list:
|
||||
tool_name = tool_call['function']['name']
|
||||
tool_arguments = tool_call['function']['arguments']
|
||||
|
||||
# 显示当前执行的工具
|
||||
progress.update(task, description=f"执行 {tool_name}")
|
||||
|
||||
result = asyncio.run(execute_tool(tool_name, tool_arguments))
|
||||
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
|
||||
all_results.append(result_str)
|
||||
|
||||
# 更新进度
|
||||
progress.update(task, advance=1)
|
||||
|
||||
return all_results
|
||||
|
||||
def get_response_from_r1(messages):
|
||||
"""获取R1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold purple]正在调用MARS-R1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-R1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
return choice.message.content
|
||||
|
||||
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
|
||||
"""显示单条消息"""
|
||||
title = role.capitalize()
|
||||
if model:
|
||||
title = f"{model} {title}"
|
||||
|
||||
if role == "user":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[{title_style}]{title}[/{title_style}]",
|
||||
border_style=border_style,
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-T1":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold yellow]{title}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
elif role == "tool":
|
||||
# 创建一个表格来显示工具调用结果
|
||||
table = Table(box=box.ROUNDED, expand=False, show_header=False)
|
||||
table.add_column("内容", style="green")
|
||||
|
||||
# 分割工具调用结果并添加到表格
|
||||
results = content.split("\n")
|
||||
for result in results:
|
||||
table.add_row(result)
|
||||
|
||||
console.print(Panel(
|
||||
table,
|
||||
title=f"[bold green]{title}[/bold green]",
|
||||
border_style="green",
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-R1":
|
||||
try:
|
||||
# 尝试将内容解析为Markdown
|
||||
md = Markdown(content)
|
||||
console.print(Panel(
|
||||
md,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
except:
|
||||
# 如果解析失败,直接显示文本
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
|
||||
def process_conversation_round(user_input, conversation_history=None):
|
||||
"""处理一轮对话,返回更新后的对话历史"""
|
||||
if conversation_history is None:
|
||||
conversation_history = []
|
||||
|
||||
# 添加用户消息到历史
|
||||
conversation_history.append({
|
||||
"role": "user",
|
||||
"content": user_input
|
||||
})
|
||||
|
||||
# 显示用户消息
|
||||
display_message("user", user_input)
|
||||
|
||||
# 准备发送给T1模型的消息
|
||||
t1_messages = []
|
||||
for msg in conversation_history:
|
||||
if msg["role"] in ["user", "assistant"]:
|
||||
t1_messages.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"]
|
||||
})
|
||||
|
||||
# 获取T1模型的响应
|
||||
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
|
||||
|
||||
# 添加T1推理到历史
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-T1"
|
||||
})
|
||||
|
||||
# 显示T1推理
|
||||
display_message("assistant", reasoning_content, model="MARS-T1")
|
||||
|
||||
# 如果有工具调用,执行并获取结果
|
||||
if tool_calls_list:
|
||||
tool_call_results = get_all_tool_calls_results(tool_calls_list)
|
||||
tool_call_results_str = "\n".join(tool_call_results)
|
||||
|
||||
# 添加工具调用结果到历史
|
||||
conversation_history.append({
|
||||
"role": "tool",
|
||||
"content": tool_call_results_str
|
||||
})
|
||||
|
||||
# 显示工具调用结果
|
||||
display_message("tool", tool_call_results_str)
|
||||
|
||||
# 准备发送给R1模型的消息
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": f"# 信息如下:\n{tool_call_results_str}\n# 问题如下:\n{user_input}"
|
||||
}
|
||||
|
||||
# 获取R1模型的响应
|
||||
r1_response = get_response_from_r1([user_message])
|
||||
|
||||
# 添加R1回答到历史
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": r1_response,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
# 显示R1回答
|
||||
display_message("assistant", r1_response, model="MARS-R1")
|
||||
else:
|
||||
# 如果没有工具调用,直接使用T1的推理作为回答
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
# 显示R1回答(实际上是T1的推理)
|
||||
display_message("assistant", reasoning_content, model="MARS-R1")
|
||||
|
||||
return conversation_history
|
||||
|
||||
def run_demo():
|
||||
"""运行演示,使用初始消息作为第一个用户问题"""
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# 获取初始用户问题
|
||||
initial_user_input = initial_message[0]["content"]
|
||||
|
||||
# 处理第一轮对话
|
||||
conversation_history = process_conversation_round(initial_user_input)
|
||||
|
||||
# 多轮对话循环
|
||||
while True:
|
||||
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit' 或 'quit' 退出[/bold cyan]")
|
||||
user_input = input("> ")
|
||||
|
||||
# 检查是否退出
|
||||
if user_input.lower() in ['exit', 'quit', '退出']:
|
||||
console.print("[bold cyan]演示结束,再见![/bold cyan]")
|
||||
break
|
||||
|
||||
# 处理用户输入
|
||||
conversation_history = process_conversation_round(user_input, conversation_history)
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
run_demo()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
|
||||
import traceback
|
||||
console.print(traceback.format_exc())
|
||||
7
test_tools/general/test_searxng.py
Normal file
7
test_tools/general/test_searxng.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server/')
|
||||
from sci_mcp.general_mcp.searxng_query.searxng_query_tools import search_online
|
||||
import asyncio
|
||||
|
||||
# 字典
|
||||
print(asyncio.run(search_online("CsPbBr3", 5)))
|
||||
14
test_tools/material/mattergen/extract_cif.py
Normal file
14
test_tools/material/mattergen/extract_cif.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
from sci_mcp.material_mcp.mattergen_gen.mattergen_service import format_cif_content
|
||||
|
||||
cif_zip_path = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/material/20250508110506/generated_crystals_cif.zip'
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
cif_content = f.read().decode('utf-8', errors='replace')
|
||||
print(format_cif_content(cif_content))
|
||||
|
||||
|
||||
|
||||
|
||||
166
test_tools/material/mattergen/test_mattergen.py
Normal file
166
test_tools/material/mattergen/test_mattergen.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Test script for mattergen_gen/material_gen_tools.py
|
||||
|
||||
This script tests the generate_material function from the material_gen_tools module.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import unittest
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from sci_mcp.material_mcp.mattergen_gen.material_gen_tools import generate_material_MatterGen
|
||||
|
||||
|
||||
class TestMatterGen(unittest.TestCase):
|
||||
"""Test cases for MatterGen material generation tools."""
|
||||
|
||||
def test_unconditional_generation(self):
|
||||
"""Test unconditional crystal structure generation."""
|
||||
# 无条件生成(不指定属性)
|
||||
result = generate_material_MatterGen(properties=None, batch_size=1, num_batches=1)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
# 检查结果是否包含一些常见的描述性文本
|
||||
self.assertIn("Material", result)
|
||||
self.assertIn("structures", result)
|
||||
|
||||
print("无条件生成结果示例:")
|
||||
print(result[:500] + "...\n" if len(result) > 500 else result)
|
||||
|
||||
return result
|
||||
|
||||
# def test_single_property_generation(self):
|
||||
# """Test crystal structure generation with a single property constraint."""
|
||||
# # 单属性条件生成 - 使用化学系统属性
|
||||
# properties = {"chemical_system": "Si-O"}
|
||||
# result = generate_material(properties=properties, batch_size=1, num_batches=1)
|
||||
|
||||
# # 验证结果是否包含预期的关键信息
|
||||
# self.assertIsInstance(result, str)
|
||||
# # 检查结果是否包含相关的化学元素
|
||||
# self.assertIn("Si-O", result)
|
||||
|
||||
# print("单属性条件生成结果示例:")
|
||||
# print(result[:500] + "...\n" if len(result) > 500 else result)
|
||||
|
||||
# return result
|
||||
|
||||
# def test_multi_property_generation(self):
|
||||
# """Test crystal structure generation with multiple property constraints."""
|
||||
# # 多属性条件生成
|
||||
# properties = {
|
||||
# "chemical_system": "Fe-O",
|
||||
# "space_group": 227 # 立方晶系,空间群Fd-3m
|
||||
# }
|
||||
# result = generate_material(properties=properties, batch_size=1, num_batches=1)
|
||||
|
||||
# # 验证结果是否为字符串
|
||||
# self.assertIsInstance(result, str)
|
||||
|
||||
# # 检查结果 - 可能是成功生成或错误信息
|
||||
# if "Error" in result:
|
||||
# # 如果是错误信息,验证它包含相关的属性信息
|
||||
# self.assertIn("properties", result)
|
||||
# print("多属性条件生成返回错误 (这是预期的,因为可能不支持多属性):")
|
||||
# else:
|
||||
# # 如果成功,验证包含相关元素
|
||||
# self.assertIn("Fe", result)
|
||||
# self.assertIn("O", result)
|
||||
# print("多属性条件生成成功:")
|
||||
|
||||
# print(result[:500] + "...\n" if len(result) > 500 else result)
|
||||
|
||||
# return result
|
||||
|
||||
# def test_batch_generation(self):
|
||||
# """Test generating multiple structures in batches."""
|
||||
# # 测试批量生成
|
||||
# result = generate_material(properties=None, batch_size=2, num_batches=2)
|
||||
|
||||
# # 验证结果是否包含预期的关键信息
|
||||
# self.assertIsInstance(result, str)
|
||||
|
||||
# # 检查结果是否提到了批量生成
|
||||
# self.assertIn("structures", result)
|
||||
|
||||
# print("批量生成结果示例:")
|
||||
# print(result[:500] + "...\n" if len(result) > 500 else result)
|
||||
|
||||
# return result
|
||||
|
||||
# def test_guidance_factor(self):
|
||||
# """Test the effect of diffusion guidance factor."""
|
||||
# # 测试不同的diffusion_guidance_factor值
|
||||
# properties = {"chemical_system": "Al-O"}
|
||||
|
||||
# # 使用较低的指导因子
|
||||
# result_low = generate_material(
|
||||
# properties=properties,
|
||||
# batch_size=1,
|
||||
# num_batches=1,
|
||||
# diffusion_guidance_factor=1.0
|
||||
# )
|
||||
|
||||
# # 使用较高的指导因子
|
||||
# result_high = generate_material(
|
||||
# properties=properties,
|
||||
# batch_size=1,
|
||||
# num_batches=1,
|
||||
# diffusion_guidance_factor=3.0
|
||||
# )
|
||||
|
||||
# # 验证两个结果都是有效的
|
||||
# self.assertIsInstance(result_low, str)
|
||||
# self.assertIsInstance(result_high, str)
|
||||
# self.assertIn("Al-O", result_low)
|
||||
# self.assertIn("Al-O", result_high)
|
||||
|
||||
# # 验证两个结果都提到了diffusion guidance factor
|
||||
# self.assertIn("guidance factor", result_low)
|
||||
# self.assertIn("guidance factor", result_high)
|
||||
|
||||
# print("不同指导因子的生成结果示例:")
|
||||
# print("低指导因子 (1.0):")
|
||||
# print(result_low[:300] + "...\n" if len(result_low) > 300 else result_low)
|
||||
# print("高指导因子 (3.0):")
|
||||
# print(result_high[:300] + "...\n" if len(result_high) > 300 else result_high)
|
||||
|
||||
# return result_low, result_high
|
||||
|
||||
# def test_invalid_properties(self):
|
||||
# """Test handling of invalid properties."""
|
||||
# # 测试无效属性
|
||||
# invalid_properties = {"invalid_property": "value"}
|
||||
# result = generate_material(properties=invalid_properties)
|
||||
|
||||
# # 验证结果是否为字符串
|
||||
# self.assertIsInstance(result, str)
|
||||
|
||||
# # 检查结果 - 可能返回错误信息或尝试生成
|
||||
# if "Error" in result:
|
||||
# print("无效属性测试返回错误 (预期行为):")
|
||||
# else:
|
||||
# # 如果没有返回错误,至少应该包含我们请求的属性名称
|
||||
# self.assertIn("invalid_property", result)
|
||||
# print("无效属性测试尝试生成:")
|
||||
|
||||
# print(result)
|
||||
|
||||
# return result
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""运行所有测试。"""
|
||||
unittest.main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
73
test_tools/material/test_mgl.py
Normal file
73
test_tools/material/test_mgl.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import asyncio
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
|
||||
import matgl
|
||||
from sci_mcp.material_mcp.matgl_tools import relax_crystal_structure_M3GNet,predict_formation_energy_M3GNet,run_molecular_dynamics_M3GNet,calculate_single_point_energy_M3GNet
|
||||
print(matgl.get_available_pretrained_models())
|
||||
cif_file_name = 'GdPbGdHGd.cif'
|
||||
cif_content="""data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
"""
|
||||
|
||||
#print(asyncio.run(relax_crystal_structure(cif_file_name)))
|
||||
print(asyncio.run(predict_formation_energy_M3GNet(cif_content)))
|
||||
#print(asyncio.run(calculate_single_point_energy(cif_file_name)))
|
||||
#print(asyncio.run(run_molecular_dynamics(cif_file_name)))
|
||||
|
||||
15
test_tools/material/test_mp.py
Normal file
15
test_tools/material/test_mp.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import sys
|
||||
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server/')
|
||||
from sci_mcp_server.material_mcp.mp_query.mp_query_tools import search_material_property_from_material_project,search_crystal_structures_from_materials_project
|
||||
import asyncio
|
||||
from sci_mcp_server.core.llm_tools import set_llm_context, clear_llm_context
|
||||
|
||||
set_llm_context(True)
|
||||
print(asyncio.run(search_material_property_from_material_project("CsPbBr3")))
|
||||
clear_llm_context()
|
||||
print(asyncio.run(search_material_property_from_material_project("CsPbBr3")))
|
||||
|
||||
set_llm_context(True)
|
||||
print(asyncio.run(search_crystal_structures_from_materials_project("CsPbBr3")))
|
||||
clear_llm_context()
|
||||
print(asyncio.run(search_crystal_structures_from_materials_project("CsPbBr3")))
|
||||
143
test_tools/material/test_property_pred.py
Normal file
143
test_tools/material/test_property_pred.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Test script for property_pred_tools.py
|
||||
|
||||
This script tests the predict_properties function from the property_pred_tools module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from sci_mcp.material_mcp.mattersim_pred.property_pred_tools import predict_properties_MatterSim
|
||||
|
||||
|
||||
class TestPropertyPrediction(unittest.TestCase):
|
||||
"""Test cases for property prediction tools."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# 简单的CIF字符串示例 - 硅晶体结构
|
||||
self.simple_cif = """
|
||||
data_Si
|
||||
_cell_length_a 5.43
|
||||
_cell_length_b 5.43
|
||||
_cell_length_c 5.43
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 90
|
||||
_symmetry_space_group_name_H-M 'P 1'
|
||||
_symmetry_Int_Tables_number 1
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Si 0.0 0.0 0.0
|
||||
Si 0.5 0.5 0.0
|
||||
Si 0.5 0.0 0.5
|
||||
Si 0.0 0.5 0.5
|
||||
Si 0.25 0.25 0.25
|
||||
Si 0.75 0.75 0.25
|
||||
Si 0.75 0.25 0.75
|
||||
Si 0.25 0.75 0.75
|
||||
"""
|
||||
|
||||
def test_predict_properties_async(self):
|
||||
"""Test predict_properties function with a simple CIF string (异步版本)."""
|
||||
async def _async_test():
|
||||
result = await predict_properties(self.simple_cif)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Crystal Structure Property Prediction Results", result)
|
||||
self.assertIn("Total Energy (eV):", result)
|
||||
self.assertIn("Energy per Atom (eV/atom):", result)
|
||||
self.assertIn("Forces (eV/Angstrom):", result)
|
||||
self.assertIn("Stress (GPa):", result)
|
||||
self.assertIn("Stress (eV/A^3):", result)
|
||||
|
||||
print("预测结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_predict_properties_sync(self):
|
||||
"""同步方式测试predict_properties函数。"""
|
||||
self.test_predict_properties_async()
|
||||
|
||||
|
||||
class TestPropertyPredictionWithFile(unittest.TestCase):
|
||||
"""使用文件测试属性预测工具。"""
|
||||
|
||||
def setUp(self):
|
||||
"""设置测试夹具,创建临时CIF文件。"""
|
||||
self.temp_cif_path = "temp_test_structure.cif"
|
||||
|
||||
# 简单的CIF内容 - 氧化铝结构
|
||||
cif_content = """
|
||||
data_Al2O3
|
||||
_cell_length_a 4.76
|
||||
_cell_length_b 4.76
|
||||
_cell_length_c 12.99
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 120
|
||||
_symmetry_space_group_name_H-M 'R -3 c'
|
||||
_symmetry_Int_Tables_number 167
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Al 0.0 0.0 0.35
|
||||
Al 0.0 0.0 0.85
|
||||
O 0.31 0.0 0.25
|
||||
"""
|
||||
|
||||
# 创建临时文件
|
||||
with open(self.temp_cif_path, "w") as f:
|
||||
f.write(cif_content)
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试夹具,删除临时文件。"""
|
||||
if os.path.exists(self.temp_cif_path):
|
||||
os.remove(self.temp_cif_path)
|
||||
|
||||
def test_predict_properties_from_file_async(self):
|
||||
"""测试从文件预测属性(异步版本)。"""
|
||||
async def _async_test():
|
||||
result = await predict_properties(self.temp_cif_path)
|
||||
|
||||
# 验证结果
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Crystal Structure Property Prediction Results", result)
|
||||
self.assertIn("Total Energy (eV):", result)
|
||||
|
||||
print("从文件预测结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_predict_properties_from_file_sync(self):
|
||||
"""同步方式测试从文件预测属性。"""
|
||||
self.test_predict_properties_from_file_async()
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""运行所有测试。"""
|
||||
unittest.main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
run_tests()
|
||||
208
test_tools/material/test_pymatgen_cal.py
Normal file
208
test_tools/material/test_pymatgen_cal.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Test script for pymatgen_cal_tools.py
|
||||
|
||||
This script tests the functions from the pymatgen_cal_tools module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from sci_mcp.material_mcp.pymatgen_cal.pymatgen_cal_tools import (
|
||||
calculate_density,
|
||||
get_element_composition,
|
||||
calculate_symmetry
|
||||
)
|
||||
|
||||
|
||||
class TestPymatgenCalculations(unittest.TestCase):
|
||||
"""Test cases for pymatgen calculation tools."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# 简单的CIF字符串示例 - 硅晶体结构
|
||||
self.simple_cif = """
|
||||
data_Si
|
||||
_cell_length_a 5.43
|
||||
_cell_length_b 5.43
|
||||
_cell_length_c 5.43
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 90
|
||||
_symmetry_space_group_name_H-M 'P 1'
|
||||
_symmetry_Int_Tables_number 1
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Si 0.0 0.0 0.0
|
||||
Si 0.5 0.5 0.0
|
||||
Si 0.5 0.0 0.5
|
||||
Si 0.0 0.5 0.5
|
||||
Si 0.25 0.25 0.25
|
||||
Si 0.75 0.75 0.25
|
||||
Si 0.75 0.25 0.75
|
||||
Si 0.25 0.75 0.75
|
||||
"""
|
||||
|
||||
def test_calculate_density_async(self):
|
||||
"""Test calculate_density function with a simple CIF string (异步版本)."""
|
||||
async def _async_test():
|
||||
result = await calculate_density(self.simple_cif)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Density Calculation", result)
|
||||
self.assertIn("Density", result)
|
||||
self.assertIn("g/cm³", result)
|
||||
|
||||
print("密度计算结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_get_element_composition_async(self):
|
||||
"""Test get_element_composition function with a simple CIF string (异步版本)."""
|
||||
async def _async_test():
|
||||
result = await get_element_composition(self.simple_cif)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Element Composition", result)
|
||||
self.assertIn("Composition", result)
|
||||
self.assertIn("Si", result)
|
||||
|
||||
print("元素组成计算结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_calculate_symmetry_async(self):
|
||||
"""Test calculate_symmetry function with a simple CIF string (异步版本)."""
|
||||
async def _async_test():
|
||||
result = await calculate_symmetry(self.simple_cif)
|
||||
|
||||
# 验证结果是否包含预期的关键信息
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Symmetry Information", result)
|
||||
self.assertIn("Space Group", result)
|
||||
self.assertIn("Number", result)
|
||||
|
||||
print("对称性计算结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
|
||||
class TestPymatgenCalculationsWithFile(unittest.TestCase):
|
||||
"""使用文件测试pymatgen计算工具。"""
|
||||
|
||||
def setUp(self):
|
||||
"""设置测试夹具,创建临时CIF文件。"""
|
||||
self.temp_cif_path = "temp_test_structure.cif"
|
||||
|
||||
# 简单的CIF内容 - 氧化铝结构
|
||||
cif_content = """
|
||||
data_Al2O3
|
||||
_cell_length_a 4.76
|
||||
_cell_length_b 4.76
|
||||
_cell_length_c 12.99
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 120
|
||||
_symmetry_space_group_name_H-M 'R -3 c'
|
||||
_symmetry_Int_Tables_number 167
|
||||
_symmetry_equiv_pos_as_xyz 'x, y, z'
|
||||
_symmetry_equiv_pos_as_xyz '-x, -y, -z'
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Al1 Al 0.0 0.0 0.35 1.0
|
||||
Al2 Al 0.0 0.0 0.85 1.0
|
||||
O1 O 0.31 0.0 0.25 1.0
|
||||
"""
|
||||
|
||||
# 创建临时文件
|
||||
with open(self.temp_cif_path, "w") as f:
|
||||
f.write(cif_content)
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试夹具,删除临时文件。"""
|
||||
if os.path.exists(self.temp_cif_path):
|
||||
os.remove(self.temp_cif_path)
|
||||
|
||||
def test_calculate_density_from_file_async(self):
|
||||
"""测试从文件计算密度(异步版本)。"""
|
||||
async def _async_test():
|
||||
result = await calculate_density(self.temp_cif_path)
|
||||
|
||||
# 验证结果
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Density Calculation", result)
|
||||
self.assertIn("Density", result)
|
||||
|
||||
print("从文件计算密度结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_get_element_composition_from_file_async(self):
|
||||
"""测试从文件获取元素组成(异步版本)。"""
|
||||
async def _async_test():
|
||||
result = await get_element_composition(self.temp_cif_path)
|
||||
|
||||
# 验证结果
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Element Composition", result)
|
||||
self.assertIn("Composition", result)
|
||||
|
||||
print("从文件获取元素组成结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
def test_calculate_symmetry_from_file_async(self):
|
||||
"""测试从文件计算对称性(异步版本)。"""
|
||||
async def _async_test():
|
||||
result = await calculate_symmetry(self.temp_cif_path)
|
||||
|
||||
# 验证结果
|
||||
self.assertIsInstance(result, str)
|
||||
self.assertIn("Symmetry Information", result)
|
||||
self.assertIn("Space Group", result)
|
||||
|
||||
print("从文件计算对称性结果示例:")
|
||||
print(result)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_async_test())
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""运行所有测试。"""
|
||||
unittest.main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
118
test_tools/material/test_structure_opt.py
Normal file
118
test_tools/material/test_structure_opt.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
测试晶体结构优化工具函数
|
||||
|
||||
此脚本测试改进后的optimize_crystal_structure函数,
|
||||
该函数接受单一的file_name_or_content_string参数,可以是文件路径或直接的结构内容。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
sys.path.append("/home/ubuntu/sas0/lzy/multi_mcp_server")
|
||||
|
||||
from sci_mcp.material_mcp.fairchem_structure_opt.structure_opt_tools import optimize_crystal_structure
|
||||
from sci_mcp.core.config import material_config
|
||||
|
||||
# 简单的CIF结构示例
|
||||
SAMPLE_CIF = """
|
||||
data_SrTiO3
|
||||
_cell_length_a 3.905
|
||||
_cell_length_b 3.905
|
||||
_cell_length_c 3.905
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 90
|
||||
_symmetry_space_group_name_H-M 'P m -3 m'
|
||||
_symmetry_Int_Tables_number 221
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Sr1 Sr 0.0 0.0 0.0
|
||||
Ti1 Ti 0.5 0.5 0.5
|
||||
O1 O 0.5 0.5 0.0
|
||||
O2 O 0.5 0.0 0.5
|
||||
O3 O 0.0 0.5 0.5
|
||||
"""
|
||||
|
||||
async def test_with_content():
|
||||
"""测试使用直接结构内容"""
|
||||
print("\n=== 测试使用直接结构内容 ===")
|
||||
result = await optimize_crystal_structure(
|
||||
file_name_or_content_string=SAMPLE_CIF,
|
||||
format_type="cif",
|
||||
optimization_level="quick"
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def test_with_file():
|
||||
"""测试使用文件路径(如果文件存在)"""
|
||||
print("\n=== 测试使用文件路径 ===")
|
||||
# 创建临时CIF文件
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(SAMPLE_CIF)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
result = await optimize_crystal_structure(
|
||||
file_name_or_content_string=tmp_path,
|
||||
format_type="auto",
|
||||
optimization_level="quick"
|
||||
)
|
||||
print(result)
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
async def test_with_temp_file():
|
||||
"""测试使用临时目录中的文件名"""
|
||||
print("\n=== 测试使用临时目录中的文件名 ===")
|
||||
|
||||
# 确保临时目录存在
|
||||
os.makedirs(material_config.TEMP_ROOT, exist_ok=True)
|
||||
|
||||
# 在临时目录中创建文件
|
||||
temp_filename = "test_structure.cif"
|
||||
temp_filepath = os.path.join(material_config.TEMP_ROOT, temp_filename)
|
||||
|
||||
with open(temp_filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(SAMPLE_CIF)
|
||||
|
||||
try:
|
||||
# 只传递文件名,而不是完整路径
|
||||
result = await optimize_crystal_structure(
|
||||
file_name_or_content_string=temp_filename,
|
||||
format_type="auto",
|
||||
optimization_level="quick"
|
||||
)
|
||||
print(result)
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_filepath):
|
||||
os.unlink(temp_filepath)
|
||||
|
||||
async def test_auto_format():
|
||||
"""测试自动格式检测"""
|
||||
print("\n=== 测试自动格式检测 ===")
|
||||
result = await optimize_crystal_structure(
|
||||
file_name_or_content_string=SAMPLE_CIF,
|
||||
format_type="auto"
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def main():
|
||||
"""运行所有测试"""
|
||||
print("测试改进后的optimize_crystal_structure函数")
|
||||
|
||||
await test_with_content()
|
||||
await test_with_file()
|
||||
await test_with_temp_file()
|
||||
await test_auto_format()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
521
test_tools/multi_round_conversation.py
Normal file
521
test_tools/multi_round_conversation.py
Normal file
@@ -0,0 +1,521 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}]
|
||||
# 初始化rich控制台
|
||||
console = Console()
|
||||
|
||||
# 获取工具模式和映射
|
||||
tools_schemas = get_domain_tool_schemas(["material", 'general'])
|
||||
tool_map = get_domain_tools(["material", 'general'])
|
||||
|
||||
# API配置
|
||||
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url = "http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
def get_t1_response(messages):
|
||||
"""获取T1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]正在调用MARS-T1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
reasoning_content = choice.message.content
|
||||
|
||||
tool_calls_list = []
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
|
||||
return reasoning_content, tool_calls_list
|
||||
|
||||
async def execute_tool(tool_name, tool_arguments):
|
||||
"""执行工具调用"""
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
def get_all_tool_calls_results(tool_calls_list):
|
||||
"""获取所有工具调用的结果"""
|
||||
all_results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold green]正在执行工具调用..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
task = progress.add_task("", total=len(tool_calls_list))
|
||||
|
||||
for tool_call in tool_calls_list:
|
||||
tool_name = tool_call['function']['name']
|
||||
tool_arguments = tool_call['function']['arguments']
|
||||
|
||||
# 显示当前执行的工具
|
||||
progress.update(task, description=f"执行 {tool_name}")
|
||||
|
||||
result = asyncio.run(execute_tool(tool_name, tool_arguments))
|
||||
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
|
||||
all_results.append(result_str)
|
||||
|
||||
# 更新进度
|
||||
progress.update(task, advance=1)
|
||||
|
||||
return all_results
|
||||
|
||||
def get_response_from_r1(messages):
|
||||
"""获取R1模型的响应"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold purple]正在调用MARS-R1模型..."),
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("", total=None)
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-R1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
return choice.message.content
|
||||
|
||||
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
|
||||
"""显示单条消息"""
|
||||
title = role.capitalize()
|
||||
if model:
|
||||
title = f"{model} {title}"
|
||||
|
||||
if role == "user":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[{title_style}]{title}[/{title_style}]",
|
||||
border_style=border_style,
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-T1":
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold yellow]{title}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
elif role == "tool":
|
||||
# 创建一个表格来显示工具调用结果
|
||||
table = Table(box=box.ROUNDED, expand=False, show_header=False)
|
||||
table.add_column("内容", style="green")
|
||||
|
||||
# 分割工具调用结果并添加到表格
|
||||
results = content.split("\n")
|
||||
for result in results:
|
||||
table.add_row(result)
|
||||
|
||||
console.print(Panel(
|
||||
table,
|
||||
title=f"[bold green]{title}[/bold green]",
|
||||
border_style="green",
|
||||
expand=False
|
||||
))
|
||||
elif role == "assistant" and model == "MARS-R1":
|
||||
try:
|
||||
# 尝试将内容解析为Markdown
|
||||
md = Markdown(content)
|
||||
console.print(Panel(
|
||||
md,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
except:
|
||||
# 如果解析失败,直接显示文本
|
||||
console.print(Panel(
|
||||
content,
|
||||
title=f"[bold purple]{title}[/bold purple]",
|
||||
border_style="purple",
|
||||
expand=False
|
||||
))
|
||||
|
||||
def process_conversation_round(user_input, conversation_history=None):
|
||||
"""处理一轮对话,返回更新后的对话历史"""
|
||||
if conversation_history is None:
|
||||
conversation_history = []
|
||||
|
||||
# 添加用户消息到外部历史
|
||||
conversation_history.append({
|
||||
"role": "user",
|
||||
"content": user_input
|
||||
})
|
||||
|
||||
# 显示用户消息
|
||||
display_message("user", user_input)
|
||||
|
||||
# 内部循环变量
|
||||
max_iterations = 3 # 防止无限循环
|
||||
iterations = 0
|
||||
|
||||
# 分别管理T1和R1的对话历史
|
||||
t1_messages = []
|
||||
r1_messages = []
|
||||
|
||||
# 初始化T1消息历史(从外部历史中提取用户和助手消息)
|
||||
for msg in conversation_history:
|
||||
if msg["role"] in ["user", "assistant"]:
|
||||
t1_messages.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"]
|
||||
})
|
||||
|
||||
# 当前问题(初始为用户输入)
|
||||
current_question = user_input
|
||||
|
||||
while iterations < max_iterations:
|
||||
iterations += 1
|
||||
|
||||
# 如果不是第一次迭代,添加R1生成的后续问题作为新的用户消息
|
||||
if iterations > 1:
|
||||
# 显示后续问题
|
||||
display_message("user", f"[后续问题] {current_question}")
|
||||
|
||||
# 添加到T1消息历史
|
||||
t1_messages.append({
|
||||
"role": "user",
|
||||
"content": current_question
|
||||
})
|
||||
|
||||
# 获取T1模型的响应
|
||||
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
|
||||
|
||||
# 显示T1推理
|
||||
display_message("assistant", reasoning_content, model="MARS-T1")
|
||||
|
||||
# 添加T1的回答到外部历史
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-T1"
|
||||
})
|
||||
|
||||
# 添加T1的回答到T1消息历史
|
||||
t1_messages.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content
|
||||
})
|
||||
|
||||
# 如果没有工具调用,使用T1的推理作为最终答案
|
||||
if not tool_calls_list:
|
||||
# 添加相同的回答作为R1的回答(因为没有工具调用)
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_content,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
display_message("assistant", reasoning_content, model="MARS-R1")
|
||||
break
|
||||
|
||||
# 执行工具调用并获取结果
|
||||
tool_call_results = get_all_tool_calls_results(tool_calls_list)
|
||||
tool_call_results_str = "\n".join(tool_call_results)
|
||||
|
||||
# 添加工具调用结果到外部历史
|
||||
conversation_history.append({
|
||||
"role": "tool",
|
||||
"content": tool_call_results_str
|
||||
})
|
||||
|
||||
# 显示工具调用结果
|
||||
display_message("tool", tool_call_results_str)
|
||||
|
||||
# 重置R1消息历史(每次迭代都重新构建)
|
||||
r1_messages = []
|
||||
|
||||
# 添加系统消息,指导R1如何处理信息
|
||||
r1_messages.append({
|
||||
"role": "system",
|
||||
"content": """你是一个能够分析工具调用结果并回答问题的助手。
|
||||
请分析提供的信息,并执行以下操作之一:
|
||||
1. 如果你能够基于提供的工具调用信息直接回答原始问题,请提供完整的回答。
|
||||
2. 如果目前的工具调用信息不足以让你回答原始问题,请明确说明缺少哪些信息,并生成一个新的问题来获取这些信息。
|
||||
新问题格式:<FOLLOW_UP_QUESTION>你的问题</FOLLOW_UP_QUESTION>
|
||||
|
||||
注意:如果你生成了后续问题,系统将自动将其发送给工具调用模型以获取更多信息。"""
|
||||
})
|
||||
|
||||
# 构建R1的用户消息,包含原始问题、工具调用信息和结果
|
||||
r1_user_message = f"""# 原始问题
|
||||
{user_input}
|
||||
|
||||
# 工具调用信息
|
||||
{reasoning_content}
|
||||
|
||||
# 工具调用结果
|
||||
{tool_call_results_str}"""
|
||||
|
||||
# 如果有后续问题,添加到R1用户消息
|
||||
if iterations > 1:
|
||||
r1_user_message += f"\n\n# 后续问题\n{current_question}"
|
||||
|
||||
# 添加构建好的用户消息
|
||||
r1_messages.append({
|
||||
"role": "user",
|
||||
"content": r1_user_message
|
||||
})
|
||||
|
||||
# 获取R1模型的响应
|
||||
r1_response = get_response_from_r1(r1_messages)
|
||||
|
||||
# 显示R1回答
|
||||
display_message("assistant", r1_response, model="MARS-R1")
|
||||
|
||||
# 检查R1是否生成了后续问题
|
||||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', r1_response, re.DOTALL)
|
||||
|
||||
if follow_up_match:
|
||||
# 提取后续问题
|
||||
follow_up_question = follow_up_match.group(1).strip()
|
||||
|
||||
# 将后续问题作为新的当前问题
|
||||
current_question = follow_up_question
|
||||
|
||||
# 添加R1的回答到外部历史(不包括后续问题标记)
|
||||
clean_response = r1_response.replace(follow_up_match.group(0), "")
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": clean_response,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
|
||||
# 继续循环,使用新问题调用T1
|
||||
else:
|
||||
# R1能够回答问题,添加回答到历史并结束循环
|
||||
conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": r1_response,
|
||||
"model": "MARS-R1"
|
||||
})
|
||||
break
|
||||
|
||||
return conversation_history
|
||||
|
||||
def run_demo():
|
||||
"""运行演示,使用初始消息作为第一个用户问题"""
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# 获取初始用户问题
|
||||
initial_user_input = initial_message[0]["content"]
|
||||
|
||||
# 处理第一轮对话
|
||||
conversation_history = process_conversation_round(initial_user_input)
|
||||
|
||||
# 检查R1是否生成了后续问题并自动处理
|
||||
auto_process_follow_up_questions(conversation_history)
|
||||
|
||||
# 多轮对话循环
|
||||
while True:
|
||||
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit' 或 'quit' 退出[/bold cyan]")
|
||||
user_input = input("> ")
|
||||
|
||||
# 检查是否退出
|
||||
if user_input.lower() in ['exit', 'quit', '退出']:
|
||||
console.print("[bold cyan]演示结束,再见![/bold cyan]")
|
||||
break
|
||||
|
||||
# 处理用户输入
|
||||
conversation_history = process_conversation_round(user_input, conversation_history)
|
||||
|
||||
# 检查R1是否生成了后续问题并自动处理
|
||||
auto_process_follow_up_questions(conversation_history)
|
||||
|
||||
def auto_process_follow_up_questions(conversation_history):
|
||||
"""自动处理R1生成的后续问题"""
|
||||
# 检查最后一条消息是否是R1的回答
|
||||
if not conversation_history or len(conversation_history) == 0:
|
||||
return
|
||||
|
||||
last_message = conversation_history[-1]
|
||||
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
|
||||
# 检查是否包含后续问题
|
||||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
|
||||
if follow_up_match:
|
||||
# 提取后续问题
|
||||
follow_up_question = follow_up_match.group(1).strip()
|
||||
|
||||
# 显示检测到的后续问题
|
||||
console.print(Panel(
|
||||
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
|
||||
# 自动处理后续问题
|
||||
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
|
||||
|
||||
# 递归处理后续问题,直到没有更多后续问题或达到最大迭代次数
|
||||
max_auto_iterations = 3
|
||||
current_iterations = 0
|
||||
|
||||
while current_iterations < max_auto_iterations:
|
||||
current_iterations += 1
|
||||
|
||||
# 处理后续问题
|
||||
conversation_history = process_conversation_round(follow_up_question, conversation_history)
|
||||
|
||||
# 检查是否还有后续问题
|
||||
if len(conversation_history) > 0:
|
||||
last_message = conversation_history[-1]
|
||||
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
|
||||
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
|
||||
if follow_up_match:
|
||||
# 提取后续问题
|
||||
follow_up_question = follow_up_match.group(1).strip()
|
||||
|
||||
# 显示检测到的后续问题
|
||||
console.print(Panel(
|
||||
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
|
||||
border_style="yellow",
|
||||
expand=False
|
||||
))
|
||||
|
||||
# 自动处理后续问题
|
||||
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
|
||||
continue
|
||||
|
||||
# 如果没有更多后续问题,退出循环
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
run_demo()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
|
||||
import traceback
|
||||
console.print(traceback.format_exc())
|
||||
83
test_tools/test.py
Normal file
83
test_tools/test.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Test script for domain-specific tool retrieval functions
|
||||
"""
|
||||
|
||||
import json
|
||||
from pprint import pprint
|
||||
import sys
|
||||
sys.path.append("/home/ubuntu/sas0/lzy/multi_mcp_server")
|
||||
from sci_mcp.core.llm_tools import get_domain_tools, get_domain_tool_schemas, get_all_tool_schemas,get_all_tools
|
||||
|
||||
def test_get_domain_tools():
|
||||
"""Test retrieving tools from specific domains"""
|
||||
print("\n=== Testing get_domain_tools ===")
|
||||
|
||||
# Test with material and general domains
|
||||
domains = ['material', 'general']
|
||||
print(f"Getting tools for domains: {domains}")
|
||||
|
||||
domain_tools = get_domain_tools(domains)
|
||||
|
||||
# Print results
|
||||
for domain, tools in domain_tools.items():
|
||||
print(f"\nDomain: {domain}")
|
||||
print(f"Number of tools: {len(tools)}")
|
||||
print("Tool names:")
|
||||
for tool_name in tools.keys():
|
||||
print(f" - {tool_name}")
|
||||
|
||||
def test_get_domain_tool_schemas():
|
||||
"""Test retrieving tool schemas from specific domains"""
|
||||
print("\n=== Testing get_domain_tool_schemas (OpenAI format) ===")
|
||||
|
||||
# Test with material and general domains
|
||||
domains = ['material', 'general']
|
||||
print(f"Getting tool schemas for domains: {domains}")
|
||||
|
||||
domain_schemas = get_domain_tool_schemas(domains)
|
||||
|
||||
# Print results
|
||||
for domain, schemas in domain_schemas.items():
|
||||
print(f"\nDomain: {domain}")
|
||||
print(f"Number of schemas: {len(schemas)}")
|
||||
print("Tool names in schemas:")
|
||||
for schema in schemas:
|
||||
print(f" - {schema['function']['name']}")
|
||||
|
||||
|
||||
def test_all_tool_schemas():
|
||||
"""Test retrieving all tool schemas in both formats"""
|
||||
print("\n=== Testing get_all_tool_schemas ===")
|
||||
|
||||
# OpenAI format
|
||||
print("Getting all tool schemas in OpenAI format")
|
||||
openai_schemas = get_all_tool_schemas()
|
||||
print(f"Number of schemas: {len(openai_schemas)}")
|
||||
print("Tool names:")
|
||||
for schema in openai_schemas:
|
||||
print(f" - {schema['function']['name']}")
|
||||
|
||||
# MCP format
|
||||
print("\nGetting all tool schemas in MCP format")
|
||||
mcp_schemas = get_all_tool_schemas(use_mcp_format=True)
|
||||
print(f"Number of schemas: {len(mcp_schemas)}")
|
||||
print("Tool names:")
|
||||
for schema in mcp_schemas:
|
||||
print(f" - {schema['name']}")
|
||||
|
||||
def main():
|
||||
"""Main function to run tests"""
|
||||
print("Testing domain-specific tool retrieval functions")
|
||||
|
||||
#test_get_domain_tools()
|
||||
test_get_domain_tool_schemas()
|
||||
|
||||
test_all_tool_schemas()
|
||||
|
||||
if __name__ == "__main__":
|
||||
#print(get_all_tool_schemas())
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
console.print("Testing domain-specific tool retrieval functions", style="bold green")
|
||||
console.print(get_domain_tool_schemas(['chemistry']))
|
||||
|
||||
253
test_tools/test_mars_t1.py
Normal file
253
test_tools/test_mars_t1.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp_server.core.llm_tools import set_llm_context, clear_llm_context
|
||||
|
||||
# 创建Rich控制台对象
|
||||
console = Console()
|
||||
|
||||
# 定义分隔符样式
|
||||
def print_separator(title=""):
|
||||
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
|
||||
|
||||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
from sci_mcp import *
|
||||
|
||||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||||
tool_map = get_domain_tools(["material",'general'])
|
||||
|
||||
|
||||
# 打印消息的函数
|
||||
def print_message(message):
|
||||
# 处理不同类型的消息对象
|
||||
if hasattr(message, 'role'): # ChatCompletionMessage 对象
|
||||
role = message.role
|
||||
content = message.content if hasattr(message, 'content') else ""
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = None
|
||||
if role == "tool" and hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
else: # 字典类型
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = message.get("name") if role == "tool" else None
|
||||
|
||||
# 根据角色选择不同的颜色
|
||||
role_colors = {
|
||||
"system": "bright_blue",
|
||||
"user": "green",
|
||||
"assistant": "yellow",
|
||||
"tool": "bright_red"
|
||||
}
|
||||
color = role_colors.get(role, "white")
|
||||
|
||||
# 创建富文本面板
|
||||
text = Text()
|
||||
|
||||
# 如果是工具消息,添加工具名称
|
||||
if role == "tool" and tool_name:
|
||||
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
|
||||
else:
|
||||
text.append(f"{role}: ", style=f"bold {color}")
|
||||
|
||||
text.append(str(content))
|
||||
console.print(Panel(text, border_style=color))
|
||||
|
||||
messages = [
|
||||
{"role": "system",
|
||||
"content": "You are MARS-R1, a professional assistant in materials science. You first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> <structured_answer> </structured_answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here <structured_answer> structured answer here <structured_answer> </answer>'"},
|
||||
{"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}
|
||||
]
|
||||
#how to synthesize CsPbBr3 at room temperature
|
||||
#
|
||||
# 打印初始消息
|
||||
print_separator("初始消息")
|
||||
for message in messages:
|
||||
print_message(message)
|
||||
finish_reason = None
|
||||
|
||||
async def execute_tool(tool_name,tool_arguments):
|
||||
# 设置LLM调用上下文标记
|
||||
set_llm_context(True)
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
# if func_name=='generate_material':
|
||||
# print("xxxxx",result)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
clear_llm_context()
|
||||
|
||||
while finish_reason is None or finish_reason == "tool_calls":
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-R1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas, # <-- 我们通过 tools 参数,将定义好的 tools 提交给 Kimi 大模型
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
|
||||
# 打印assistant消息
|
||||
print_separator("Assistant消息")
|
||||
print_message(choice.message)
|
||||
|
||||
# 将ChatCompletionMessage对象转换为字典
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": choice.message.content if hasattr(choice.message, 'content') else None
|
||||
}
|
||||
|
||||
# 如果有工具调用,添加到字典中
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
# 将tool_calls对象转换为字典列表
|
||||
tool_calls_list = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
assistant_message["tool_calls"] = tool_calls_list
|
||||
|
||||
# 添加消息到上下文
|
||||
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
|
||||
|
||||
# 打印工具调用信息
|
||||
print_separator("工具调用")
|
||||
for tool_call in choice.message.tool_calls:
|
||||
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
|
||||
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
|
||||
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
|
||||
console.print("")
|
||||
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object,我们需要使用 json.loads 反序列化一下
|
||||
try:
|
||||
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
|
||||
except Exception as e:
|
||||
tool_result=f'工具调用失败{e}'
|
||||
# 构造工具响应消息
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
|
||||
}
|
||||
|
||||
# 打印工具响应
|
||||
print_separator(f"工具响应: {tool_call_name}")
|
||||
print_message(tool_message)
|
||||
|
||||
# 添加消息到上下文
|
||||
messages.append(tool_message)
|
||||
|
||||
# 打印最终响应
|
||||
if choice.message.content:
|
||||
print_separator("最终响应")
|
||||
console.print(Panel(choice.message.content, border_style="green"))
|
||||
177
test_tools/test_mars_t1_.py
Normal file
177
test_tools/test_mars_t1_.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import asyncio
|
||||
import json
|
||||
from openai import OpenAI
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||||
tool_map = get_domain_tools(["material",'general'])
|
||||
|
||||
|
||||
|
||||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
messages = [ {"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}]
|
||||
|
||||
def get_t1_response(messages):
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
reasoning_content = choice.message.content
|
||||
#print("Reasoning content:", reasoning_content)
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
tool_calls_list = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
return reasoning_content, tool_calls_list
|
||||
|
||||
async def execute_tool(tool_name,tool_arguments):
|
||||
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
# if func_name=='generate_material':
|
||||
# print("xxxxx",result)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
def get_all_tool_calls_results(tool_calls_list):
|
||||
all_results = []
|
||||
for tool_call in tool_calls_list:
|
||||
tool_name = tool_call['function']['name']
|
||||
tool_arguments = tool_call['function']['arguments']
|
||||
result = asyncio.run(execute_tool(tool_name,tool_arguments))
|
||||
result_str = f"[{tool_name} content begin]\n"+result+f"\n[{tool_name} content end]\n"
|
||||
all_results.append(result_str)
|
||||
|
||||
return all_results
|
||||
def get_response_from_r1(messages):
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-R1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
return choice.message.content
|
||||
print("R1 RESPONSE:", choice.message.content)
|
||||
if __name__ == '__main__':
|
||||
reasoning_content, tool_calls_list=get_t1_response(messages)
|
||||
print("Reasoning content:", reasoning_content)
|
||||
tool_call_results=get_all_tool_calls_results(tool_calls_list)
|
||||
tool_call_results_str = "\n".join(tool_call_results)
|
||||
# for tool_call in tool_call_results:
|
||||
# print(tool_call)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": f"# 信息如下:{tool_call_results_str}# 问题如下 {messages[0]['content']}"
|
||||
}
|
||||
print("user_message_for_r1:", user_message)
|
||||
get_response_from_r1([user_message])
|
||||
339
test_tools/test_mars_t1_r1.py
Normal file
339
test_tools/test_mars_t1_r1.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sci_mcp import *
|
||||
|
||||
# 创建Rich控制台对象
|
||||
console = Console()
|
||||
|
||||
# 创建一个列表来存储工具调用结果
|
||||
tool_results = []
|
||||
|
||||
# 定义分隔符样式
|
||||
def print_separator(title=""):
|
||||
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
|
||||
|
||||
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
|
||||
base_url="http://gpustack.ddwtop.team/v1-openai"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
from sci_mcp import *
|
||||
|
||||
tools_schemas = get_domain_tool_schemas(["material",'general'])
|
||||
tool_map = get_domain_tools(["material",'general'])
|
||||
|
||||
|
||||
# 打印消息的函数
|
||||
def print_message(message):
|
||||
# 处理不同类型的消息对象
|
||||
if hasattr(message, 'role'): # ChatCompletionMessage 对象
|
||||
role = message.role
|
||||
content = message.content if hasattr(message, 'content') else ""
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = None
|
||||
if role == "tool" and hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
else: # 字典类型
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
# 如果是工具消息,获取工具名称
|
||||
tool_name = message.get("name") if role == "tool" else None
|
||||
|
||||
# 根据角色选择不同的颜色
|
||||
role_colors = {
|
||||
"system": "bright_blue",
|
||||
"user": "green",
|
||||
"assistant": "yellow",
|
||||
"tool": "bright_red"
|
||||
}
|
||||
color = role_colors.get(role, "white")
|
||||
|
||||
# 创建富文本面板
|
||||
text = Text()
|
||||
|
||||
# 如果是工具消息,添加工具名称
|
||||
if role == "tool" and tool_name:
|
||||
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
|
||||
else:
|
||||
text.append(f"{role}: ", style=f"bold {color}")
|
||||
|
||||
text.append(str(content))
|
||||
console.print(Panel(text, border_style=color))
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": """data_Ti4V
|
||||
_symmetry_space_group_name_H-M Fmmm
|
||||
_cell_length_a 3.18353600
|
||||
_cell_length_b 4.52677200
|
||||
_cell_length_c 22.74397000
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 69
|
||||
_chemical_formula_structural Ti4V
|
||||
_chemical_formula_sum 'Ti16 V4'
|
||||
_cell_volume 327.76657340
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
2 '-x, -y, -z'
|
||||
3 '-x, -y, z'
|
||||
4 'x, y, -z'
|
||||
5 'x, -y, -z'
|
||||
6 '-x, y, z'
|
||||
7 '-x, y, -z'
|
||||
8 'x, -y, z'
|
||||
9 'x+1/2, y, z+1/2'
|
||||
10 '-x+1/2, -y, -z+1/2'
|
||||
11 '-x+1/2, -y, z+1/2'
|
||||
12 'x+1/2, y, -z+1/2'
|
||||
13 'x+1/2, -y, -z+1/2'
|
||||
14 '-x+1/2, y, z+1/2'
|
||||
15 '-x+1/2, y, -z+1/2'
|
||||
16 'x+1/2, -y, z+1/2'
|
||||
17 'x+1/2, y+1/2, z'
|
||||
18 '-x+1/2, -y+1/2, -z'
|
||||
19 '-x+1/2, -y+1/2, z'
|
||||
20 'x+1/2, y+1/2, -z'
|
||||
21 'x+1/2, -y+1/2, -z'
|
||||
22 '-x+1/2, y+1/2, z'
|
||||
23 '-x+1/2, y+1/2, -z'
|
||||
24 'x+1/2, -y+1/2, z'
|
||||
25 'x, y+1/2, z+1/2'
|
||||
26 '-x, -y+1/2, -z+1/2'
|
||||
27 '-x, -y+1/2, z+1/2'
|
||||
28 'x, y+1/2, -z+1/2'
|
||||
29 'x, -y+1/2, -z+1/2'
|
||||
30 '-x, y+1/2, z+1/2'
|
||||
31 '-x, y+1/2, -z+1/2'
|
||||
32 'x, -y+1/2, z+1/2'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
|
||||
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
|
||||
V V2 4 0.00000000 0.00000000 0.00000000 1.0
|
||||
,根据上文提供的CIF文件,请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性,并用JSON格式回答。"""}
|
||||
]
|
||||
#how to synthesize CsPbBr3 at room temperature
|
||||
#
|
||||
# 打印初始消息
|
||||
print_separator("初始消息")
|
||||
for message in messages:
|
||||
print_message(message)
|
||||
finish_reason = None
|
||||
|
||||
async def execute_tool(tool_name,tool_arguments):
|
||||
# 设置LLM调用上下文标记
|
||||
|
||||
try:
|
||||
tool_func = tool_map[tool_name] # 获取工具函数
|
||||
arguments = {}
|
||||
if tool_arguments:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(tool_arguments, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = tool_arguments
|
||||
elif isinstance(tool_arguments, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(tool_arguments)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": tool_arguments}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
# if func_name=='generate_material':
|
||||
# print("xxxxx",result)
|
||||
return result
|
||||
finally:
|
||||
# 清除LLM调用上下文标记
|
||||
pass
|
||||
|
||||
# 定义一个函数来估算消息的token数量(粗略估计)
|
||||
def estimate_tokens(messages):
|
||||
# 简单估计:每个英文单词约1.3个token,每个中文字符约1个token
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "") if isinstance(msg, dict) else (msg.content if hasattr(msg, "content") else "")
|
||||
if content:
|
||||
# 粗略估计内容的token数
|
||||
total += len(content) * 1.3
|
||||
|
||||
# 估计工具调用的token数
|
||||
tool_calls = msg.get("tool_calls", []) if isinstance(msg, dict) else (msg.tool_calls if hasattr(msg, "tool_calls") else [])
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
args = tool_call.get("function", {}).get("arguments", "")
|
||||
total += len(args) * 1.3
|
||||
else:
|
||||
args = tool_call.function.arguments if hasattr(tool_call, "function") else ""
|
||||
total += len(args) * 1.3
|
||||
|
||||
return int(total)
|
||||
|
||||
# 管理消息历史,保持在token限制以内
|
||||
def manage_message_history(messages, max_tokens=7000):
|
||||
# 保留第一条消息(通常是系统消息或初始用户消息)
|
||||
if len(messages) <= 1:
|
||||
return messages
|
||||
|
||||
# 估算当前消息的token数
|
||||
current_tokens = estimate_tokens(messages)
|
||||
|
||||
# 如果当前token数已经接近限制,开始裁剪历史消息
|
||||
if current_tokens > max_tokens:
|
||||
# 保留第一条消息和最近的消息
|
||||
preserved_messages = [messages[0]]
|
||||
|
||||
# 从最新的消息开始添加,直到接近但不超过token限制
|
||||
temp_messages = []
|
||||
for msg in reversed(messages[1:]):
|
||||
temp_messages.insert(0, msg)
|
||||
if estimate_tokens([messages[0]] + temp_messages) > max_tokens:
|
||||
# 如果添加这条消息会超过限制,则停止添加
|
||||
temp_messages.pop(0)
|
||||
break
|
||||
|
||||
# 如果裁剪后的消息太少,至少保留最近的几条消息
|
||||
if len(temp_messages) < 4 and len(messages) > 4:
|
||||
temp_messages = messages[-4:]
|
||||
|
||||
return preserved_messages + temp_messages
|
||||
|
||||
return messages
|
||||
|
||||
while finish_reason is None or finish_reason == "tool_calls":
|
||||
# 在发送请求前管理消息历史
|
||||
managed_messages = manage_message_history(messages)
|
||||
if len(managed_messages) < len(messages):
|
||||
print_separator(f"消息历史已裁剪,从{len(messages)}条减少到{len(managed_messages)}条")
|
||||
messages = managed_messages
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="MARS-T1",
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
tools=tools_schemas,
|
||||
timeout=120,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
|
||||
# 打印assistant消息
|
||||
print_separator("Assistant消息")
|
||||
print_message(choice.message)
|
||||
|
||||
# 将ChatCompletionMessage对象转换为字典
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": choice.message.content if hasattr(choice.message, 'content') else None
|
||||
}
|
||||
|
||||
# 如果有工具调用,添加到字典中
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
|
||||
# 将tool_calls对象转换为字典列表
|
||||
tool_calls_list = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
tool_calls_list.append(tool_call_dict)
|
||||
assistant_message["tool_calls"] = tool_calls_list
|
||||
|
||||
# 添加消息到上下文
|
||||
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
|
||||
|
||||
# 打印工具调用信息
|
||||
print_separator("工具调用")
|
||||
for tool_call in choice.message.tool_calls:
|
||||
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
|
||||
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
|
||||
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
|
||||
console.print("")
|
||||
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object,我们需要使用 json.loads 反序列化一下
|
||||
try:
|
||||
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
|
||||
except Exception as e:
|
||||
tool_result=f'工具调用失败{e}'
|
||||
|
||||
# 将工具调用结果保存到单独的列表中,使用指定格式包裹
|
||||
formatted_result = f"[{tool_call_name} content begin]{tool_result}[{tool_call_name} content end]"
|
||||
tool_results.append({
|
||||
"tool_name": tool_call_name,
|
||||
"tool_id": tool_call.id,
|
||||
"formatted_result": formatted_result,
|
||||
"raw_result": tool_result
|
||||
})
|
||||
|
||||
# 打印保存的工具结果信息
|
||||
console.print(f"[bold green]已保存工具结果:[/] [cyan]{tool_call_name}[/]")
|
||||
|
||||
# 构造工具响应消息
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
|
||||
}
|
||||
|
||||
# 打印工具响应
|
||||
print_separator(f"工具响应: {tool_call_name}")
|
||||
print_message(tool_message)
|
||||
|
||||
# 添加消息到上下文
|
||||
#messages.append(tool_message)
|
||||
|
||||
# 打印最终响应
|
||||
if choice.message.content:
|
||||
print_separator("最终响应")
|
||||
console.print(Panel(choice.message.content, border_style="green"))
|
||||
|
||||
# 打印收集的所有工具调用结果
|
||||
if tool_results:
|
||||
print_separator("所有工具调用结果")
|
||||
console.print(f"[bold cyan]共收集了 {len(tool_results)} 个工具调用结果[/]")
|
||||
|
||||
# 将所有格式化的结果写入文件
|
||||
with open("tool_results.txt", "w", encoding="utf-8") as f:
|
||||
for result in tool_results:
|
||||
f.write(f"{result['formatted_result']}\n\n")
|
||||
|
||||
console.print(f"[bold green]所有工具调用结果已保存到 tool_results.txt 文件中[/]")
|
||||
Reference in New Issue
Block a user