初次提交

This commit is contained in:
lzy
2025-05-09 14:16:33 +08:00
commit 3a50afeec4
56 changed files with 9224 additions and 0 deletions

11
.gitignore vendored Normal file
View File

@@ -0,0 +1,11 @@
__pycache__/
# 忽略所有.pth和.ckqt文件包括子目录
**/*.pth
**/*.ckpt
/reference
/sci_mcp_server
/temp
/sci_mcp/material_mcp/support/pretrained_models
/sci_mcp/material_mcp/mattergen_gen/mattergen/*

15
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,15 @@
{
// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python 调试程序: 当前文件",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}

View File

@@ -0,0 +1,11 @@
{
"folders": [
{
"path": "../../.."
},
{
"path": "../../../../SciToolAgent/ToolsAgent/ToolsFuns"
}
],
"settings": {}
}

315
requirements.txt Normal file
View File

@@ -0,0 +1,315 @@
absl-py==2.2.2
aiohappyeyeballs==2.6.1
aiohttp==3.11.16
aioitertools==0.12.0
aiosignal==1.3.2
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.9.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
ase==3.25.0
astroid==3.3.9
asttokens==3.0.0
async-lru==2.0.5
async-timeout==4.0.3
atomate2==0.0.18
attrs==25.3.0
autopep8==2.3.2
azure-core==1.33.0
azure-identity==1.21.0
azure-storage-blob==12.25.1
babel==2.17.0
bcrypt==4.3.0
beautifulsoup4==4.13.4
bleach==6.2.0
boto3==1.37.35
botocore==1.37.35
cachetools==5.5.2
certifi==2025.1.31
cffi==1.17.1
charset-normalizer==3.4.1
click==8.1.8
cloudpickle==3.1.1
colorama==0.4.6
comm==0.2.2
contextlib2==21.6.0
contourpy==1.3.2
cryptography==44.0.2
custodian==2025.4.14
cycler==0.12.1
dataclasses-json==0.6.7
debugpy==1.8.14
decorator==5.2.1
defusedxml==0.7.1
deprecated==1.2.18
dgl==2.4.0+cu124
dill==0.4.0
distro==1.9.0
dnspython==2.7.0
docker-pycreds==0.4.0
docstring-parser==0.16
e3nn==0.5.6
emmet-core==0.84.6
exceptiongroup==1.2.2
executing==2.2.0
fairchem-core==1.9.0
fastjsonschema==2.21.1
filelock==3.18.0
fire==0.7.0
fonttools==4.57.0
fqdn==1.5.1
frozenlist==1.5.0
fsspec==2025.3.2
gitdb==4.0.12
gitpython==3.1.44
greenlet==3.2.0
grpcio==1.71.0
h11==0.14.0
h5py==3.13.0
httpcore==1.0.8
httpx==0.28.1
httpx-sse==0.4.0
huggingface-hub==0.30.2
hydra-core==1.3.1
hydra-joblib-launcher==1.1.5
idna==3.10
iniconfig==2.1.0
ipykernel==6.29.5
ipython==8.35.0
isodate==0.7.2
isoduration==20.11.0
isort==6.0.1
jedi==0.19.2
jinja2==3.1.6
jiter==0.9.0
jmespath==1.0.1
jobflow==0.1.19
joblib==1.4.2
json5==0.12.0
jsonlines==4.0.0
jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-client==8.6.3
jupyter-core==5.7.2
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter-server==2.15.0
jupyter-server-terminals==0.5.3
jupyterlab==4.4.0
jupyterlab-pygments==0.3.0
jupyterlab-server==2.27.3
kiwisolver==1.4.8
langchain==0.3.23
langchain-community==0.3.21
langchain-core==0.3.52
langchain-text-splitters==0.3.8
langsmith==0.3.31
latexcodec==3.0.0
lightning==2.5.1
lightning-utilities==0.14.3
llvmlite==0.44.0
lmdb==1.6.2
loguru==0.7.3
lxml==5.3.2
maggma==0.71.5
markdown==3.8
markdown-it-py==3.0.0
markupsafe==3.0.2
marshmallow==3.26.1
matgl==1.2.6
matplotlib==3.8.4
matplotlib-inline==0.1.7
matscipy==1.1.1
-e file:///home/ubuntu/sas0/lzy/mars-mcp/mattergen
mattersim==1.1.2
mccabe==0.7.0
mcp==1.6.0
mdurl==0.1.2
mistune==3.1.3
mongomock==4.3.0
monty==2025.3.3
mp-api==0.45.3
mpmath==1.3.0
msal==1.32.0
msal-extensions==1.3.1
msgpack==1.1.0
multidict==6.4.3
multiprocess==0.70.18
mypy-extensions==1.0.0
mysql-connector==2.2.9
narwhals==1.35.0
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
notebook==7.4.0
notebook-shim==0.2.4
numba==0.61.2
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cublas-cu12==12.4.2.65
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-cupti-cu12==12.4.99
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-nvrtc-cu12==12.4.99
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cuda-runtime-cu12==12.4.99
nvidia-cudnn-cu11==8.7.0.84
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu11==10.9.0.58
nvidia-cufft-cu12==11.2.0.44
nvidia-curand-cu11==10.3.0.86
nvidia-curand-cu12==10.3.5.119
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusolver-cu12==11.6.0.99
nvidia-cusparse-cu11==11.7.5.86
nvidia-cusparse-cu12==12.3.0.142
nvidia-nccl-cu11==2.19.3
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu11==11.8.86
nvidia-nvtx-cu12==12.4.99
omegaconf==2.3.0
openai==1.75.0
opt-einsum==3.4.0
opt-einsum-fx==0.1.4
orjson==3.10.16
overrides==7.7.0
packaging==24.2
palettable==3.3.3
pandarallel==1.6.5
pandas==2.2.3
pandocfilters==1.5.1
paramiko==3.5.1
parso==0.8.4
pathos==0.3.3
pexpect==4.9.0
phonopy==2.38.1
pillow==11.2.1
pip==25.0.1
platformdirs==4.3.7
plotly==6.0.1
pluggy==1.5.0
pox==0.3.6
ppft==1.7.7
prometheus-client==0.21.1
prompt-toolkit==3.0.51
propcache==0.3.1
protobuf==5.29.4
psutil==7.0.0
ptyprocess==0.7.0
pubchempy==1.0.4
pure-eval==0.2.3
pybtex==0.24.0
pycodestyle==2.13.0
pycparser==2.22
pydantic==2.11.3
pydantic-core==2.33.1
pydantic-settings==2.8.1
pydash==8.0.5
pyg-lib==0.4.0+pt24cu124
pygments==2.19.1
pyjwt==2.10.1
pylint==3.3.6
pymatgen==2024.10.29
pymongo==4.10.1
pynacl==1.5.0
pyparsing==3.2.3
pyre-extensions==0.0.32
pytest==8.3.5
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-json-logger==3.3.0
pytorch-lightning==2.0.6
pytz==2025.2
pyyaml==6.0.2
pyzmq==26.4.0
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
requests-toolbelt==1.0.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==14.0.0
rpds-py==0.24.0
ruamel-yaml==0.18.10
ruamel-yaml-clib==0.2.12
s3transfer==0.11.4
scikit-learn==1.6.1
scipy==1.15.2
seaborn==0.13.2
seekpath==2.1.0
send2trash==1.8.3
sentinels==1.0.0
sentry-sdk==2.26.1
setproctitle==1.3.5
setuptools==78.1.0
shellingham==1.5.4
six==1.17.0
smact==3.1.0
smart-open==7.1.0
smmap==5.0.2
sniffio==1.3.1
soupsieve==2.6
spglib==2.6.0
sqlalchemy==2.0.40
sse-starlette==2.2.1
sshtunnel==0.4.0
stack-data==0.6.3
starlette==0.46.2
submitit==1.5.2
symfc==1.3.4
sympy==1.13.3
tabulate==0.9.0
tenacity==9.1.2
tensorboard==2.19.0
tensorboard-data-server==0.7.2
termcolor==3.0.1
terminado==0.18.1
threadpoolctl==3.6.0
tiktoken==0.9.0
tinycss2==1.4.0
tomli==2.2.1
tomlkit==0.13.2
torch==2.4.0+cu124
torch-cluster==1.6.3+pt24cu124
torch-ema==0.3
torch-geometric==2.6.1
torch-runstats==0.2.0
torch-scatter==2.1.2+pt24cu124
torch-sparse==0.6.18+pt24cu124
torch-spline-conv==1.2.2+pt24cu124
torchaudio==2.4.0+cu124
torchdata==0.8.0
torchmetrics==1.7.1
torchtnt==0.2.4
torchvision==0.19.0+cu124
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
triton==3.0.0
typer==0.15.2
types-python-dateutil==2.9.0.20241206
typing-extensions==4.13.2
typing-inspect==0.9.0
typing-inspection==0.4.0
tzdata==2025.2
uncertainties==3.2.2
uri-template==1.3.0
urllib3==2.4.0
uvicorn==0.34.1
wandb==0.19.9
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
werkzeug==3.1.3
wheel==0.45.1
wrapt==1.17.2
yarl==1.20.0
zstandard==0.23.0

38
sci_mcp/__init__.py Normal file
View File

@@ -0,0 +1,38 @@
from .core.llm_tools import llm_tool,get_all_tools, get_all_tool_schemas, get_domain_tools, get_domain_tool_schemas
#general_mcp
#from .general_mcp.searxng_query.searxng_query_tools import search_online
# #material_mcp
from .material_mcp.mp_query.mp_query_tools import search_crystal_structures_from_materials_project,search_material_property_from_materials_project
from .material_mcp.oqmd_query.oqmd_query_tools import query_material_from_OQMD
from .material_mcp.knowledge_base_query.retrieval_from_knowledge_base_tools import retrieval_from_knowledge_base
from .material_mcp.mattersim_pred.property_pred_tools import predict_properties_MatterSim
from .material_mcp.mattergen_gen.material_gen_tools import generate_material_MatterGen
from .material_mcp.fairchem_structure_opt.structure_opt_tools import optimize_crystal_structure_FairChem
from .material_mcp.pymatgen_cal.pymatgen_cal_tools import calculate_density_Pymatgen,get_element_composition_Pymatgen,calculate_symmetry_Pymatgen
from .material_mcp.matgl_tools.matgl_tools import relax_crystal_structure_M3GNet,predict_formation_energy_M3GNet,run_molecular_dynamics_M3GNet,calculate_single_point_energy_M3GNet
#chemistry_mcp
from .chemistry_mcp.pubchem_tools.pubchem_tools import search_advanced_pubchem
from .chemistry_mcp.rdkit_tools.rdkit_tools import (
calculate_molecular_properties_rdkit,
calculate_drug_likeness_rdkit,
calculate_topological_descriptors_rdkit,
generate_molecular_fingerprints_rdkit,
calculate_molecular_similarity_rdkit,
analyze_molecular_structure_rdkit,
generate_molecular_conformer_rdkit,
identify_scaffolds_rdkit,
convert_between_chemical_formats_rdkit,
standardize_molecule_rdkit,
enumerate_stereoisomers_rdkit,
perform_substructure_search_rdkit
)
from .chemistry_mcp.rxn_tools.rxn_tools import (
predict_reaction_outcome_rxn,
predict_reaction_topn_rxn,
predict_reaction_properties_rxn,
extract_reaction_actions_rxn
)
__all__ = ["llm_tool", "get_all_tools", "get_all_tool_schemas", "get_domain_tools", "get_domain_tool_schemas"]

View File

@@ -0,0 +1,9 @@
"""
Chemistry MCP Module
This module provides tools for chemistry-related operations.
"""
from .pubchem_tools import search_advanced_pubchem
__all__ = ["search_advanced_pubchem"]

View File

@@ -0,0 +1,9 @@
"""
PubChem Tools Module
This module provides tools for accessing and processing chemical data from PubChem.
"""
from .pubchem_tools import search_advanced_pubchem
__all__ = ["search_advanced_pubchem"]

View File

@@ -0,0 +1,325 @@
"""
PubChem Tools Module
This module provides tools for searching and retrieving chemical compound information
from the PubChem database using the PubChemPy library.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Dict, List, Union, Optional, Any
import pubchempy as pcp
from ...core.llm_tools import llm_tool
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def compound_to_dict(compound: pcp.Compound) -> Dict[str, Any]:
"""
Convert a PubChem compound to a structured dictionary with relevant information.
Args:
compound: PubChem compound object
Returns:
Dictionary containing organized compound information
"""
if not compound:
return {}
# Basic information
result = {
"basic_info": {
"cid": compound.cid,
"iupac_name": compound.iupac_name,
"molecular_formula": compound.molecular_formula,
"molecular_weight": compound.molecular_weight,
"canonical_smiles": compound.canonical_smiles,
"isomeric_smiles": compound.isomeric_smiles,
},
"identifiers": {
"inchi": compound.inchi,
"inchikey": compound.inchikey,
},
"physical_properties": {
"xlogp": compound.xlogp,
"exact_mass": compound.exact_mass,
"monoisotopic_mass": compound.monoisotopic_mass,
"tpsa": compound.tpsa,
"complexity": compound.complexity,
"charge": compound.charge,
},
"molecular_features": {
"h_bond_donor_count": compound.h_bond_donor_count,
"h_bond_acceptor_count": compound.h_bond_acceptor_count,
"rotatable_bond_count": compound.rotatable_bond_count,
"heavy_atom_count": compound.heavy_atom_count,
"atom_stereo_count": compound.atom_stereo_count,
"defined_atom_stereo_count": compound.defined_atom_stereo_count,
"undefined_atom_stereo_count": compound.undefined_atom_stereo_count,
"bond_stereo_count": compound.bond_stereo_count,
"defined_bond_stereo_count": compound.defined_bond_stereo_count,
"undefined_bond_stereo_count": compound.undefined_bond_stereo_count,
"covalent_unit_count": compound.covalent_unit_count,
}
}
# Add synonyms if available
if hasattr(compound, 'synonyms') and compound.synonyms:
result["alternative_names"] = {
"synonyms": compound.synonyms[:10] if len(compound.synonyms) > 10 else compound.synonyms
}
return result
async def _search_by_name(name: str, max_results: int = 5) -> List[Dict[str, Any]]:
"""
Search compounds by name asynchronously.
Args:
name: Chemical compound name
max_results: Maximum number of results to return
Returns:
List of compound dictionaries
"""
try:
compounds = await asyncio.to_thread(
pcp.get_compounds, name, 'name', max_records=max_results
)
#print(compounds[0].to_dict())
return [compound.to_dict() for compound in compounds]
except Exception as e:
logging.error(f"Error searching by name '{name}': {str(e)}")
return [{"error": f"Error: {str(e)}"}]
async def _search_by_smiles(smiles: str, max_results: int = 5) -> List[Dict[str, Any]]:
"""
Search compounds by SMILES notation asynchronously.
Args:
smiles: SMILES notation of chemical compound
max_results: Maximum number of results to return
Returns:
List of compound dictionaries
"""
try:
compounds = await asyncio.to_thread(
pcp.get_compounds, smiles, 'smiles', max_records=max_results
)
return [compound.to_dict() for compound in compounds]
except Exception as e:
logging.error(f"Error searching by SMILES '{smiles}': {str(e)}")
return [{"error": f"Error: {str(e)}"}]
async def _search_by_formula(
formula: str,
max_results: int = 5,
listkey_count: int = 5,
listkey_start: int = 0
) -> List[Dict[str, Any]]:
"""
Search compounds by molecular formula asynchronously.
Uses pagination with listkey parameters to avoid timeout errors when searching
formulas that might return many results.
Args:
formula: Molecular formula
max_results: Maximum number of results to return
listkey_count: Number of results per page (default: 5)
listkey_start: Starting position for pagination (default: 0)
Returns:
List of compound dictionaries
"""
try:
# Use listkey parameters to avoid timeout errors
compounds = await asyncio.to_thread(
pcp.get_compounds,
formula,
'formula',
max_records=max_results,
listkey_count=listkey_count,
listkey_start=listkey_start
)
return [compound.to_dict() for compound in compounds]
except Exception as e:
logging.error(f"Error searching by formula '{formula}': {str(e)}")
return [{"error": f"Error: {str(e)}"}]
def _format_results_as_markdown(results: List[Dict[str, Any]], query_type: str, query_value: str) -> str:
"""
Format search results as a structured Markdown string.
Args:
results: List of compound dictionaries from compound.to_dict()
query_type: Type of search query (name, SMILES, formula)
query_value: Value of the search query
Returns:
Formatted Markdown string
"""
if not results:
return f"## PubChem Search Results\n\nNo compounds found for {query_type}: `{query_value}`"
if "error" in results[0]:
return f"## PubChem Search Error\n\n{results[0]['error']}"
markdown = f"## PubChem Search Results\n\nSearch by {query_type}: `{query_value}`\n\nFound {len(results)} compound(s)\n\n"
for i, compound in enumerate(results):
# Extract information from the compound.to_dict() structure
cid = compound.get("cid", "N/A")
iupac_name = compound.get("iupac_name", "Unknown")
molecular_formula = compound.get("molecular_formula", "N/A")
molecular_weight = compound.get("molecular_weight", "N/A")
canonical_smiles = compound.get("canonical_smiles", "N/A")
isomeric_smiles = compound.get("isomeric_smiles", "N/A")
inchi = compound.get("inchi", "N/A")
inchikey = compound.get("inchikey", "N/A")
xlogp = compound.get("xlogp", "N/A")
exact_mass = compound.get("exact_mass", "N/A")
tpsa = compound.get("tpsa", "N/A")
h_bond_donor_count = compound.get("h_bond_donor_count", "N/A")
h_bond_acceptor_count = compound.get("h_bond_acceptor_count", "N/A")
rotatable_bond_count = compound.get("rotatable_bond_count", "N/A")
heavy_atom_count = compound.get("heavy_atom_count", "N/A")
# Get atoms and bonds information if available
atoms = compound.get("atoms", [])
bonds = compound.get("bonds", [])
# Format the markdown output
markdown += f"### Compound {i+1}: {iupac_name}\n\n"
# Basic information section
markdown += "#### Basic Information\n\n"
markdown += f"- **CID**: {cid}\n"
markdown += f"- **Formula**: {molecular_formula}\n"
markdown += f"- **Molecular Weight**: {molecular_weight} g/mol\n"
markdown += f"- **Canonical SMILES**: `{canonical_smiles}`\n"
markdown += f"- **Isomeric SMILES**: `{isomeric_smiles}`\n"
# Identifiers section
markdown += "\n#### Identifiers\n\n"
markdown += f"- **InChI**: `{inchi}`\n"
markdown += f"- **InChIKey**: `{inchikey}`\n"
# Physical properties section
markdown += "\n#### Physical Properties\n\n"
markdown += f"- **XLogP**: {xlogp}\n"
markdown += f"- **Exact Mass**: {exact_mass}\n"
markdown += f"- **TPSA**: {tpsa} Ų\n"
# Molecular features section
markdown += "\n#### Molecular Features\n\n"
markdown += f"- **H-Bond Donors**: {h_bond_donor_count}\n"
markdown += f"- **H-Bond Acceptors**: {h_bond_acceptor_count}\n"
markdown += f"- **Rotatable Bonds**: {rotatable_bond_count}\n"
markdown += f"- **Heavy Atoms**: {heavy_atom_count}\n"
# Structure information
markdown += "\n#### Structure Information\n\n"
markdown += f"- **Atoms Count**: {len(atoms)}\n"
markdown += f"- **Bonds Count**: {len(bonds)}\n"
# Add a summary of atom elements if available
if atoms:
elements = {}
for atom in atoms:
element = atom.get("element", "")
if element:
elements[element] = elements.get(element, 0) + 1
if elements:
markdown += "- **Elements**: "
elements_str = ", ".join([f"{element}: {count}" for element, count in elements.items()])
markdown += f"{elements_str}\n"
markdown += "\n---\n\n" if i < len(results) - 1 else "\n"
return markdown
@llm_tool(name="search_advanced_pubchem",
description="Search for chemical compounds on PubChem database using name, SMILES notation, or molecular formula via PubChemPy library")
async def search_advanced_pubchem(
name: Optional[str] = None,
smiles: Optional[str] = None,
formula: Optional[str] = None,
max_results: int = 3
) -> str:
"""
Perform an advanced search for chemical compounds on PubChem using various identifiers.
This function allows searching by compound name, SMILES notation, or molecular formula.
At least one search parameter must be provided. If multiple parameters are provided,
the search will prioritize in the order: name > smiles > formula.
Args:
name: Name of the chemical compound (e.g., "Aspirin", "Caffeine")
smiles: SMILES notation of the chemical compound (e.g., "CC(=O)OC1=CC=CC=C1C(=O)O" for Aspirin)
formula: Molecular formula (e.g., "C9H8O4" for Aspirin)
max_results: Maximum number of results to return (default: 3)
Returns:
A formatted Markdown string with search results
Examples:
>>> search_advanced_pubchem(name="Aspirin")
# Returns information about Aspirin
>>> search_advanced_pubchem(smiles="CC(=O)OC1=CC=CC=C1C(=O)O")
# Returns information about compounds matching the SMILES notation
>>> search_advanced_pubchem(formula="C9H8O4", max_results=5)
# Returns up to 5 compounds with the formula C9H8O4
"""
logging.info(f"Performing advanced PubChem search with parameters: name={name}, smiles={smiles}, formula={formula}, max_results={max_results}")
# Validate input parameters
if name is None and smiles is None and formula is None:
return "## PubChem Search Error\n\nError: At least one search parameter (name, smiles, or formula) must be provided"
# Validate max_results
if max_results < 1:
max_results = 1
elif max_results > 10:
max_results = 10 # Limit to 10 results to prevent overwhelming responses
try:
results = []
query_type = ""
query_value = ""
# Prioritize search by name, then SMILES, then formula
if name is not None:
results = await _search_by_name(name, max_results)
query_type = "name"
query_value = name
elif smiles is not None:
results = await _search_by_smiles(smiles, max_results)
query_type = "SMILES"
query_value = smiles
elif formula is not None:
# Use pagination parameters for formula searches to avoid timeout
# Using the default values from _search_by_formula
results = await _search_by_formula(formula, max_results)
query_type = "formula"
query_value = formula
# Return results as markdown
return _format_results_as_markdown(results, query_type, query_value)
except Exception as e:
return f"## PubChem Search Error\n\nError: {str(e)}"

View File

@@ -0,0 +1,9 @@
"""
RDKit Tools Package
This package provides tools for molecular analysis, manipulation, and visualization
using the RDKit library. It includes functions for calculating molecular descriptors,
generating molecular fingerprints, analyzing molecular structures, and more.
"""
from .rdkit_tools import *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
"""
RXN Tools Module
This module provides tools for chemical reaction prediction and analysis
using the IBM RXN for Chemistry API.
"""

View File

@@ -0,0 +1,772 @@
"""
RXN Tools Module
This module provides tools for chemical reaction prediction and analysis
using the IBM RXN for Chemistry API through the rxn4chemistry package.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Dict, List, Union, Optional, Any, Tuple
from rxn4chemistry import RXN4ChemistryWrapper
from ...core.llm_tools import llm_tool
from ...core.config import Chemistry_Config
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Constants
DEFAULT_MAX_RESULTS = 3
DEFAULT_TIMEOUT = 180 # seconds - increased from 60 to 180
def _get_rxn_wrapper() -> RXN4ChemistryWrapper:
"""
Get an initialized RXN4Chemistry wrapper with API key and project set.
Returns:
Initialized RXN4ChemistryWrapper instance with project set
Raises:
ValueError: If API key is not available
"""
# Try to get API key from environment or config
api_key = Chemistry_Config.RXN4_CHEMISTRY_KEY or os.environ.get("RXN_API_KEY")
if not api_key:
raise ValueError("Error: RXN API key not found. Please set the RXN_API_KEY environment variable.")
# Initialize the wrapper
wrapper = RXN4ChemistryWrapper(api_key=api_key)
try:
# Create a new project
project_name = f"RXN_Tools_Project_{os.getpid()}" # Add process ID to make name unique
project_response = wrapper.create_project(project_name)
# Extract project ID from response
# The API response format is nested: {'response': {'payload': {'id': '...'}}
if project_response and isinstance(project_response, dict):
# Try to extract project ID from different possible response formats
project_id = None
# Direct format: {"project_id": "..."}
if "project_id" in project_response:
project_id = project_response["project_id"]
# Nested format: {"response": {"payload": {"id": "..."}}}
elif "response" in project_response and isinstance(project_response["response"], dict):
payload = project_response["response"].get("payload", {})
if isinstance(payload, dict) and "id" in payload:
project_id = payload["id"]
if project_id:
wrapper.set_project(project_id)
logger.info(f"RXN project '{project_name}' created and set successfully with ID: {project_id}")
else:
logger.warning(f"Could not extract project ID from response: {project_response}")
else:
logger.warning(f"Unexpected project creation response: {project_response}")
except Exception as e:
logger.error(f"Error creating RXN project: {e}")
# Check if project is set
if not hasattr(wrapper, "project_id") or not wrapper.project_id:
logger.warning("No project set. Some API calls may fail.")
return wrapper
def _format_reaction_markdown(reactants: str, products: List[str],
confidence: Optional[List[float]] = None) -> str:
"""
Format reaction results as Markdown.
Args:
reactants: SMILES of reactants
products: List of product SMILES
confidence: Optional list of confidence scores
Returns:
Formatted Markdown string
"""
markdown = f"## 反应预测结果\n\n"
markdown += f"### 输入反应物\n\n`{reactants}`\n\n"
markdown += f"### 预测产物\n\n"
for i, product in enumerate(products):
conf_str = f" (置信度: {confidence[i]:.2f})" if confidence and i < len(confidence) else ""
markdown += f"{i+1}. `{product}`{conf_str}\n"
return markdown
@llm_tool(name="predict_reaction_outcome_rxn",
description="Predict chemical reaction outcomes for given reactants using IBM RXN for Chemistry API")
async def predict_reaction_outcome_rxn(reactants: str) -> str:
"""
Predict chemical reaction outcomes for given reactants.
This function uses the IBM RXN for Chemistry API to predict the most likely
products formed when the given reactants are combined.
Args:
reactants: SMILES notation of reactants, multiple reactants separated by dots (.).
Returns:
Formatted Markdown string containing the predicted reaction results.
Examples:
>>> predict_reaction_outcome_rxn("BrBr.c1ccc2cc3ccccc3cc2c1")
# Returns predicted products of bromine and anthracene reaction
"""
try:
# Get RXN wrapper
wrapper = _get_rxn_wrapper()
# Clean input
reactants = reactants.strip()
# Submit prediction
response = await asyncio.to_thread(
wrapper.predict_reaction, reactants
)
if not response or "prediction_id" not in response:
return "Error: 无法提交反应预测请求"
# 直接获取结果而不是通过_wait_for_result
results = await asyncio.to_thread(
wrapper.get_predict_reaction_results,
response["prediction_id"]
)
# Extract products
try:
attempts = results.get("response", {}).get("payload", {}).get("attempts", [])
if not attempts:
return "Error: 未找到预测结果"
# Get the top predicted product
product_smiles = attempts[0].get("smiles", "")
confidence = attempts[0].get("confidence", None)
# Format results
return _format_reaction_markdown(
reactants,
[product_smiles] if product_smiles else ["无法预测产物"],
[confidence] if confidence is not None else None
)
except (KeyError, IndexError) as e:
logger.error(f"Error parsing prediction results: {e}")
return f"Error: 解析预测结果时出错: {str(e)}"
except Exception as e:
logger.error(f"Error in predict_reaction_outcome: {e}")
return f"Error: {str(e)}"
@llm_tool(name="predict_reaction_topn_rxn",
description="Predict multiple possible products for chemical reactions using IBM RXN for Chemistry API")
async def predict_reaction_topn_rxn(reactants: Union[str, List[str], List[List[str]]], topn: int = 3) -> str:
"""
Predict multiple possible products for chemical reactions.
This function uses the IBM RXN for Chemistry API to predict multiple products
that may be formed from given reactants, ranked by likelihood. Suitable for
scenarios where multiple reaction pathways need to be considered.
Args:
reactants: Reactants in one of the following formats:
- String: SMILES notation for a single reaction, multiple reactants separated by dots (.)
- List of strings: Multiple reactants for a single reaction, each reactant as a SMILES string
- List of lists of strings: Multiple reactions, each reaction composed of multiple reactant SMILES strings
topn: Number of predicted products to return for each reaction, default is 3.
Returns:
Formatted Markdown string containing multiple predicted reaction results.
Examples:
>>> predict_reaction_topn_rxn("BrBr.c1ccc2cc3ccccc3cc2c1", 5)
# Returns top 5 possible products for bromine and anthracene reaction
>>> predict_reaction_topn_rxn(["BrBr", "c1ccc2cc3ccccc3cc2c1"], 3)
# Returns top 3 possible products for bromine and anthracene reaction
>>> predict_reaction_topn_rxn([
... ["BrBr", "c1ccc2cc3ccccc3cc2c1"],
... ["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"]
... ], 3)
# Returns top 3 possible products for two different reactions
"""
try:
# Get RXN wrapper
wrapper = _get_rxn_wrapper()
# Validate topn
if topn < 1:
topn = 1
elif topn > 10:
topn = 10
logger.warning("topn限制为最大10个结果")
# Process input to create precursors_lists
precursors_lists = []
if isinstance(reactants, str):
# Single reaction as string (e.g., "BrBr.c1ccc2cc3ccccc3cc2c1")
reactants = reactants.strip()
precursors_lists = [reactants.split(".")]
# For display in results
reactants_display = [reactants]
elif isinstance(reactants, list):
if all(isinstance(r, str) for r in reactants):
# Single reaction as list of strings (e.g., ["BrBr", "c1ccc2cc3ccccc3cc2c1"])
precursors_lists = [reactants]
# For display in results
reactants_display = [".".join(reactants)]
elif all(isinstance(r, list) for r in reactants):
# Multiple reactions as list of lists (e.g., [["BrBr", "c1ccc2cc3ccccc3cc2c1"], ["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"]])
precursors_lists = reactants
# For display in results
reactants_display = [".".join(r) for r in reactants]
else:
return "Error: 反应物列表格式无效,必须是字符串列表或字符串列表的列表"
else:
return "Error: 反应物参数类型无效,必须是字符串或列表"
# Submit prediction
response = await asyncio.to_thread(
wrapper.predict_reaction_batch_topn,
precursors_lists=precursors_lists,
topn=topn
)
if not response or "task_id" not in response:
return "Error: 无法提交多产物反应预测请求"
# 直接获取结果,不使用循环等待
results = await asyncio.to_thread(
wrapper.get_predict_reaction_batch_topn_results,
response["task_id"]
)
# Extract products
try:
# 记录结果的结构,以便调试
logger.info(f"Results structure: {results.keys()}")
# 更灵活地获取结果使用get方法并提供默认值
reaction_results = results.get("result", [])
# 如果结果为空,尝试其他可能的键
if not reaction_results and "predictions" in results:
reaction_results = results.get("predictions", [])
logger.info("Using 'predictions' key instead of 'result'")
# 如果结果仍然为空,尝试直接使用整个结果
if not reaction_results and isinstance(results, list):
reaction_results = results
logger.info("Using entire results as list")
if not reaction_results:
logger.warning(f"No reaction results found. Available keys: {results.keys()}")
return "Error: 未找到预测结果。请检查API响应格式。"
# Format results for all reactions
markdown = "## 反应预测结果\n\n"
# 确保reaction_results和reactants_display长度匹配
if len(reaction_results) != len(reactants_display):
logger.warning(f"Mismatch between results ({len(reaction_results)}) and reactants ({len(reactants_display)})")
# 如果不匹配,使用较短的长度
min_len = min(len(reaction_results), len(reactants_display))
reaction_results = reaction_results[:min_len]
reactants_display = reactants_display[:min_len]
for i, (reaction_result, reactants_str) in enumerate(zip(reaction_results, reactants_display), 1):
if not reaction_result:
markdown += f"### 反应 {i}: 未找到预测结果\n\n"
continue
# 记录每个反应结果的结构
logger.info(f"Reaction {i} result structure: {type(reaction_result)}")
# Extract products and confidences for this reaction
products = []
confidences = []
# 处理不同格式的反应结果
if isinstance(reaction_result, list):
# 标准格式:列表中的每个项目是一个预测
for pred in reaction_result:
if isinstance(pred, dict) and "smiles" in pred:
# 检查smiles是否为列表
if isinstance(pred["smiles"], list) and pred["smiles"]:
products.append(pred["smiles"][0]) # 取列表中的第一个元素
else:
products.append(pred["smiles"])
confidences.append(pred.get("confidence", 0.0))
elif isinstance(reaction_result, dict):
# 根据用户反馈,检查是否有'results'键
if "results" in reaction_result:
# 遍历results列表
for pred in reaction_result.get("results", []):
if isinstance(pred, dict) and "smiles" in pred:
# 检查smiles是否为列表
if isinstance(pred["smiles"], list) and pred["smiles"]:
products.append(pred["smiles"][0]) # 取列表中的第一个元素
else:
products.append(pred["smiles"])
confidences.append(pred.get("confidence", 0.0))
# 替代格式:字典中直接包含预测
elif "smiles" in reaction_result:
# 检查smiles是否为列表
if isinstance(reaction_result["smiles"], list) and reaction_result["smiles"]:
products.append(reaction_result["smiles"][0]) # 取列表中的第一个元素
else:
products.append(reaction_result["smiles"])
confidences.append(reaction_result.get("confidence", 0.0))
# 另一种可能的格式
elif "products" in reaction_result:
for prod in reaction_result.get("products", []):
if isinstance(prod, dict) and "smiles" in prod:
# 检查smiles是否为列表
if isinstance(prod["smiles"], list) and prod["smiles"]:
products.append(prod["smiles"][0]) # 取列表中的第一个元素
else:
products.append(prod["smiles"])
confidences.append(prod.get("confidence", 0.0))
# Add results for this reaction
markdown += f"### 反应 {i}\n\n"
markdown += f"**输入反应物:** `{reactants_str}`\n\n"
if products:
markdown += "**预测产物:**\n\n"
for j, (product, confidence) in enumerate(zip(products, confidences), 1):
markdown += f"{j}. `{product}` (置信度: {confidence:.2f})\n"
else:
markdown += "**预测产物:** 无法解析产物结构\n\n"
# 添加原始结果以便调试
markdown += f"**原始结果:** `{reaction_result}`\n\n"
markdown += "\n"
return markdown
except Exception as e:
logger.error(f"Error parsing topn prediction results: {e}", exc_info=True)
return f"Error: 解析多产物预测结果时出错: {str(e)}"
except Exception as e:
logger.error(f"Error in predict_reaction_topn: {e}")
return f"Error: {str(e)}"
# @llm_tool(name="predict_retrosynthesis",
# description="预测目标分子的逆合成路径")
# async def predict_retrosynthesis(target_molecule: str, max_steps: int = 3) -> str:
# """
# 预测目标分子的逆合成路径。
# 此函数使用IBM RXN for Chemistry API建议可能的合成路线
# 将目标分子分解为可能商业可得的更简单前体。
# Args:
# target_molecule: 目标分子的SMILES表示法。
# max_steps: 考虑的最大逆合成步骤数默认为3。
# Returns:
# 包含预测逆合成路径的格式化Markdown字符串。
# Examples:
# >>> predict_retrosynthesis("Brc1c2ccccc2c(Br)c2ccccc12")
# # 返回目标分子的可能合成路线
# """
# try:
# # Get RXN wrapper
# wrapper = _get_rxn_wrapper()
# # Clean input
# target_molecule = target_molecule.strip()
# # Validate max_steps
# if max_steps < 1:
# max_steps = 1
# elif max_steps > 5:
# max_steps = 5
# logger.warning("max_steps限制为最大5步")
# # Submit prediction
# response = await asyncio.to_thread(
# wrapper.predict_automatic_retrosynthesis,
# product=target_molecule,
# max_steps=max_steps
# )
# if not response or "prediction_id" not in response:
# return "Error: 无法提交逆合成预测请求"
# # 直接获取结果而不是通过_wait_for_result
# results = await asyncio.to_thread(
# wrapper.get_predict_automatic_retrosynthesis_results,
# response["prediction_id"]
# )
# # Extract retrosynthetic paths
# try:
# paths = results.get("retrosynthetic_paths", [])
# if not paths:
# return "## 逆合成分析结果\n\n未找到可行的逆合成路径。目标分子可能太复杂或结构有问题。"
# # Format results
# markdown = f"## 逆合成分析结果\n\n"
# markdown += f"### 目标分子\n\n`{target_molecule}`\n\n"
# markdown += f"### 找到的合成路径: {len(paths)}\n\n"
# # Limit to top 3 paths for readability
# display_paths = paths[:3]
# for i, path in enumerate(display_paths, 1):
# markdown += f"#### 路径 {i}\n\n"
# # Extract sequence information
# sequence_id = path.get("sequenceId", "未知")
# confidence = path.get("confidence", 0.0)
# markdown += f"**置信度:** {confidence:.2f}\n\n"
# # Extract steps
# steps = path.get("steps", [])
# if steps:
# markdown += "**合成步骤:**\n\n"
# for j, step in enumerate(steps, 1):
# # Extract reactants and products
# reactants = step.get("reactants", [])
# reactant_smiles = [r.get("smiles", "") for r in reactants if "smiles" in r]
# product = step.get("product", {})
# product_smiles = product.get("smiles", "")
# markdown += f"步骤 {j}: "
# if reactant_smiles and product_smiles:
# markdown += f"`{'.' if len(reactant_smiles) > 1 else ''.join(reactant_smiles)}` → `{product_smiles}`\n\n"
# else:
# markdown += "反应细节不可用\n\n"
# else:
# markdown += "**合成步骤:** 未提供详细步骤\n\n"
# markdown += "---\n\n"
# if len(paths) > 3:
# markdown += f"*注: 仅显示前3条路径共找到{len(paths)}条可能的合成路径。*\n"
# return markdown
# except (KeyError, IndexError) as e:
# logger.error(f"Error parsing retrosynthesis results: {e}")
# return f"Error: 解析逆合成结果时出错: {str(e)}"
# except Exception as e:
# logger.error(f"Error in predict_retrosynthesis: {e}")
# return f"Error: {str(e)}"
# @llm_tool(name="predict_biocatalytic_retrosynthesis",
# description="使用生物催化模型预测目标分子的逆合成路径")
# async def predict_biocatalytic_retrosynthesis(target_molecule: str) -> str:
# """
# 使用生物催化模型预测目标分子的逆合成路径。
# 此函数使用IBM RXN for Chemistry API的专门生物催化模型
# 建议可能的酶催化合成路线来创建目标分子。
# Args:
# target_molecule: 目标分子的SMILES表示法。
# Returns:
# 包含预测生物催化逆合成路径的格式化Markdown字符串。
# Examples:
# >>> predict_biocatalytic_retrosynthesis("OC1C(O)C=C(Br)C=C1")
# # 返回目标分子的可能酶催化合成路线
# """
# try:
# # Get RXN wrapper
# wrapper = _get_rxn_wrapper()
# # Clean input
# target_molecule = target_molecule.strip()
# # Submit prediction with enzymatic model
# # Note: The model name might change in future API versions
# response = await asyncio.to_thread(
# wrapper.predict_automatic_retrosynthesis,
# product=target_molecule,
# ai_model="enzymatic-2021-04-16" # Use the enzymatic model
# )
# if not response or "prediction_id" not in response:
# return "Error: 无法提交生物催化逆合成预测请求"
# # 直接获取结果而不是通过_wait_for_result
# results = await asyncio.to_thread(
# wrapper.get_predict_automatic_retrosynthesis_results,
# response["prediction_id"]
# )
# # Extract retrosynthetic paths
# try:
# paths = results.get("retrosynthetic_paths", [])
# if not paths:
# return "## 生物催化逆合成分析结果\n\n未找到可行的酶催化合成路径。目标分子可能不适合酶催化或结构有问题。"
# # Format results
# markdown = f"## 生物催化逆合成分析结果\n\n"
# markdown += f"### 目标分子\n\n`{target_molecule}`\n\n"
# markdown += f"### 找到的酶催化合成路径: {len(paths)}\n\n"
# # Limit to top 3 paths for readability
# display_paths = paths[:3]
# for i, path in enumerate(display_paths, 1):
# markdown += f"#### 路径 {i}\n\n"
# # Extract sequence information
# sequence_id = path.get("sequenceId", "未知")
# confidence = path.get("confidence", 0.0)
# markdown += f"**置信度:** {confidence:.2f}\n\n"
# # Extract steps
# steps = path.get("steps", [])
# if steps:
# markdown += "**酶催化步骤:**\n\n"
# for j, step in enumerate(steps, 1):
# # Extract reactants and products
# reactants = step.get("reactants", [])
# reactant_smiles = [r.get("smiles", "") for r in reactants if "smiles" in r]
# product = step.get("product", {})
# product_smiles = product.get("smiles", "")
# markdown += f"步骤 {j}: "
# if reactant_smiles and product_smiles:
# markdown += f"`{'.' if len(reactant_smiles) > 1 else ''.join(reactant_smiles)}` → `{product_smiles}`\n\n"
# else:
# markdown += "反应细节不可用\n\n"
# # Add enzyme information if available
# if "enzymeClass" in step:
# markdown += f"*可能的酶类别: {step['enzymeClass']}*\n\n"
# else:
# markdown += "**合成步骤:** 未提供详细步骤\n\n"
# markdown += "---\n\n"
# if len(paths) > 3:
# markdown += f"*注: 仅显示前3条路径共找到{len(paths)}条可能的酶催化合成路径。*\n"
# return markdown
# except (KeyError, IndexError) as e:
# logger.error(f"Error parsing biocatalytic retrosynthesis results: {e}")
# return f"Error: 解析生物催化逆合成结果时出错: {str(e)}"
# except Exception as e:
# logger.error(f"Error in predict_biocatalytic_retrosynthesis: {e}")
# return f"Error: {str(e)}"
@llm_tool(name="predict_reaction_properties_rxn",
description="Predict chemical reaction properties such as atom mapping and yield using IBM RXN for Chemistry API")
async def predict_reaction_properties_rxn(
reaction: str,
property_type: str = "atom-mapping"
) -> str:
"""
Predict chemical reaction properties such as atom mapping and yield.
This function uses the IBM RXN for Chemistry API to predict various properties
of chemical reactions, including atom-to-atom mapping (showing how atoms in
reactants correspond to atoms in products) and potential reaction yields.
Args:
reaction: SMILES notation of the reaction (reactants>>products).
property_type: Type of property to predict. Options: "atom-mapping", "yield".
Returns:
Formatted Markdown string containing predicted reaction properties.
Examples:
>>> predict_reaction_properties_rxn("CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F", "atom-mapping")
# Returns atom mapping for the reaction
"""
try:
# Get RXN wrapper
wrapper = _get_rxn_wrapper()
# Clean input
reaction = reaction.strip()
# Validate property_type
valid_property_types = ["atom-mapping", "yield"]
if property_type not in valid_property_types:
return f"Error: 无效的属性类型 '{property_type}'。支持的类型: {', '.join(valid_property_types)}"
# Determine model based on property type
ai_model = "atom-mapping-2020" if property_type == "atom-mapping" else "yield-2020-08-10"
# Submit prediction
response = await asyncio.to_thread(
wrapper.predict_reaction_properties,
reactions=[reaction],
ai_model=ai_model
)
if not response or "response" not in response:
return f"Error: 无法提交{property_type}预测请求"
# Extract results
try:
content = response.get("response", {}).get("payload", {}).get("content", [])
if not content:
return f"Error: 未找到{property_type}预测结果"
# Format results based on property type
markdown = f"## 反应{property_type}预测结果\n\n"
markdown += f"### 输入反应\n\n`{reaction}`\n\n"
if property_type == "atom-mapping":
# Extract mapped reaction
mapped_reaction = content[0].get("value", "")
if not mapped_reaction:
return "Error: 无法生成原子映射"
markdown += "### 原子映射结果\n\n"
markdown += f"`{mapped_reaction}`\n\n"
# Split into reactants and products for explanation
if ">>" in mapped_reaction:
reactants, products = mapped_reaction.split(">>")
markdown += "### 映射解释\n\n"
markdown += "原子映射显示了反应物中的原子如何对应到产物中的原子。\n"
markdown += "每个原子上的数字表示映射ID相同ID的原子在反应前后是同一个原子。\n\n"
markdown += f"**映射的反应物:** `{reactants}`\n\n"
markdown += f"**映射的产物:** `{products}`\n"
elif property_type == "yield":
# Extract predicted yield
predicted_yield = content[0].get("value", "")
if not predicted_yield:
return "Error: 无法预测反应产率"
try:
yield_value = float(predicted_yield)
markdown += "### 产率预测结果\n\n"
markdown += f"**预测产率:** {yield_value:.1f}%\n\n"
# Add interpretation
if yield_value < 30:
markdown += "**解释:** 预测产率较低,反应可能效率不高。考虑优化反应条件或探索替代路线。\n"
elif yield_value < 70:
markdown += "**解释:** 预测产率中等,反应可能是可行的,但有优化空间。\n"
else:
markdown += "**解释:** 预测产率较高,反应可能非常有效。\n"
except ValueError:
markdown += f"**预测产率:** {predicted_yield}\n"
return markdown
except (KeyError, IndexError) as e:
logger.error(f"Error parsing reaction properties results: {e}")
return f"Error: 解析反应属性预测结果时出错: {str(e)}"
except Exception as e:
logger.error(f"Error in predict_reaction_properties: {e}")
return f"Error: {str(e)}"
@llm_tool(name="extract_reaction_actions_rxn",
description="Extract structured reaction steps from text descriptions using IBM RXN for Chemistry API")
async def extract_reaction_actions_rxn(reaction_text: str) -> str:
"""
Extract structured reaction steps from text descriptions.
This function uses the IBM RXN for Chemistry API to parse text descriptions
of chemical procedures and extract structured actions representing the steps
of the procedure.
Args:
reaction_text: Text description of a chemical reaction procedure.
Returns:
Formatted Markdown string containing the extracted reaction steps.
Examples:
>>> extract_reaction_actions_rxn("To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.")
# Returns structured steps extracted from the text
"""
try:
# Get RXN wrapper
wrapper = _get_rxn_wrapper()
# Clean input
reaction_text = reaction_text.strip()
if not reaction_text:
return "Error: 反应文本为空"
# Submit extraction request
response = await asyncio.to_thread(
wrapper.paragraph_to_actions,
paragraph=reaction_text
)
# 检查response是否存在
if not response:
return "Error: 无法从文本中提取反应步骤"
# 直接返回response不做任何处理
# 这是基于参考代码中直接打印results['actions']的方式
# 我们假设response本身就是我们需要的结果
return f"""## 反应步骤提取结果
### 输入文本
{reaction_text}
### 提取的反应步骤
```
{response}
```
"""
return markdown
except Exception as e:
logger.error(f"Error in extract_reaction_actions: {e}")
return f"Error: {str(e)}"

84
sci_mcp/core/config.py Executable file
View File

@@ -0,0 +1,84 @@
"""
Configuration Module
This module provides configuration settings for the Mars Toolkit.
It includes API keys, endpoints, paths, and other configuration parameters.
"""
from typing import Dict, Any
class Config:
@classmethod
def as_dict(cls) -> Dict[str, Any]:
"""Return all configuration settings as a dictionary"""
return {
key: value for key, value in cls.__dict__.items()
if not key.startswith('__') and not callable(value)
}
@classmethod
def update(cls, **kwargs):
"""Update configuration settings"""
for key, value in kwargs.items():
if hasattr(cls, key):
setattr(cls, key, value)
class General_Config(Config):
"""Configuration class for General MCP"""
# Searxng
SEARXNG_HOST="http://192.168.168.1:40032/"
SEARXNG_MAX_RESULTS=10
class Material_Config(Config):
# Materials Project
MP_API_KEY = 'PMASAg256b814q3OaSRWeVc7MKx4mlKI'
MP_ENDPOINT = 'https://api.materialsproject.org/'
MP_TOPK = 3
LOCAL_MP_PROPS_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Props/'
LOCAL_MP_CIF_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/MPDatasets/'
# Proxy
HTTP_PROXY = ''#'http://192.168.168.1:20171'
HTTPS_PROXY = ''#'http://192.168.168.1:20171'
# FairChem
FAIRCHEM_MODEL_PATH = '/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/support/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
FMAX = 0.05
# MatterGen
MATTERGENMODEL_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/support/pretrained_models/mattergen_ckpt'
MATTERGEN_ROOT='/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/mattergen_gen/mattergen'
MATTERGENMODEL_RESULT_PATH = 'results/'
# Dify
DIFY_ROOT_URL = 'http://192.168.191.101:6080'
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
#temp root
TEMP_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/material'
class Chemistry_Config(Config):
TEMP_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/chemistry'
RXN4_CHEMISTRY_KEY='apk-8928522a146c2503f30b16d9909222d7583f412ee8f1049f08d32a089ba88d34'
material_config = Material_Config()
general_config = General_Config()
chemistry_config = Chemistry_Config()

315
sci_mcp/core/llm_tools.py Executable file
View File

@@ -0,0 +1,315 @@
"""
LLM Tools Module
This module provides decorators and utilities for defining, registering, and managing LLM tools.
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
registered tools for use with LLM APIs.
"""
import asyncio
import inspect
import importlib
import pkgutil
import os
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
import docstring_parser
from pydantic import BaseModel, create_model, Field
# Registry to store all registered tools
_TOOL_REGISTRY = {}
# Mapping of domain names to their module paths
_DOMAIN_MODULE_MAPPING = {
'material': 'sci_mcp.material_mcp',
'general': 'sci_mcp.general_mcp',
'biology': 'sci_mcp.biology_mcp',
'chemistry': 'sci_mcp.chemistry_mcp'
}
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
"""
Decorator to mark a function as an LLM tool.
This decorator registers the function as an LLM tool, generates a JSON schema for it,
and makes it available for retrieval through the get_tools function.
Args:
name: Optional custom name for the tool. If not provided, the function name will be used.
description: Optional custom description for the tool. If not provided, the function's
docstring will be used.
Returns:
The decorated function with additional attributes for LLM tool functionality.
Example:
@llm_tool(name="weather_lookup", description="Get current weather for a location")
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
'''Get weather information for a specific location.'''
# Implementation...
return {"temperature": 22.5, "conditions": "sunny"}
"""
# Handle case when decorator is used without parentheses: @llm_tool
if callable(name):
func = name
name = None
description = None
return _llm_tool_impl(func, name, description)
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
def decorator(func: Callable) -> Callable:
return _llm_tool_impl(func, name, description)
return decorator
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
"""Implementation of the llm_tool decorator."""
# Get function signature and docstring
sig = inspect.signature(func)
doc = inspect.getdoc(func) or ""
parsed_doc = docstring_parser.parse(doc)
# Determine tool name
tool_name = name or func.__name__
# Determine tool description
tool_description = description or doc
# Create parameter properties for JSON schema
properties = {}
required = []
for param_name, param in sig.parameters.items():
# Skip self parameter for methods
if param_name == "self":
continue
param_type = param.annotation
param_default = None if param.default is inspect.Parameter.empty else param.default
param_required = param.default is inspect.Parameter.empty
# Get parameter description from docstring if available
param_desc = ""
for param_doc in parsed_doc.params:
if param_doc.arg_name == param_name:
param_desc = param_doc.description
break
# Handle Annotated types
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
args = get_args(param_type)
param_type = args[0] # The actual type
if len(args) > 1 and isinstance(args[1], str):
param_desc = args[1] # The description
# Create property for parameter
param_schema = {
"type": _get_json_type(param_type),
"description": param_desc,
"title": param_name.replace("_", " ").title()
}
# Add default value if available
if param_default is not None:
param_schema["default"] = param_default
properties[param_name] = param_schema
# Add to required list if no default value
if param_required:
required.append(param_name)
# Create OpenAI format JSON schema
openai_schema = {
"type": "function",
"function": {
"name": tool_name,
"description": tool_description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
}
# Create MCP format JSON schema
mcp_schema = {
"name": tool_name,
"description": tool_description,
"inputSchema": {
"type": "object",
"properties": properties,
"required": required
}
}
# Create Pydantic model for args schema
field_definitions = {}
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
param_type = param.annotation
param_default = ... if param.default is inspect.Parameter.empty else param.default
# Handle Annotated types
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
args = get_args(param_type)
param_type = args[0]
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
else:
field_definitions[param_name] = (param_type, Field(default=param_default))
# Create args schema model
model_name = f"{tool_name.title().replace('_', '')}Schema"
args_schema = create_model(model_name, **field_definitions)
# 根据原始函数是否是异步函数来创建相应类型的包装函数
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
else:
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
# Attach metadata to function
wrapper.is_llm_tool = True
wrapper.tool_name = tool_name
wrapper.tool_description = tool_description
wrapper.openai_schema = openai_schema
wrapper.mcp_schema = mcp_schema
wrapper.args_schema = args_schema
# Register the tool
_TOOL_REGISTRY[tool_name] = wrapper
return wrapper
def get_all_tools() -> Dict[str, Callable]:
"""
Get all registered LLM tools.
Returns:
A dictionary mapping tool names to their corresponding functions.
"""
return _TOOL_REGISTRY
def get_all_tool_schemas(schema_type='openai') -> List[Dict[str, Any]]:
"""
Get JSON schemas for all registered LLM tools.
Returns:
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
"""
return [tool.mcp_schema for tool in _TOOL_REGISTRY.values()] if schema_type == 'mcp' else [tool.openai_schema for tool in _TOOL_REGISTRY.values()]
def import_domain_tools(domains: List[str]) -> None:
"""
Import tools from specified domains to ensure they are registered.
This function dynamically imports modules from the specified domains to ensure
that all tools decorated with @llm_tool are registered in the _TOOL_REGISTRY.
Args:
domains: List of domain names (e.g., ['material', 'general'])
"""
for domain in domains:
if domain not in _DOMAIN_MODULE_MAPPING:
continue
module_path = _DOMAIN_MODULE_MAPPING[domain]
try:
# Import the base module
base_module = importlib.import_module(module_path)
base_path = os.path.dirname(base_module.__file__)
# Recursively import all submodules
for _, name, is_pkg in pkgutil.walk_packages([base_path], f"{module_path}."):
try:
importlib.import_module(name)
except ImportError as e:
print(f"Error importing {name}: {e}")
except ImportError as e:
print(f"Error importing domain {domain}: {e}")
def get_domain_tools(domains: List[str]) -> Dict[str, Dict[str, Callable]]:
"""
Get tools from specified domains.
Args:
domains: List of domain names (e.g., ['material', 'general'])
Returns:
A dictionary that maps tool names and their functions
"""
# First, ensure all tools from the specified domains are imported and registered
import_domain_tools(domains)
domain_tools = {}
for domain in domains:
if domain not in _DOMAIN_MODULE_MAPPING:
continue
domain_module_prefix = _DOMAIN_MODULE_MAPPING[domain]
for tool_name, tool_func in _TOOL_REGISTRY.items():
# Check if the tool's module belongs to this domain
if hasattr(tool_func, "__module__") and tool_func.__module__.startswith(domain_module_prefix):
domain_tools[tool_name] = tool_func
return domain_tools
def get_domain_tool_schemas(domains: List[str],schema_type='openai') -> Dict[str, List[Dict[str, Any]]]:
"""
Get JSON schemas for tools from specified domains.
Args:
domains: List of domain names (e.g., ['material', 'general'])
Returns:
A dictionary mapping domain names to lists of tool schemas
"""
# First, get all domain tools
domain_tools = get_domain_tools(domains)
if schema_type == 'mcp':
tools_schema_list = [tool.mcp_schema for tool in domain_tools.values()]
else:
tools_schema_list = [tool.openai_schema for tool in domain_tools.values()]
return tools_schema_list
def _get_json_type(python_type: Any) -> str:
"""
Convert Python type to JSON schema type.
Args:
python_type: Python type annotation
Returns:
Corresponding JSON schema type as string
"""
if python_type is str:
return "string"
elif python_type is int:
return "integer"
elif python_type is float:
return "number"
elif python_type is bool:
return "boolean"
elif python_type is list or python_type is List:
return "array"
elif python_type is dict or python_type is Dict:
return "object"
else:
# Default to string for complex types
return "string"

View File

View File

@@ -0,0 +1,78 @@
"""
Search Online Module
This module provides functions for searching information on the web.
"""
from ...core.llm_tools import llm_tool
from ...core.config import general_config
import asyncio
import os
from typing import Annotated, Any, Dict, List, Union
from langchain_community.utilities import SearxSearchWrapper
import mcp.types as types
@llm_tool(name="search_online_searxng", description="Search scientific information online using searxng")
async def search_online_searxng(
query: Annotated[str, "Search term"],
num_results: Annotated[int, "Number of results"] = 5
) -> str:
"""
Searches for scientific information online and returns results as a formatted string.
Args:
query: Search term for scientific content
num_results: Number of results to return
Returns:
Formatted string with search results (titles, snippets, links)
"""
# lzy: 此部分到正式发布时可能要删除因为searxng 已在本地部署,因此本地调试时无需设置代理
os.environ['HTTP_PROXY'] = ''
os.environ['HTTPS_PROXY'] = ''
try:
max_results = min(int(num_results), general_config.SEARXNG_MAX_RESULTS)
search = SearxSearchWrapper(
searx_host=general_config.SEARXNG_HOST,
categories=["science",],
k=num_results
)
# Execute search in a separate thread to avoid blocking the event loop
# since SearxSearchWrapper doesn't have native async support
loop = asyncio.get_event_loop()
raw_results = await loop.run_in_executor(
None,
lambda: search.results(query, language=['en','zh'], num_results=max_results)
)
# Transform results into structured format
formatted_results = []
for result in raw_results:
formatted_results.append({
"title": result.get("title", ""),
"snippet": result.get("snippet", ""),
"link": result.get("link", ""),
"source": result.get("source", "")
})
# Format results into a readable Markdown string
result_str = f"Search Results for '{query}' ({len(formatted_results)} items):\n\n"
if len(formatted_results) > 0:
for i, res in enumerate(formatted_results):
title = res.get("title", "No Title")
snippet = res.get("snippet", "No Snippet")
link = res.get("link", "No Link")
source = res.get("source", "No Source")
result_str += f"{i + 1}. **{title}**\n"
result_str += f" - Snippet: {snippet}\n"
result_str += f" - Link: [{link}]({link})\n"
result_str += f" - Source: {source}\n\n"
else:
result_str += "No results found.\n"
return result_str
except Exception as e:
return f"Error: {str(e)}"

View File

@@ -0,0 +1,34 @@
# # Core modules
# from mars_toolkit.core.config import config
# # Basic tools
# from mars_toolkit.misc.misc_tools import get_current_time
# # Compute modules
# from mars_toolkit.compute.material_gen import generate_material
# from mars_toolkit.compute.property_pred import predict_properties
# from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
# # Query modules
# from mars_toolkit.query.mp_query import (
# search_material_property_from_material_project,
# get_crystal_structures_from_materials_project,
# get_mpid_from_formula
# )
# from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
# from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
# from mars_toolkit.query.web_search import search_online
# # Visualization modules
# from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
# __version__ = "0.1.0"
# __all__ = ["llm_tool", "get_tools", "get_tool_schemas"]

View File

@@ -0,0 +1,191 @@
"""
Structure Optimization Module
This module provides functions for optimizing crystal structures using the FairChem model.
"""
import asyncio
from io import StringIO
import sys
import tempfile
import os
import logging
from typing import Optional, Dict, Any
from ase.io import read, write
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from ..support.utils import convert_structure, remove_symmetry_equiv_xyz, read_structure_from_file_name_or_content_string
from ...core.llm_tools import llm_tool
from ...core.config import material_config
logger = logging.getLogger(__name__)
# 初始化FairChem模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
if calc is not None:
return
try:
from fairchem.core import OCPCalculator
calc = OCPCalculator(checkpoint_path=material_config.FAIRCHEM_MODEL_PATH)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def generate_symmetry_cif(structure: Structure) -> str:
"""
生成对称性CIF
Args:
structure: Pymatgen Structure对象
Returns:
CIF格式的字符串
"""
analyzer = SpacegroupAnalyzer(structure)
structure_refined = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
return content
def optimize_structure(atoms: Atoms, output_format: str, fmax: float = 0.05) -> str:
"""
优化晶体结构
Args:
atoms: ASE Atoms对象
output_format: 输出格式 (cif, xyz, vasp等)
fmax: 力收敛标准
Returns:
包含优化结果的格式化字符串
"""
atoms.calc = calc
try:
# 捕获优化过程的输出
temp_output = StringIO()
original_stdout = sys.stdout
sys.stdout = temp_output
# 执行优化
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=fmax)
# 恢复标准输出并获取日志
sys.stdout = original_stdout
optimization_log = temp_output.getvalue()
temp_output.close()
# 获取总能量
total_energy = atoms.get_potential_energy()
# 处理优化后的结构
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
content = remove_symmetry_equiv_xyz(content)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
**Total Energy: {total_energy} eV**
#### Optimizing Log:
```text
{optimization_log}
```
### Optimized {output_format.upper()} Content:
```
{content}
```
"""
return format_result
except Exception as e:
return f"Error: Failed to optimize structure: {str(e)}"
@llm_tool(name="optimize_crystal_structure_FairChem",
description="Optimizes crystal structures using the FairChem model")
async def optimize_crystal_structure_FairChem(
structure_source: str,
format_type: str = "auto",
optimization_level: str = "normal"
) -> str:
"""
Optimizes a crystal structure to find its lowest energy configuration.
Args:
structure_source: Either a file name or direct structure content (CIF, XYZ, POSCAR)
format_type: Structure format type (auto, cif, xyz, poscar). Default "auto" will attempt to detect format.
optimization_level: Optimization precision level (quick, normal, precise)
Returns:
Optimized structure with total energy and optimization details
"""
# 确保模型已初始化
if calc is None:
init_model()
# 设置优化参数
fmax_values = {
"quick": 0.1,
"normal": 0.05,
"precise": 0.01
}
fmax = fmax_values.get(optimization_level, 0.05)
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_optimization():
try:
# 处理输入结构
content, actual_format = read_structure_from_file_name_or_content_string(structure_source, format_type)
# 转换格式映射
format_mapping = {
"cif": "cif",
"xyz": "xyz",
"poscar": "vasp",
"vasp": "vasp"
}
final_format = format_mapping.get(actual_format.lower(), "cif")
# 转换结构
atoms = convert_structure(final_format, content)
if atoms is None:
return f"Error: Unable to convert input structure. Please check if the format is correct."
# 优化结构
return optimize_structure(atoms, final_format, fmax=fmax)
except Exception as e:
return f"Error: Failed to optimize structure: {str(e)}"
return await asyncio.to_thread(run_optimization)

View File

@@ -0,0 +1,70 @@
import codecs
import json
import requests
from ...core.llm_tools import llm_tool
from ...core.config import material_config
@llm_tool(
name="retrieval_from_knowledge_base",
description="Retrieve information from local materials science literature knowledge base"
)
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
"""
检索本地材料科学文献知识库中的相关信息
Args:
query: 查询字符串,如材料名称"CsPbBr3"
topk: 返回结果数量默认3条
Returns:
包含文档ID、标题和相关性分数的字典
"""
# 设置Dify API的URL端点
url = f'{material_config.DIFY_ROOT_URL}/v1/chat-messages'
# 配置请求头包含API密钥和内容类型
headers = {
'Authorization': f'Bearer {material_config.DIFY_API_KEY}',
'Content-Type': 'application/json'
}
# 准备请求数据
data = {
"inputs": {"topK": topk}, # 设置返回的最大结果数量
"query": query, # 设置查询字符串
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
"conversation_id": "", # 不使用会话ID每次都是独立查询
"user": "abc-123" # 用户标识符
}
try:
# 发送POST请求到Dify API并获取响应
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
response = requests.post(url, headers=headers, json=data, timeout=1111)
# 获取响应文本
response_text = response.text
# 解码响应文本中的Unicode转义序列
response_text = codecs.decode(response_text, 'unicode_escape')
# 将响应文本解析为JSON对象
result_json = json.loads(response_text)
# 从响应中提取元数据
metadata = result_json.get("metadata", {})
# 构建包含关键信息的结果字典
useful_info = {
"id": metadata.get("document_id"), # 文档ID
"title": result_json.get("title"), # 文档标题
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
"score": metadata.get("score") # 相关性分数
}
# 返回提取的有用信息
return json.dumps(useful_info, ensure_ascii=False, indent=2)
except Exception as e:
# 捕获并处理所有可能的异常,返回错误信息
return f"Error: {str(e)}"

View File

@@ -0,0 +1,8 @@
"""
MatGL Tools Module
This module provides tools for material structure relaxation and property prediction
using MatGL (Materials Graph Library) models.
"""
from .matgl_tools import *

View File

@@ -0,0 +1,487 @@
"""
MatGL Tools Module
This module provides tools for material structure relaxation and property prediction
using MatGL (Materials Graph Library) models.
"""
from __future__ import annotations
from ...core.config import material_config
import warnings
import json
from typing import Dict, List, Union, Optional, Any
import torch
from pymatgen.core import Lattice, Structure
from pymatgen.ext.matproj import MPRester
from pymatgen.io.ase import AseAtomsAdaptor
import matgl
from matgl.ext.ase import Relaxer, MolecularDynamics, PESCalculator
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ...core.llm_tools import llm_tool
import os
from ..support.utils import read_structure_from_file_name_or_content_string
# To suppress warnings for clearer output
warnings.simplefilter("ignore")
@llm_tool(name="relax_crystal_structure_M3GNet",
description="Optimize crystal structure geometry using M3GNet universal potential from a structure file or content string")
async def relax_crystal_structure_M3GNet(
structure_source: str,
fmax: float = 0.01
) -> str:
"""
Optimize crystal structure geometry to find its equilibrium configuration.
Uses M3GNet universal potential for fast and accurate structure relaxation without DFT.
Accepts a structure file or content string.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
fmax: Maximum force tolerance for convergence in eV/Å (default: 0.01).
Returns:
A Markdown formatted string with the relaxation results or an error message.
"""
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# 使用pymatgen读取结构
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# Load the M3GNet universal potential model
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
# Create a relaxer and relax the structure
relaxer = Relaxer(potential=pot)
relax_results = relaxer.relax(structure, fmax=fmax)
# Get the relaxed structure
relaxed_structure = relax_results["final_structure"]
reduced_formula = relaxed_structure.composition.reduced_formula
# 添加结构信息
lattice_info = relaxed_structure.lattice
volume = relaxed_structure.volume
density = relaxed_structure.density
symmetry = relaxed_structure.get_space_group_info()
# 构建原子位置表格
sites_table = " # SP a b c\n"
sites_table += "--- ---- -------- -------- --------\n"
for i, site in enumerate(relaxed_structure):
frac_coords = site.frac_coords
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
return (f"## Structure Relaxation\n\n"
f"- **Structure**: `{reduced_formula}`\n"
f"- **Force Tolerance**: `{fmax} eV/Å`\n"
f"- **Status**: `Successfully relaxed`\n\n"
f"### Relaxed Structure Information\n\n"
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
f"- **Volume**: `{volume:.2f} ų`\n"
f"- **Density**: `{density:.2f} g/cm³`\n"
f"- **Lattice Parameters**:\n"
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
f"### Atomic Positions (Fractional Coordinates)\n\n"
f"```\n"
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
f"pbc : {relaxed_structure.lattice.pbc[0]!s:5s} {relaxed_structure.lattice.pbc[1]!s:5s} {relaxed_structure.lattice.pbc[2]!s:5s}\n"
f"Sites ({len(relaxed_structure)})\n"
f"{sites_table}```\n")
except Exception as e:
return f"Error during structure relaxation: {str(e)}"
# 内部函数,用于结构优化,返回结构对象而不是格式化字符串
async def _relax_crystal_structure_M3GNet_internal(
structure_source: str,
fmax: float = 0.01
) -> Union[Structure, str]:
"""
内部使用的结构优化函数,返回结构对象而不是格式化字符串。
Args:
structure_source: 结构文件名或内容字符串
fmax: 力收敛阈值 (eV/Å)
Returns:
优化后的结构对象或错误信息
"""
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# 使用pymatgen读取结构
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# Load the M3GNet universal potential model
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
# Create a relaxer and relax the structure
relaxer = Relaxer(potential=pot)
relax_results = relaxer.relax(structure, fmax=fmax)
# Get the relaxed structure
relaxed_structure = relax_results["final_structure"]
return relaxed_structure
except Exception as e:
return f"Error during structure relaxation: {str(e)}"
@llm_tool(name="predict_formation_energy_M3GNet",
description="Predict the formation energy of a crystal structure using the M3GNet formation energy model from a structure file or content string, with optional structure optimization")
async def predict_formation_energy_M3GNet(
structure_source: str,
optimize_structure: bool = True,
fmax: float = 0.01
) -> str:
"""
Predict the formation energy of a crystal structure using the M3GNet formation energy model.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
optimize_structure: Whether to optimize the structure before prediction (default: True).
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
Returns:
A Markdown formatted string containing the predicted formation energy in eV/atom or an error message.
"""
try:
# 获取结构(优化或不优化)
if optimize_structure:
# 使用内部函数优化结构
structure = await _relax_crystal_structure_M3GNet_internal(
structure_source=structure_source,
fmax=fmax
)
# 检查优化是否成功
if isinstance(structure, str) and structure.startswith("Error"):
return structure
else:
# 直接读取结构,不进行优化
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# 加载预训练模型
model = matgl.load_model("M3GNet-MP-2018.6.1-Eform")
# 预测形成能
eform = model.predict_structure(structure)
reduced_formula = structure.composition.reduced_formula
# 构建结果字符串
optimization_status = "optimized" if optimize_structure else "non-optimized"
# 添加结构信息
lattice_info = structure.lattice
volume = structure.volume
density = structure.density
symmetry = structure.get_space_group_info()
# 构建原子位置表格
sites_table = " # SP a b c\n"
sites_table += "--- ---- -------- -------- --------\n"
for i, site in enumerate(structure):
frac_coords = site.frac_coords
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
return (f"## Formation Energy Prediction\n\n"
f"- **Structure**: `{reduced_formula}`\n"
f"- **Structure Status**: `{optimization_status}`\n"
f"- **Formation Energy**: `{float(eform):.3f} eV/atom`\n\n"
f"### Structure Information\n\n"
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
f"- **Volume**: `{volume:.2f} ų`\n"
f"- **Density**: `{density:.2f} g/cm³`\n"
f"- **Lattice Parameters**:\n"
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
f"### Atomic Positions (Fractional Coordinates)\n\n"
f"```\n"
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
f"Sites ({len(structure)})\n"
f"{sites_table}```\n")
except Exception as e:
return f"Error: {str(e)}"
@llm_tool(name="run_molecular_dynamics_M3GNet",
description="Run molecular dynamics simulation on a crystal structure using M3GNet universal potential, with optional structure optimization")
async def run_molecular_dynamics_M3GNet(
structure_source: str,
temperature_K: float = 300,
steps: int = 100,
optimize_structure: bool = True,
fmax: float = 0.01
) -> str:
"""
Run molecular dynamics simulation on a crystal structure using M3GNet universal potential.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
temperature_K: Temperature for MD simulation in Kelvin (default: 300).
steps: Number of MD steps to run (default: 100).
optimize_structure: Whether to optimize the structure before simulation (default: True).
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
Returns:
A Markdown formatted string containing the simulation results, including final potential energy.
"""
try:
# 获取结构(优化或不优化)
if optimize_structure:
# 使用内部函数优化结构
structure = await _relax_crystal_structure_M3GNet_internal(
structure_source=structure_source,
fmax=fmax
)
# 检查优化是否成功
if isinstance(structure, str) and structure.startswith("Error"):
return structure
else:
# 直接读取结构,不进行优化
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# Load the M3GNet universal potential model
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
# Convert pymatgen structure to ASE atoms
ase_adaptor = AseAtomsAdaptor()
atoms = ase_adaptor.get_atoms(structure)
# Initialize the velocity according to Maxwell Boltzmann distribution
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
# Create the MD class and run simulation
driver = MolecularDynamics(atoms, potential=pot, temperature=temperature_K)
driver.run(steps)
# Get final potential energy
final_energy = atoms.get_potential_energy()
# Get final structure
final_structure = ase_adaptor.get_structure(atoms)
reduced_formula = final_structure.composition.reduced_formula
# 构建结果字符串
optimization_status = "optimized" if optimize_structure else "non-optimized"
# 添加结构信息
lattice_info = final_structure.lattice
volume = final_structure.volume
density = final_structure.density
symmetry = final_structure.get_space_group_info()
# 构建原子位置表格
sites_table = " # SP a b c\n"
sites_table += "--- ---- -------- -------- --------\n"
for i, site in enumerate(final_structure):
frac_coords = site.frac_coords
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
return (f"## Molecular Dynamics Simulation\n\n"
f"- **Structure**: `{reduced_formula}`\n"
f"- **Structure Status**: `{optimization_status}`\n"
f"- **Temperature**: `{temperature_K} K`\n"
f"- **Steps**: `{steps}`\n"
f"- **Final Potential Energy**: `{float(final_energy):.3f} eV`\n\n"
f"### Final Structure Information\n\n"
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
f"- **Volume**: `{volume:.2f} ų`\n"
f"- **Density**: `{density:.2f} g/cm³`\n"
f"- **Lattice Parameters**:\n"
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
f"### Atomic Positions (Fractional Coordinates)\n\n"
f"```\n"
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
f"pbc : {final_structure.lattice.pbc[0]!s:5s} {final_structure.lattice.pbc[1]!s:5s} {final_structure.lattice.pbc[2]!s:5s}\n"
f"Sites ({len(final_structure)})\n"
f"{sites_table}```\n")
except Exception as e:
return f"Error: {str(e)}"
@llm_tool(name="calculate_single_point_energy_M3GNet",
description="Calculate single point energy of a crystal structure using M3GNet universal potential, with optional structure optimization")
async def calculate_single_point_energy_M3GNet(
structure_source: str,
optimize_structure: bool = True,
fmax: float = 0.01
) -> str:
"""
Calculate single point energy of a crystal structure using M3GNet universal potential.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
optimize_structure: Whether to optimize the structure before calculation (default: True).
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
Returns:
A Markdown formatted string containing the calculated potential energy in eV.
"""
try:
# 获取结构(优化或不优化)
if optimize_structure:
# 使用内部函数优化结构
structure = await _relax_crystal_structure_M3GNet_internal(
structure_source=structure_source,
fmax=fmax
)
# 检查优化是否成功
if isinstance(structure, str) and structure.startswith("Error"):
return structure
else:
# 直接读取结构,不进行优化
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# Load the M3GNet universal potential model
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
# Convert pymatgen structure to ASE atoms
ase_adaptor = AseAtomsAdaptor()
atoms = ase_adaptor.get_atoms(structure)
# Set up the calculator for atoms object
calc = PESCalculator(pot)
atoms.set_calculator(calc)
# Calculate potential energy
energy = atoms.get_potential_energy()
reduced_formula = structure.composition.reduced_formula
# 构建结果字符串
optimization_status = "optimized" if optimize_structure else "non-optimized"
# 添加结构信息
lattice_info = structure.lattice
volume = structure.volume
density = structure.density
symmetry = structure.get_space_group_info()
# 构建原子位置表格
sites_table = " # SP a b c\n"
sites_table += "--- ---- -------- -------- --------\n"
for i, site in enumerate(structure):
frac_coords = site.frac_coords
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
return (f"## Single Point Energy Calculation\n\n"
f"- **Structure**: `{reduced_formula}`\n"
f"- **Structure Status**: `{optimization_status}`\n"
f"- **Potential Energy**: `{float(energy):.3f} eV`\n\n"
f"### Structure Information\n\n"
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
f"- **Volume**: `{volume:.2f} ų`\n"
f"- **Density**: `{density:.2f} g/cm³`\n"
f"- **Lattice Parameters**:\n"
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
f"### Atomic Positions (Fractional Coordinates)\n\n"
f"```\n"
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
f"Sites ({len(structure)})\n"
f"{sites_table}```\n")
except Exception as e:
return f"Error: {str(e)}"
#Error: Bad serialized model or bad model name. It is possible that you have an older model cached. Please clear your cache by running `python -c "import matgl; matgl.clear_cache()"`
# @llm_tool(name="predict_band_gap",
# description="Predict the band gap of a crystal structure using MEGNet multi-fidelity model from either a chemical formula or CIF file, with structure optimization")
# async def predict_band_gap(
# formula: str = None,
# cif_file_name: str = None,
# method: str = "PBE",
# fmax: float = 0.01
# ) -> str:
# """
# Predict the band gap of a crystal structure using the MEGNet multi-fidelity band gap model.
# First optimizes the crystal structure using M3GNet universal potential, then predicts
# the band gap based on the relaxed structure for more accurate results.
# Accepts either a chemical formula (searches Materials Project database) or a CIF file.
# Args:
# formula: Chemical formula to retrieve from Materials Project (e.g., "Fe2O3").
# cif_file_name: Name of CIF file in temp directory to use as structure source.
# method: The DFT method to use for the prediction. Options are "PBE", "GLLB-SC", "HSE", or "SCAN".
# Default is "PBE".
# fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
# Returns:
# A string containing the predicted band gap in eV or an error message.
# """
# try:
# # First, relax the crystal structure
# relaxed_result = await relax_crystal_structure(
# formula=formula,
# cif_file_name=cif_file_name,
# fmax=fmax
# )
# # Check if relaxation was successful
# if isinstance(relaxed_result, str) and relaxed_result.startswith("Error"):
# return relaxed_result
# # Use the relaxed structure for band gap prediction
# structure = relaxed_result
# if structure is None:
# return "Error: Failed to obtain a valid relaxed structure"
# # Load the pre-trained MEGNet band gap model
# model = matgl.load_model("MEGNet-MP-2019.4.1-BandGap-mfi")
# # Map method name to index
# method_map = {"PBE": 0, "GLLB-SC": 1, "HSE": 2, "SCAN": 3}
# if method not in method_map:
# return f"Error: Unsupported method: {method}. Choose from PBE, GLLB-SC, HSE, or SCAN."
# # Set the graph label based on the method
# graph_attrs = torch.tensor([method_map[method]])
# # Predict the band gap using the relaxed structure
# bandgap = model.predict_structure(structure=structure, state_attr=graph_attrs)
# reduced_formula = structure.reduced_formula
# # Return the band gap as a string
# return f"The predicted band gap for relaxed {reduced_formula} using {method} method is {float(bandgap):.3f} eV."
# except Exception as e:
# return f"Error: {str(e)}"

View File

@@ -0,0 +1,240 @@
import ast
import json
import logging
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
import tempfile
import os
import datetime
import asyncio
import zipfile
import shutil
import re
import multiprocessing
from multiprocessing import Process, Queue
from pathlib import Path
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
import logging
# 设置多进程启动方法为spawn解决CUDA初始化错误
try:
multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
# 如果已经设置过启动方法会抛出RuntimeError
pass
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
# 导入路径已更新
from ...core.llm_tools import llm_tool
from .mattergen_wrapper import *
# 使用mattergen_wrapper
import sys
import os
def convert_values(data_str):
"""
将字符串转换为字典
Args:
data_str: JSON字符串
Returns:
解析后的数据,如果解析失败则返回原字符串
"""
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return data_str # 如果无法解析为JSON返回原字符串
return data
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
"""
Preprocess a property value based on its name, converting it to the appropriate type.
Args:
property_name: Name of the property
property_value: Value of the property (can be string, float, or int)
Returns:
Tuple of (property_name, processed_value)
Raises:
ValueError: If the property value is invalid for the given property name
"""
valid_properties = [
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
"energy_above_hull", "formation_energy_per_atom", "space_group",
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
]
if property_name not in valid_properties:
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
# Process property_value if it's a string
if isinstance(property_value, str):
try:
# Try to convert string to float for numeric properties
if property_name != "chemical_system":
property_value = float(property_value)
except ValueError:
# If conversion fails, keep as string (for chemical_system)
pass
# Handle special cases for properties that need specific types
if property_name == "chemical_system":
if isinstance(property_value, (int, float)):
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
property_value = str(property_value)
elif property_name == "space_group" :
space_group = property_value
if space_group < 1 or space_group > 230:
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
return property_name, property_value
def main(
output_path: str,
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
model_path: str | None = None,
batch_size: int = 2,
num_batches: int = 1,
config_overrides: list[str] | None = None,
checkpoint_epoch: Literal["best", "last"] | int = "last",
properties_to_condition_on: TargetProperty | None = None,
sampling_config_path: str | None = None,
sampling_config_name: str = "default",
sampling_config_overrides: list[str] | None = None,
record_trajectories: bool = True,
diffusion_guidance_factor: float | None = None,
strict_checkpoint_loading: bool = True,
target_compositions: list[dict[str, int]] | None = None,
):
"""
Evaluate diffusion model against molecular metrics.
Args:
model_path: Path to DiffusionLightningModule checkpoint directory.
output_path: Path to output directory.
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
record: Whether to record the trajectories of the generated structures. (default: True)
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
Only supported for models trained for crystal structure prediction (CSP) (default: None)
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
"""
assert (
pretrained_name is not None or model_path is not None
), "Either pretrained_name or model_path must be provided."
assert (
pretrained_name is None or model_path is None
), "Only one of pretrained_name or model_path can be provided."
if not os.path.exists(output_path):
os.makedirs(output_path)
sampling_config_overrides = sampling_config_overrides or []
config_overrides = config_overrides or []
properties_to_condition_on = properties_to_condition_on or {}
target_compositions = target_compositions or []
if pretrained_name is not None:
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
pretrained_name, config_overrides=config_overrides
)
else:
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch=checkpoint_epoch,
config_overrides=config_overrides,
strict_checkpoint_loading=strict_checkpoint_loading,
)
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=properties_to_condition_on,
batch_size=batch_size,
num_batches=num_batches,
sampling_config_name=sampling_config_name,
sampling_config_path=_sampling_config_path,
sampling_config_overrides=sampling_config_overrides,
record_trajectories=record_trajectories,
diffusion_guidance_factor=(
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
),
target_compositions_dict=target_compositions,
)
generator.generate(output_dir=Path(output_path))
@llm_tool(name="generate_material_MatterGen", description="Generate crystal structures with optional property constraints using MatterGen model")
def generate_material_MatterGen(
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
) -> str:
"""
Generate crystal structures with optional property constraints.
This unified function can generate materials in three modes:
1. Unconditional generation (no properties specified)
2. Single property conditional generation (one property specified)
3. Multi-property conditional generation (multiple properties specified)
Args:
properties: Optional property constraints. Can be:
- None or empty dict for unconditional generation
- Dict with single key-value pair for single property conditioning
- Dict with multiple key-value pairs for multi-property conditioning
Valid property names include: "dft_band_gap", "chemical_system", etc.
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
Descriptive text with generated crystal structures in CIF format
"""
# 导入MatterGenService
from .mattergen_service import MatterGenService
logger.info("子进程成功导入MatterGenService")
# 获取MatterGenService实例
service = MatterGenService.get_instance()
logger.info("子进程成功获取MatterGenService实例")
# 使用服务生成材料
logger.info("子进程开始调用generate方法...")
result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
logger.info("子进程generate方法调用完成")
if "Error generating structures" in result:
return f"Error: Invalid properties {properties}."
else:
return result

View File

@@ -0,0 +1,466 @@
"""
MatterGen service for mars_toolkit.
This module provides a service for generating crystal structures using MatterGen.
The service initializes the CrystalGenerator once and reuses it for multiple
generation requests, improving performance.
"""
import datetime
import os
import logging
import json
from pathlib import Path
import re
from typing import Dict, Any, Optional, Union, List
import threading
import torch
from .mattergen_wrapper import *
from ...core.config import material_config
logger = logging.getLogger(__name__)
def format_cif_content(content):
"""
Format CIF content by removing unnecessary headers and organizing each CIF file.
Args:
content: String containing CIF content, possibly with PK headers
Returns:
Formatted string with each CIF file properly labeled and formatted
"""
# 如果内容为空,直接返回空字符串
if not content or content.strip() == '':
return ''
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
# 但我们需要保留这个字段在每个CIF文件中
cif_blocks = []
# 查找所有_chemical_formula_structural的位置
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
# 如果没有找到任何_chemical_formula_structural返回空字符串
if not formula_positions:
return ''
# 分割CIF块
for i in range(len(formula_positions)):
start_pos = formula_positions[i]
# 如果是最后一个块,结束位置是字符串末尾
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
cif_block = content[start_pos:end_pos].strip()
# 提取formula值
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
if formula_match:
formula = formula_match.group(1)
cif_blocks.append((formula, cif_block))
# 格式化输出
result = []
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]\n"
result.append(formatted)
return "\n".join(result)
def extract_cif_file_from_zip(cifs_zip_path: str):
"""
Extract CIF files from a zip archive, extract formula from each CIF file,
and save each CIF file with its formula as the filename.
Args:
cifs_zip_path: Path to the zip file
Returns:
list: List of tuples containing (index, formula, cif_path)
"""
result_dict = {}
if os.path.exists(cifs_zip_path):
with open(cifs_zip_path, 'rb') as f:
result_dict['cif_content'] = f.read()
cifs_content = format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))
pattern = r'\[cif (\d+) begin\]\n(.*?)\n\[cif \1 end\]'
matches = re.findall(pattern, cifs_content, re.DOTALL)
# 处理每个匹配项提取formula并保存CIF文件
saved_files = []
for idx, cif_content in matches:
# 提取data_{formula}中的formula
formula_match = re.search(r'data_([^\s]+)', cif_content)
if formula_match:
formula = formula_match.group(1)
# 构建保存路径
cif_path = os.path.join(material_config.TEMP_ROOT, f"{formula}.cif")
# 保存CIF文件
with open(cif_path, 'w') as f:
f.write(cif_content)
saved_files.append((idx, formula, cif_path))
return saved_files
class MatterGenService:
"""
Service for generating crystal structures using MatterGen.
This service initializes the CrystalGenerator once and reuses it for multiple
generation requests, improving performance.
"""
_instance = None
_lock = threading.Lock()
# 模型到GPU ID的映射
MODEL_TO_GPU = {
"mattergen_base": "0", # 基础模型使用GPU 0
"dft_mag_density": "1", # 磁密度模型使用GPU 1
"dft_bulk_modulus": "2", # 体积模量模型使用GPU 2
"dft_shear_modulus": "3", # 剪切模量模型使用GPU 3
"energy_above_hull": "4", # 能量模型使用GPU 4
"formation_energy_per_atom": "5", # 形成能模型使用GPU 5
"space_group": "6", # 空间群模型使用GPU 6
"hhi_score": "7", # HHI评分模型使用GPU 7
"ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0
"chemical_system": "1", # 化学系统模型使用GPU 1
"dft_band_gap": "2", # 带隙模型使用GPU 2
"dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3
"chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4
}
@classmethod
def get_instance(cls):
"""
Get the singleton instance of MatterGenService.
Returns:
MatterGenService: The singleton instance.
"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""
Initialize the MatterGenService.
This initializes the base generator without any property conditioning.
Specific generators for different property conditions will be initialized
on demand.
"""
self._generators = {}
self._output_dir = material_config.TEMP_ROOT
# 确保输出目录存在
if not os.path.exists(self._output_dir):
os.makedirs(self._output_dir)
# 初始化基础生成器(无条件生成)
self._init_base_generator()
def _init_base_generator(self):
"""
Initialize the base generator for unconditional generation.
"""
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
if not os.path.exists(model_path):
logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.")
return
logger.info(f"Initializing base MatterGen generator from {model_path}")
try:
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch="last",
config_overrides=[],
strict_checkpoint_loading=True,
)
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=None,
batch_size=2, # 默认值,可在生成时覆盖
num_batches=1, # 默认值,可在生成时覆盖
sampling_config_name="default",
sampling_config_path=None,
sampling_config_overrides=[],
record_trajectories=True,
diffusion_guidance_factor=0.0,
target_compositions_dict=[],
)
self._generators["base"] = generator
logger.info("Base MatterGen generator initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize base MatterGen generator: {e}")
def _get_or_create_generator(
self,
properties: Optional[Dict[str, Any]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
):
"""
Get or create a generator for the specified properties.
Args:
properties: Optional property constraints
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
tuple: (generator, generator_key, properties_to_condition_on, gpu_id)
"""
# 如果没有属性约束,使用基础生成器
if not properties:
if "base" not in self._generators:
self._init_base_generator()
gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0
return self._generators.get("base"), "base", None, gpu_id
# 处理属性约束
properties_to_condition_on = {}
for property_name, property_value in properties.items():
properties_to_condition_on[property_name] = property_value
# 确定模型目录
if len(properties) == 1:
# 单属性条件
property_name = list(properties.keys())[0]
property_to_model = {
"dft_mag_density": "dft_mag_density",
"dft_bulk_modulus": "dft_bulk_modulus",
"dft_shear_modulus": "dft_shear_modulus",
"energy_above_hull": "energy_above_hull",
"formation_energy_per_atom": "formation_energy_per_atom",
"space_group": "space_group",
"hhi_score": "hhi_score",
"ml_bulk_modulus": "ml_bulk_modulus",
"chemical_system": "chemical_system",
"dft_band_gap": "dft_band_gap"
}
model_dir = property_to_model.get(property_name, property_name)
generator_key = f"single_{property_name}"
else:
# 多属性条件
property_keys = set(properties.keys())
if property_keys == {"dft_mag_density", "hhi_score"}:
model_dir = "dft_mag_density_hhi_score"
generator_key = "multi_dft_mag_density_hhi_score"
elif property_keys == {"chemical_system", "energy_above_hull"}:
model_dir = "chemical_system_energy_above_hull"
generator_key = "multi_chemical_system_energy_above_hull"
else:
# 如果没有特定的多属性模型,使用第一个属性的模型
first_property = list(properties.keys())[0]
model_dir = first_property
generator_key = f"multi_{first_property}_etc"
# 获取对应的GPU ID
gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0
# 构建完整的模型路径
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, model_dir)
# 检查模型目录是否存在
if not os.path.exists(model_path):
# 如果特定模型不存在,回退到基础模型
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
generator_key = "base"
# 检查是否已经有这个生成器
if generator_key in self._generators:
# 更新生成器的参数
generator = self._generators[generator_key]
generator.batch_size = batch_size
generator.num_batches = num_batches
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
return generator, generator_key, properties_to_condition_on, gpu_id
# 创建新的生成器
try:
logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}")
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch="last",
config_overrides=[],
strict_checkpoint_loading=True,
)
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=properties_to_condition_on,
batch_size=batch_size,
num_batches=num_batches,
sampling_config_name="default",
sampling_config_path=None,
sampling_config_overrides=[],
record_trajectories=True,
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0,
target_compositions_dict=[],
)
self._generators[generator_key] = generator
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
return generator, generator_key, properties_to_condition_on, gpu_id
except Exception as e:
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
# 回退到基础生成器
if "base" not in self._generators:
self._init_base_generator()
base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0")
return self._generators.get("base"), "base", None, base_gpu_id
def generate(
self,
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
) -> str:
"""
Generate crystal structures with optional property constraints.
Args:
properties: Optional property constraints
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
str: Descriptive text with generated crystal structures in CIF format
"""
# 处理字符串输入(如果提供)
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError:
raise ValueError(f"Invalid properties JSON string: {properties}")
# 如果为None默认为空字典
properties = properties or {}
# 获取或创建生成器和GPU ID
generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator(
properties, batch_size, num_batches, diffusion_guidance_factor
)
print("gpu_id",gpu_id)
if generator is None:
return "Error: Failed to initialize MatterGen generator"
# 使用torch.cuda.set_device()直接设置当前GPU
try:
# 将字符串类型的gpu_id转换为整数
cuda_device_id = int(gpu_id)
torch.cuda.set_device(cuda_device_id)
logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}")
print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}")
except Exception as e:
logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.")
# 生成结构
try:
output_dir= Path(self._output_dir+f'/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}')
Path.mkdir(output_dir, parents=True, exist_ok=True)
generator.generate(output_dir=output_dir)
except Exception as e:
logger.error(f"Error generating structures: {e}")
return f"Error generating structures: {e}"
# 创建字典存储文件内容
result_dict = {}
# 定义文件路径
cif_zip_path = os.path.join(str(output_dir), f"generated_crystals_cif.zip")
xyz_file_path = os.path.join(str(output_dir), f"generated_crystals.extxyz")
trajectories_zip_path = os.path.join(str(output_dir), f"generated_trajectories.zip")
# 读取CIF压缩文件
if os.path.exists(cif_zip_path):
with open(cif_zip_path, 'rb') as f:
result_dict['cif_content'] = f.read()
# 根据生成类型创建描述性提示
if not properties:
generation_type = "unconditional"
title = "Generated Material Structures"
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
property_description = "unconditionally"
elif len(properties) == 1:
generation_type = "single_property"
property_name = list(properties.keys())[0]
property_value = properties[property_name]
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
property_description = f"conditioned on {property_name} = {property_value}"
else:
generation_type = "multi_property"
title = "Generated Material Structures Conditioned on Multiple Properties"
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
# 创建完整的提示
prompt = f"""
# {title}
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
{'' if generation_type == 'unconditional' else f'''
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
the generation adheres to the specified property values. Higher values produce samples that more
closely match the target properties but may reduce diversity.
'''}
## CIF Files (Crystallographic Information Files)
- Standard format for crystallographic structures
- Contains unit cell parameters, atomic positions, and symmetry information
- Used by crystallographic software and visualization tools
```
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
```
{description}
You can use these structures for materials discovery, property prediction, or further analysis.
"""
# print("prompt",prompt)
# 清理文件(读取后删除)
# try:
# if os.path.exists(cif_zip_path):
# os.remove(cif_zip_path)
# if os.path.exists(xyz_file_path):
# os.remove(xyz_file_path)
# if os.path.exists(trajectories_zip_path):
# os.remove(trajectories_zip_path)
# except Exception as e:
# logger.warning(f"Error cleaning up files: {e}")
# GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理
logger.info(f"Generation completed on GPU for model {generator_key}")
return prompt

View File

@@ -0,0 +1,26 @@
"""
This is a wrapper module that provides access to the mattergen modules
by modifying the Python path at runtime.
"""
import sys
import os
from pathlib import Path
from ...core.config import material_config
# Add the mattergen directory to the Python path
mattergen_dir = material_config.MATTERGEN_ROOT
sys.path.insert(0, mattergen_dir)
# Import the necessary modules from the mattergen package
try:
from mattergen import generator
from mattergen.common.data import chemgraph
from mattergen.common.data.types import TargetProperty
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
except ImportError as e:
print(f"Error importing mattergen modules: {e}")
print(f"Python path: {sys.path}")
raise
CrystalGenerator = generator.CrystalGenerator
# Re-export the modules
__all__ = ['generator', 'chemgraph', 'TargetProperty', 'MatterGenCheckpointInfo', 'PRETRAINED_MODEL_NAME','CrystalGenerator']

View File

@@ -0,0 +1,73 @@
"""
Property Prediction Module
This module provides functions for predicting properties of crystal structures.
"""
import asyncio
import torch
import numpy as np
from ase.units import GPa
from mattersim.forcefield import MatterSimCalculator
from ...core.llm_tools import llm_tool
from ..support.utils import convert_structure,read_structure_from_file_name_or_content_string
@llm_tool(
name="predict_properties_MatterSim",
description="Predict energy, forces, and stress of crystal structures using MatterSim model based on CIF string",
)
async def predict_properties_MatterSim(structure_source: str) -> str:
"""
Use MatterSim model to predict energy, forces, and stress of crystal structures.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string
Returns:
String containing prediction results
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_prediction():
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
structure_content,content_format=read_structure_from_file_name_or_content_string(structure_source)
structure = convert_structure(content_format, structure_content)
if structure is None:
return "Unable to parse CIF string. Please check if the format is correct."
# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
# 使用 MatterSimCalculator 计算属性
structure.calc = MatterSimCalculator(device=device)
# 直接获取能量、力和应力
energy = structure.get_potential_energy()
forces = structure.get_forces()
stresses = structure.get_stress(voigt=False)
# 计算每原子能量
num_atoms = len(structure)
energy_per_atom = energy / num_atoms
# 计算应力GPa和eV/A^3格式
stresses_ev_a3 = stresses
stresses_gpa = stresses / GPa
# 构建返回的提示信息
prompt = f"""
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
Prediction results using the provided CIF structure:
- Total Energy (eV): {energy}
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
"""
return prompt
# 异步执行预测操作
return await asyncio.to_thread(run_prediction)

View File

@@ -0,0 +1,42 @@
import os
from typing import List
from mp_api.client import MPRester
from ...core.config import material_config
async def get_mpid_from_formula(formula: str) -> List[str]:
"""
Get material IDs (mpid) from Materials Project database by chemical formula.
Returns mpids for the lowest energy structures.
Args:
formula: Chemical formula (e.g., "Fe2O3")
Returns:
List of material IDs
"""
os.environ['HTTP_PROXY'] = material_config.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] =material_config.HTTPS_PROXY or ''
try:
id_list = []
cleaned_formula = formula.replace(" ", "").replace("\n", "").replace("\'", "").replace("\"", "")
if "=" in cleaned_formula:
name, id = cleaned_formula.split("=")
else:
id = cleaned_formula
formula_list = [id]
with MPRester(material_config.MP_API_KEY) as mpr:
docs = mpr.materials.summary.search(formula=formula_list)
if not docs:
return "No materials found"
else:
for doc in docs:
id_list.append(doc.material_id)
return id_list
except Exception as e:
return f"Error: get_mpid_from_formula: {str(e)}"

View File

@@ -0,0 +1,168 @@
import glob
import json
from typing import Dict, Any, Union
from ...core.llm_tools import llm_tool
from .get_mp_id import get_mpid_from_formula
from ..support.utils import extract_cif_info, remove_symmetry_equiv_xyz
from ...core.config import material_config
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
@llm_tool(name="search_crystal_structures_from_materials_project",
description="Retrieve and optimize crystal structures from Materials Project database using a chemical formula")
async def search_crystal_structures_from_materials_project(
formula: str,
conventional_unit_cell: bool = True,
symprec: float = 0.1
) -> str:
"""
Retrieves crystal structures for a given chemical formula from Materials Project database and applies symmetry optimization.
Args:
formula: Chemical formula to search for (e.g., "Fe2O3")
conventional_unit_cell: If True, returns conventional unit cell; if False, returns primitive cell
symprec: Symmetry precision parameter for structure refinement (default: 0.1)
Returns:
Formatted CIF data for the retrieved crystal structures with symmetry analysis
"""
try:
structures = {}
mp_id_list = await get_mpid_from_formula(formula=formula)
if isinstance(mp_id_list, str):
return mp_id_list # 直接返回错误信息
for i, mp_id in enumerate(mp_id_list):
try:
# 文件操作可能引发异常
cif_files = glob.glob(material_config.LOCAL_MP_CIF_ROOT + f"/{mp_id}.cif")
if not cif_files:
continue # 如果没有找到文件跳过这个mp_id
cif_file = cif_files[0]
structure = Structure.from_file(cif_file)
# 结构处理可能引发异常
if conventional_unit_cell:
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
# 对结构进行对称化处理
sga = SpacegroupAnalyzer(structure, symprec=symprec)
symmetrized_structure = sga.get_refined_structure()
# 使用CifWriter生成CIF数据
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
cif_data = str(cif_writer)
# 删除CIF文件中的对称性操作部分
cif_data = remove_symmetry_equiv_xyz(cif_data)
cif_data = cif_data.replace('# generated using pymatgen', "")
# 生成一个唯一的键
formula_key = structure.composition.reduced_formula
key = f"{formula_key}_{i}"
structures[key] = cif_data
# 只保留前config.MP_TOPK个结果
if len(structures) >= material_config.MP_TOPK:
break
except (FileNotFoundError, IndexError) as file_error:
# 处理文件相关错误
continue # 跳过这个mp_id继续处理下一个
except ValueError as value_error:
# 处理结构处理中的值错误
continue # 跳过这个mp_id继续处理下一个
except Exception as process_error:
# 记录处理特定结构时的错误,但继续处理其他结构
print(f"Error processing structure {mp_id}: {str(process_error)}")
continue
# 如果没有成功处理任何结构
if not structures:
return f"No valid crystal structures found for formula: {formula}"
# 格式化结果为可读字符串
prompt = f"""
# Materials Project Symmetrized Crystal Structure Data
Below are symmetrized crystal structure data for {len(structures)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
"""
for i, (key, cif_data) in enumerate(structures.items(), 1):
prompt += f"[cif {i} begin]\n"
prompt += cif_data
prompt += f"\n[cif {i} end]\n\n"
return prompt
except Exception as e:
# 捕获整个函数执行过程中的任何未处理异常
return f"Error: An unexpected error occurred while processing crystal structures: {str(e)}"
@llm_tool(name="search_material_property_from_material_project",
description="Query material properties from Materials Project database using chemical formula")
async def search_material_property_from_materials_project(
formula: str,
) -> str:
"""
Retrieve detailed property data for materials matching a chemical formula from Materials Project database.
Args:
formula: Chemical formula of the material(s) to search for (e.g. 'Fe2O3', 'LiFePO4')
Returns:
Formatted string containing material properties including structure, electronic, thermodynamic and mechanical data
"""
# 获取MP ID列表
mp_id_list = await get_mpid_from_formula(formula=formula)
# 检查get_mpid_from_formula的返回值类型
# 如果返回的是字符串,说明发生了错误或没有找到材料
if isinstance(mp_id_list, str):
return mp_id_list # 直接返回错误信息
# 如果代码执行到这里说明mp_id_list是一个有效的ID列表
try:
# 获取材料属性
properties = []
for mp_id in mp_id_list:
try:
file_path = material_config.LOCAL_MP_PROPS_ROOT + f"/{mp_id}.json"
crystal_props = extract_cif_info(file_path, ['all_fields'])
properties.append(crystal_props)
except Exception as file_error:
# 记录单个文件处理错误但继续处理其他ID
continue
# 检查是否有结果
if len(properties) == 0:
return "No material properties found for the given formula, please try again."
# 只保留前MP_TOPK个结果
properties = properties[:material_config.MP_TOPK]
# 格式化结果
formatted_results = []
for i, item in enumerate(properties, 1):
formatted_result = f"[property {i} begin]\n"
formatted_result += json.dumps(item, indent=2)
formatted_result += f"\n[property {i} end]\n\n"
formatted_results.append(formatted_result)
# 将所有结果合并为一个字符串
res_chunk = "\n\n".join(formatted_results)
res_template = f"""
Here are the search material property from the Materials Project database:
Due to length limitations, only the top {len(properties)} results are shown below:\n
{res_chunk}
"""
return res_template
except Exception as e:
return f"Error: processing material properties: {str(e)}"

View File

@@ -0,0 +1,92 @@
import logging
import httpx
import pandas as pd
from bs4 import BeautifulSoup
from io import StringIO
from typing import Annotated, Any, Dict, List
import mcp.types as types
from ...core.llm_tools import llm_tool
@llm_tool(name="query_material_from_OQMD", description="Query material properties by chemical formula from OQMD database")
async def query_material_from_OQMD(
formula: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
) -> str:
"""
Query material information by chemical formula from OQMD database.
Args:
formula: Chemical formula of the material (e.g., Fe2O3, LiFePO4)
Returns:
Formatted text with material information and property tables
"""
# Fetch data from OQMD
url = f"https://www.oqmd.org/materials/composition/{formula}"
try:
async with httpx.AsyncClient(timeout=100.0) as client:
response = await client.get(url)
response.raise_for_status()
# Validate response content
if not response.text or len(response.text) < 100:
raise ValueError("Invalid response content from OQMD API")
# Parse HTML data
html = response.text
soup = BeautifulSoup(html, 'html.parser')
# Parse basic data
basic_data = []
h1_element = soup.find('h1')
if h1_element:
basic_data.append(h1_element.text.strip())
else:
basic_data.append(f"Material: {formula}")
for script in soup.find_all('p'):
if script:
combined_text = ""
for element in script.contents:
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
url = "https://www.oqmd.org" + element['href']
combined_text += f"[{element.text.strip()}]({url}) "
elif hasattr(element, 'text'):
combined_text += element.text.strip() + " "
else:
combined_text += str(element).strip() + " "
basic_data.append(combined_text.strip())
# Parse table data
table_data = ""
table = soup.find('table')
if table:
try:
df = pd.read_html(StringIO(str(table)))[0]
df = df.fillna('')
df = df.replace([float('inf'), float('-inf')], '')
table_data = df.to_markdown(index=False)
except Exception as e:
table_data = "Error: parsing table data"
# Integrate data into a single text
combined_text = "\n\n".join(basic_data)
if table_data:
combined_text += "\n\n## Material Properties Table\n\n" + table_data
return combined_text
except httpx.HTTPStatusError as e:
return f"Error: OQMD API request failed - {str(e)}"
except httpx.TimeoutException:
return "Error: OQMD API request timed out"
except httpx.NetworkError as e:
return f"Error: Network error occurred - {str(e)}"
except ValueError as e:
return f"Error: Invalid response content - {str(e)}"
except Exception as e:
return f"Error: Unexpected error occurred - {str(e)}"

View File

@@ -0,0 +1,95 @@
import os
import asyncio
from pymatgen.core import Structure
from ...core.config import material_config
from ...core.llm_tools import llm_tool
from ..support.utils import read_structure_from_file_name_or_content_string
@llm_tool(name="calculate_density_Pymatgen", description="Calculate the density of a crystal structure from a file or content string using Pymatgen")
async def calculate_density_Pymatgen(structure_source: str) -> str:
"""
Calculates the density of a structure from a file or content string.
Args:
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
Returns:
str: A Markdown formatted string with the density or an error message if the calculation fails.
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# # 使用pymatgen读取结构
structure = Structure.from_str(structure_content,fmt=content_format)
density = structure.density
# 删除临时文件
return (f"## Density Calculation\n\n"
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
f"- **Density**: `{density:.2f} g/cm³`\n")
except Exception as e:
return f"Error: error occurred while calculating density: {str(e)}\n"
@llm_tool(name="get_element_composition_Pymatgen", description="Analyze and retrieve the elemental composition of a crystal structure from a file or content string using Pymatgen")
async def get_element_composition_Pymatgen(structure_source: str) -> str:
"""
Returns the elemental composition of a structure from a file or content string.
Args:
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
Returns:
str: A Markdown formatted string with the elemental composition or an error message if the operation fails.
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# 使用pymatgen读取结构
structure = Structure.from_str(structure_content, fmt=content_format)
composition = structure.composition
return (f"## Element Composition\n\n"
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
f"- **Composition**: `{composition}`\n")
except Exception as e:
return f"Error: error occurred while getting element composition: {str(e)}\n"
@llm_tool(name="calculate_symmetry_Pymatgen", description="Determine the space group and symmetry operations of a crystal structure from a file or content string using Pymatgen")
async def calculate_symmetry_Pymatgen(structure_source: str) -> str:
"""
Calculates the symmetry of a structure from a file or content string.
Args:
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
Returns:
str: A Markdown formatted string with the symmetry information or an error message if the operation fails.
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# 使用pymatgen读取结构
structure = Structure.from_str(structure_content, fmt=content_format)
symmetry = structure.get_space_group_info()
return (f"## Symmetry Information\n\n"
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
f"- **Space Group**: `{symmetry[0]}`\n"
f"- **Number**: `{symmetry[1]}`\n")
except Exception as e:
return f"Error: error occurred while calculating symmetry: {str(e)}\n"

View File

View File

@@ -0,0 +1,212 @@
"""
CIF Utilities Module
This module provides basic functions for handling CIF (Crystallographic Information File) files,
which are commonly used in materials science for representing crystal structures.
"""
import json
import logging
import os
from ase.io import read
import tempfile
from typing import Optional, Tuple
from ase import Atoms
from ...core.config import material_config
logger = logging.getLogger(__name__)
def read_cif_txt_file(file_path):
"""
Read the CIF file and return its content.
Args:
file_path: Path to the CIF file
Returns:
String content of the CIF file or None if an error occurs
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"Error reading file {file_path}: {e}")
return None
def extract_cif_info(path: str, fields_name: list):
"""
Extract specific fields from the CIF description JSON file.
Args:
path: Path to the JSON file containing CIF information
fields_name: List of field categories to extract. Use 'all_fields' to extract all fields.
Other options include 'basic_fields', 'energy_electronic_fields', 'metal_magentic_fields'
Returns:
Dictionary containing the extracted fields
"""
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
selected_fields = []
if fields_name[0] == 'all_fields':
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
else:
for field in fields_name:
selected_fields.extend(locals().get(field, []))
with open(path, 'r') as f:
docs = json.load(f)
new_docs = {}
for field_name in selected_fields:
new_docs[field_name] = docs.get(field_name, '')
return new_docs
def remove_symmetry_equiv_xyz(cif_content):
"""
Remove symmetry operations section from CIF file content.
This is often useful when working with CIF files in certain visualization tools
or when focusing on the basic structure without symmetry operations.
Args:
cif_content: CIF file content string
Returns:
Cleaned CIF content string with symmetry operations removed
"""
lines = cif_content.split('\n')
output_lines = []
i = 0
while i < len(lines):
line = lines[i].strip()
# 检测循环开始
if line == 'loop_':
# 查看下一行,检查是否是对称性循环
next_lines = []
j = i + 1
while j < len(lines) and lines[j].strip().startswith('_'):
next_lines.append(lines[j].strip())
j += 1
# 检查是否包含对称性操作标签
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
# 跳过整个循环块
while i < len(lines):
if i + 1 >= len(lines):
break
next_line = lines[i + 1].strip()
# 检查是否到达下一个循环或数据块
if next_line == 'loop_' or next_line.startswith('data_'):
break
# 检查是否到达原子位置部分
if next_line.startswith('_atom_site_'):
break
i += 1
else:
# 不是对称性循环保留loop_行
output_lines.append(lines[i])
else:
# 非循环开始行,直接保留
output_lines.append(lines[i])
i += 1
return '\n'.join(output_lines)
def read_structure_from_file_name_or_content_string(file_name_or_content_string: str, format_type: str = "auto") -> Tuple[str, str]:
"""
处理结构输入,判断是文件名还是直接内容
当file_name_or_content_string被视为文件名时会在material_config.TEMP_ROOT目录下查找该文件。
这适用于大模型生成的临时文件,这些文件通常存储在临时目录中。
Args:
file_name_or_content_string: 文件名或结构内容字符串
format_type: 结构格式类型,"auto"表示自动检测
Returns:
tuple: (内容字符串, 实际格式类型)
"""
# 首先检查是否是完整路径的文件
if os.path.exists(file_name_or_content_string) and os.path.isfile(file_name_or_content_string):
# 是完整路径文件,读取文件内容
with open(file_name_or_content_string, 'r', encoding='utf-8') as f:
content = f.read()
# 如果格式为auto从文件扩展名推断
if format_type == "auto":
ext = os.path.splitext(file_name_or_content_string)[1].lower().lstrip('.')
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
else:
# 默认假设为CIF
format_type = 'cif'
else:
# 检查是否是临时目录中的文件名
temp_path = os.path.join(material_config.TEMP_ROOT, file_name_or_content_string)
if os.path.exists(temp_path) and os.path.isfile(temp_path):
# 是临时目录中的文件,读取文件内容
with open(temp_path, 'r', encoding='utf-8') as f:
content = f.read()
# 如果格式为auto从文件扩展名推断
if format_type == "auto":
ext = os.path.splitext(temp_path)[1].lower().lstrip('.')
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
else:
# 默认假设为CIF
format_type = 'cif'
else:
# 不是文件路径,假设是直接内容
content = file_name_or_content_string
# 如果格式为auto尝试从内容推断
if format_type == "auto":
# 简单启发式判断:
# CIF文件通常包含"data_"和"_cell_"
if "data_" in content and "_cell_" in content:
format_type = "cif"
# XYZ文件通常第一行是原子数量
elif content.strip().split('\n')[0].strip().isdigit():
format_type = "xyz"
# POSCAR/VASP格式通常第一行是注释
elif len(content.strip().split('\n')) > 5 and all(len(line.split()) == 3 for line in content.strip().split('\n')[2:5]):
format_type = "vasp"
# 默认假设为CIF
else:
format_type = "cif"
return content, format_type
def convert_structure(input_format: str='cif', content: str=None) -> Optional[Atoms]:
"""
将输入内容转换为Atoms对象
Args:
input_format: 输入格式 (cif, xyz, vasp等)
content: 结构内容字符串
Returns:
ASE Atoms对象如果转换失败则返回None
"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None

306
server.py Executable file
View File

@@ -0,0 +1,306 @@
"""Mars Toolkit MCP Server implementation."""
import anyio
import asyncio
import click
import json
import logging
import os
import sys
import traceback
from typing import Any, Dict, List, Optional, Union
import time
from prompts.material_synthesis import create_messages
# 添加mars_toolkit模块的路径
sys.path.append('/home/ubuntu/50T/lzy/mars-mcp')
import mcp.types as types
from mcp.server.lowlevel import Server
# 导入提示词处理器
#from prompts.material_synthesis import register_prompt_handlers
# 导入Mars Toolkit工具函数
try:
# 获取当前时间
from mars_toolkit.misc.misc_tools import get_current_time
# 网络搜索
from mars_toolkit.query.web_search import search_online
# 从Materials Project查询材料属性
from mars_toolkit.query.mp_query import search_material_property_from_material_project
# 从Materials Project获取晶体结构
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
# 从化学式获取Materials Project ID
from mars_toolkit.query.mp_query import get_mpid_from_formula
# 优化晶体结构
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
# 生成材料
from mars_toolkit.compute.material_gen import generate_material
# 从OQMD获取化学成分
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
# 从知识库检索
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
# 预测属性
from mars_toolkit.compute.property_pred import predict_properties
# 获取所有工具函数
from mars_toolkit import get_tools, get_tool_schemas
MARS_TOOLKIT_AVAILABLE = True
except ImportError as e:
print(f"警告: 无法导入Mars Toolkit: {e}", file=sys.stderr)
MARS_TOOLKIT_AVAILABLE = False
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
app = Server("mars-toolkit-server")
async def call_mars_toolkit_function(func_name: str, arguments: Dict[str, Any]) -> Any:
"""
调用Mars Toolkit中的工具函数
Args:
func_name: 工具函数名称
arguments: 工具函数参数
Returns:
工具函数的执行结果
"""
if not MARS_TOOLKIT_AVAILABLE:
raise ValueError("Mars Toolkit不可用")
# 获取所有注册的工具函数
tools = get_tools()
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
raise ValueError(f"函数 '{func_name}' 不存在于工具函数字典中")
# 获取对应的工具函数
tool_func = tools[func_name]
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
print("result1",result)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
return result
def get_tool_schemas_dict() -> Dict[str, Dict[str, Any]]:
"""
获取所有工具函数的模式字典
Returns:
工具函数名称到模式的映射字典
"""
if not MARS_TOOLKIT_AVAILABLE:
return {}
schemas = get_tool_schemas()
schemas_dict = {}
for schema in schemas:
func_name = schema["function"]["name"]
schemas_dict[func_name] = schema
return schemas_dict
@click.command()
@click.option("--port", default=5666, help="Port to listen on for SSE")
@click.option(
"--transport",
type=click.Choice(["stdio", "sse"]),
default="sse",
help="Transport type",
)
def main(port: int, transport: str='SSE') -> int:
"""
Mars Toolkit MCP Server主函数
Args:
port: SSE传输的端口号
transport: 传输类型stdio或sse
Returns:
退出码
"""
if not MARS_TOOLKIT_AVAILABLE:
print("错误: Mars Toolkit不可用请确保已正确安装", file=sys.stderr)
return 1
# 获取工具函数模式字典
schemas_dict = get_tool_schemas_dict()
# 注册提示词处理器
#register_prompt_handlers(app)
@app.list_prompts()
async def list_prompts() -> list[types.Prompt]:
return [
types.Prompt(
name="material_synthesis",
description="生成材料并设计合成方案使用mermaid绘制合成流程图",
arguments=[
types.PromptArgument(
name="properties",
description="材料性质及其值的JSON字符串例如 {\"dft_band_gap\": \"2.0\"}",
required=False,
),
types.PromptArgument(
name="batch_size",
description="生成材料的数量默认为2",
required=False,
),
],
)
]
@app.get_prompt()
async def get_prompt(
name: str, arguments: dict[str, str] | None = None
) -> types.GetPromptResult:
if name != "material_synthesis":
raise ValueError(f"未知的提示词: {name}")
if arguments is None:
arguments = {}
# 解析properties参数
properties = {}
if "properties" in arguments and arguments["properties"]:
try:
import json
properties = json.loads(arguments["properties"])
except json.JSONDecodeError:
properties = {}
# 解析batch_size参数
batch_size = 2 # 默认值
if "batch_size" in arguments and arguments["batch_size"]:
try:
batch_size = int(arguments["batch_size"])
except ValueError:
pass # 使用默认值
return types.GetPromptResult(
messages=create_messages(properties=properties, batch_size=batch_size),
description="生成材料并设计合成方案使用mermaid绘制合成流程图",
)
@app.call_tool()
async def call_tool(
name: str, arguments: Dict[str, Any]
) -> List[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""
调用工具函数
Args:
name: 工具函数名称
arguments: 工具函数参数
Returns:
工具函数的执行结果
"""
try:
print(f"调用{name},参数为{arguments}")
result = await call_mars_toolkit_function(name, arguments)
print("result",result)
# 将结果转换为字符串
if isinstance(result, (dict, list)):
result_str = json.dumps(result, ensure_ascii=False, indent=2)
else:
result_str = str(result)
return [types.TextContent(type="text", text=result_str)]
except Exception as e:
error_msg = f"调用工具函数 {name} 时出错: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
@app.list_tools()
async def list_tools() -> List[types.Tool]:
"""
列出所有可用的工具函数
Returns:
工具函数列表
"""
tools = []
print("列举所有可用的工具函数")
for func_name, schema in schemas_dict.items():
# 获取函数描述
description = schema["function"].get("description", f"Mars Toolkit工具: {func_name}")
# 获取参数模式
parameters = schema["function"].get("parameters", {})
# 创建工具
tool = types.Tool(
name=func_name,
description=description,
inputSchema=parameters,
)
tools.append(tool)
return tools
if transport == "sse":
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.routing import Mount, Route
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
import uvicorn
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
else:
from mcp.server.stdio import stdio_server
async def arun():
async with stdio_server() as streams:
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
anyio.run(arun)
return 0
if __name__ == "__main__":
print(get_tool_schemas_dict())
main()

437
test_tools/agent_test.py Executable file
View File

@@ -0,0 +1,437 @@
import asyncio
from api_key import *
from openai import OpenAI
import json
from typing import Dict, List, Any, Union, Optional
from rich.console import Console
import sys
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
from sci_mcp import *
# 获取工具
all_tools_schemas = get_all_tool_schemas()
tools = get_all_tools()
chemistry_tools = get_domain_tools("chemistry")
console = Console()
#console.print(all_tools_schemas)
class ModelAgent:
"""
只支持 gpt-4o 模型的代理类
处理返回值格式并提供统一的工具调用接口
"""
def __init__(self, model_name: str = "gpt-4o"):
"""
初始化模型客户端
Args:
model_name: 模型名称
"""
# 初始化客户端
self.client = OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_API_URL,
)
# 模型名称
self.model_name = model_name
# 定义工具列表
self.tools = all_tools_schemas
def get_response(self, messages: List[Dict[str, Any]]) -> Any:
"""
获取模型的响应
Args:
messages: 消息列表
Returns:
响应对象
"""
completion = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=self.tools,
tool_choice="auto",
temperature=0.6,
)
return completion
def extract_tool_calls(self, response: Any) -> Optional[List[Any]]:
"""
从响应中提取工具调用信息
Args:
response: 响应对象
Returns:
工具调用列表,如果没有则返回 None
"""
if hasattr(response.choices[0].message, 'tool_calls'):
return response.choices[0].message.tool_calls
return None
def extract_content(self, response: Any) -> str:
"""
从响应中提取内容
Args:
response: 响应对象
Returns:
内容字符串
"""
content = response.choices[0].message.content
return content if content is not None else ""
def extract_finish_reason(self, response: Any) -> str:
"""
从响应中提取完成原因
Args:
response: 响应对象
Returns:
完成原因
"""
return response.choices[0].finish_reason
async def call_tool(self, tool_name: str, tool_arguments: Dict[str, Any]) -> str:
"""
调用工具函数,支持同步和异步函数
Args:
tool_name: 工具名称
tool_arguments: 工具参数
Returns:
工具执行结果
"""
if tool_name in tools:
tool_function = tools[tool_name]
try:
# 检查函数是同步的还是异步的
import asyncio
import inspect
if asyncio.iscoroutinefunction(tool_function) or inspect.isawaitable(tool_function):
# 异步调用工具函数
tool_result = await tool_function(**tool_arguments)
else:
# 同步调用工具函数
tool_result = tool_function(**tool_arguments)
return tool_result
except Exception as e:
return f"工具调用错误: {str(e)}"
else:
return f"未找到工具: {tool_name}"
async def chat(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
"""
与模型对话,支持工具调用
Args:
messages: 初始消息列表
max_turns: 最大对话轮数
Returns:
最终回答
"""
current_messages = messages.copy()
turn = 0
while turn < max_turns:
turn += 1
console.print(f"\n[bold magenta]第 {turn} 轮对话[/bold magenta]")
# 获取响应
response = self.get_response(current_messages)
assistant_message = response.choices[0].message
# 将助手消息添加到上下文
current_messages.append(assistant_message.model_dump())
# 提取内容和工具调用
content = self.extract_content(response)
tool_calls = self.extract_tool_calls(response)
finish_reason = self.extract_finish_reason(response)
console.print(f"[green]助手回复:[/green] {content}")
# 如果没有工具调用或已完成,返回内容
if tool_calls is None or finish_reason != "tool_calls":
return content
# 处理工具调用
for tool_call in tool_calls:
tool_call_name = tool_call.function.name
tool_call_arguments = json.loads(tool_call.function.arguments)
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
# 执行工具调用
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
console.print(f"[blue]工具结果:[/blue] {tool_result}")
# 添加工具结果到上下文
current_messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call_name,
"content": tool_result
})
return "达到最大对话轮数限制"
async def main():
"""主函数"""
# 创建模型代理
agent = ModelAgent()
while True:
# 获取用户输入
user_input = input("\n请输入问题(输入 'exit' 退出): ")
if user_input.lower() == 'exit':
break
try:
# 使用 GPT-4o 模型
messages = [
{"role": "system", "content": "你是一个有用的助手。"},
{"role": "user", "content": user_input}
]
response = await agent.chat(messages)
console.print(f"\n[bold magenta]最终回答:[/bold magenta] {response}")
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
# 为每个工具函数生成的问题列表
def get_tool_questions():
"""
获取为每个工具函数生成的问题列表
这些问题设计为能够引导大模型调用相应的工具函数,但不直接命令模型调用特定工具
Returns:
包含工具名称和对应问题的字典
"""
questions = {
# 材料科学相关工具函数
"search_crystal_structures_from_materials_project": [
"我想了解氧化铁(Fe2O3)的晶体结构,能帮我找一下相关信息吗?",
"锂离子电池中常用的LiFePO4材料有什么样的晶体结构",
"能否帮我查询一下钙钛矿(CaTiO3)的晶体结构数据?"
],
"search_material_property_from_material_project": [
"二氧化钛(TiO2)有哪些重要的物理和化学性质?",
"我正在研究锂电池材料能告诉我LiCoO2的主要性质吗",
"硅(Si)的带隙和电子性质是什么?能帮我查一下详细数据吗?"
],
"query_material_from_OQMD": [
"能帮我从OQMD数据库中查询一下铝合金(Al-Cu)的形成能吗?",
"我想了解镍基超合金在OQMD数据库中的热力学稳定性数据",
"OQMD数据库中有关于锌氧化物(ZnO)的什么信息?"
],
"retrieval_from_knowledge_base": [
"有关高温超导体的最新研究进展是什么?",
"能否从材料科学知识库中找到关于石墨烯应用的信息?",
"我想了解钙钛矿太阳能电池的工作原理和效率限制"
],
"predict_properties": [
"这个化学式为Li2FeSiO4的材料可能有什么样的电子性质",
"能预测一下Na3V2(PO4)3这种材料的离子导电性吗",
"如果我设计一个新的钙钛矿结构,能预测它的稳定性和带隙吗?"
],
"generate_material": [
"能生成一种可能具有铁磁性的新材料结构吗?"
],
"optimize_crystal_structure": [
"我有一个CIF文件HEu2H3EuH2EuH5.cif能帮我优化一下使其更稳定吗"
],
"calculate_density": [
"我有一个CIF文件HEu2H3EuH2EuH5.cif能计算一下它的密度吗"
],
"get_element_composition": [
"我有一个CIF文件HEu2H3EuH2EuH5.cif能分析一下它的元素组成吗"
],
"calculate_symmetry": [
"我有一个CIF文件HEu2H3EuH2EuH5.cif能分析一下它的对称性和空间群吗"
],
# 化学相关工具函数
"search_pubchem_advanced": [
"阿司匹林的分子结构和性质是什么?"
],
"calculate_molecular_properties": [
"对乙酰氨基酚的SMILES是CC(=O)NC1=CC=C(C=C1)O计算它的物理化学性质"
],
"calculate_drug_likeness": [
"布洛芬的SMILES是CC(C)CC1=CC=C(C=C1)C(C)C(=O)O能计算它的药物性吗"
],
"calculate_topological_descriptors": [
"咖啡因的SMILES是CN1C=NC2=C1C(=O)N(C(=O)N2C)C能计算它的拓扑描述符吗"
],
"generate_molecular_fingerprints": [
"尼古丁的SMILES是CN1CCCC1C2=CN=CC=C2能为它生成Morgan指纹吗"
],
"calculate_molecular_similarity": [
"阿司匹林的SMILES是CC(=O)OC1=CC=CC=C1C(=O)O对乙酰氨基酚的SMILES是CC(=O)NC1=CC=C(C=C1)O它们的分子相似性如何"
],
"analyze_molecular_structure": [
"苯甲酸的SMILES是C1=CC=C(C=C1)C(=O)O能分析它的结构特征吗"
],
"generate_molecular_conformer": [
"甲基苯并噻唑的SMILES是CC1=NC2=CC=CC=C2S1能生成它的3D构象吗"
],
"identify_scaffolds": [
"奎宁的SMILES是COC1=CC2=C(C=CN=C2C=C1)C(C3CC4CCN3CC4C=C)O它的核心骨架是什么"
],
"convert_between_chemical_formats": [
"将乙醇的SMILESCCO转换为InChI格式"
],
"standardize_molecule": [
"将四环素的SMILESCC1C2C(C(=O)C3(C(CC4C(C3C(=O)C2C(=O)C(=C1O)C(=O)N)O)(C(=O)CO4)O)O)N(C)C标准化处理"
],
"enumerate_stereoisomers": [
"2-丁醇的SMILES是CCC(C)O它可能有哪些立体异构体,不要单纯靠你自身的知识,如果不确定可以使用工具。"
],
"perform_substructure_search": [
"在阿莫西林的SMILESCC1(C(N2C(S1)C(C2=O)NC(=O)C(C3=CC=C(C=C3)O)N)C(=O)O)C中搜索羧酸基团"
],
# RXN工具函数的测试问题
"predict_reaction_outcome": [
"我正在研究乙酸和乙醇的酯化反应想知道这个反应的产物是什么。反应物的SMILES表示法是'CC(=O)O.CCO'。能帮我预测一下这个反应最可能的结果吗?"
],
"predict_reaction_batch": [
"我在实验室中设计了三个酯化反应系列想同时了解它们的可能产物。这三个反应的SMILES分别是'CC(=O)O.CCO'(乙酸和乙醇)、'CC(=O)O.CCCO'(乙酸和丙醇)和'CC(=O)O.CCCCO'(乙酸和丁醇)。能否一次性预测这些反应的结果?"
],
"predict_reaction_topn": [
"我在研究丙烯醛和甲胺的反应机理这个反应可能有多种产物路径。反应物的SMILES是'C=CC=O.CN'。能帮我分析出最可能的前3种产物及它们的相对可能性吗"
],
"predict_retrosynthesis": [
"我需要为实验室合成阿司匹林但不确定最佳的合成路线。阿司匹林的SMILES是'CC(=O)OC1=CC=CC=C1C(=O)O'。能帮我分析一下可能的合成路径,将其分解为更简单的前体化合物吗?"
],
"predict_biocatalytic_retrosynthesis": [
"我们实验室正在研究绿色化学合成方法想知道是否可以使用酶催化方式合成这个含溴的芳香化合物SMILES: 'OC1C(O)C=C(Br)C=C1')。能帮我分析可能的生物催化合成路径吗?"
],
"predict_reaction_properties": [
"我正在研究这个有机反应的机理:'CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F'特别想了解反应中的原子映射关系。能帮我分析一下反应前后各原子的对应关系吗我需要atom-mapping属性。"
],
"extract_reaction_actions": [
"我从一篇有机合成文献中找到了这段实验步骤:'To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.' 能帮我将这段文本转换为结构化的反应步骤吗?这样我可以更清晰地理解每个操作。"
]
}
return questions
# 测试特定工具函数的问题
async def test_tool_with_question(question_index: int = 0):
"""
使用预设问题测试特定工具函数
Args:
question_index: 问题索引默认为0
"""
# 获取所有工具问题
all_questions = get_tool_questions()
# 创建工具名称到问题的映射
tool_questions = {}
for tool_name, questions in all_questions.items():
if questions:
tool_questions[tool_name] = questions[min(question_index, len(questions)-1)]
# 打印可用的工具和问题
console.print("[bold]可用的工具和问题:[/bold]")
for i, (tool_name, question) in enumerate(tool_questions.items(), 1):
console.print(f"{i}. [cyan]{tool_name}[/cyan]: {question}")
# 选择要测试的工具
choice = input("\n请选择要测试的工具编号(输入'all'测试所有工具): ")
agent = ModelAgent()
if choice.lower() == 'all':
# 测试所有工具
for tool_name, question in tool_questions.items():
console.print(f"\n[bold]测试工具: [cyan]{tool_name}[/cyan][/bold]")
console.print(f"问题: {question}")
messages = [
{"role": "system", "content": "你是一个有用的助手。"},
{"role": "user", "content": question}
]
try:
response = await agent.chat(messages)
#console.print(f"[green]回答:[/green] {response[:200]}...")
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
else:
try:
index = int(choice) - 1
if 0 <= index < len(tool_questions):
tool_name = list(tool_questions.keys())[index]
question = tool_questions[tool_name]
console.print(f"\n[bold]测试工具: [cyan]{tool_name}[/cyan][/bold]")
console.print(f"问题: {question}")
messages = [
{"role": "system", "content": "你是一个有用的助手。"},
{"role": "user", "content": question+'如果你不确定答案,请使用工具'}
]
response = await agent.chat(messages)
#console.print(f"[green]回答:[/green] {response}")
else:
console.print("[bold red]无效的选择[/bold red]")
except ValueError:
console.print("[bold red]请输入有效的数字[/bold red]")
if __name__ == "__main__":
# 取消注释以运行主函数
# asyncio.run(main())
# 取消注释以测试工具函数问题
asyncio.run(test_tool_with_question())
# pass
# 知识检索API的接口 数据库

8
test_tools/api_key.py Executable file
View File

@@ -0,0 +1,8 @@
OPENAI_API_KEY='sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d'
OPENAI_API_URL='https://vip.apiyi.com/v1'
#OPENAI_API_KEY='gpustack_56f0adc61a865d22_c61cdbf601fa2cb95979d417618060e6'
#OPENAI_API_URL='http://192.168.191.100:5080/v1'

View File

@@ -0,0 +1,123 @@
"""
Test script for PubChem tools
This script tests the search_pubchem_advanced function from the chemistry_mcp module.
"""
import sys
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
import asyncio
from sci_mcp.chemistry_mcp.pubchem_tools.pubchem_tools import _search_by_formula
from sci_mcp.chemistry_mcp import search_pubchem_advanced
async def test_search_by_name():
"""Test searching compounds by name"""
print("\n=== Testing search by name ===")
result = await search_pubchem_advanced(name="Aspirin")
print(result)
async def test_search_by_smiles():
"""Test searching compounds by SMILES notation"""
print("\n=== Testing search by SMILES ===")
# SMILES for Caffeine
result = await search_pubchem_advanced(smiles="CN1C=NC2=C1C(=O)N(C(=O)N2C)C")
print(result)
async def test_search_by_formula():
"""Test searching compounds by molecular formula"""
print("\n=== Testing search by formula ===")
# Formula for Aspirin
result = await search_pubchem_advanced(formula="C9H8O4", max_results=2)
print(result)
async def test_complex_formula():
"""Test searching with a more complex formula that might cause timeout"""
print("\n=== Testing complex formula search ===")
# A more complex formula that might return many results
result = await search_pubchem_advanced(
formula="C6H12O6", # Glucose and isomers
max_results=5
)
print(result)
async def test_complex_molecules():
"""Test searching for complex molecules with rich molecular features"""
print("\n=== Testing complex molecules with rich features ===")
# 1. Paclitaxel (Taxol) - Complex anticancer drug with many rotatable bonds and H-bond donors/acceptors
print("\n--- Testing Paclitaxel (anticancer drug) ---")
result = await search_pubchem_advanced(name="Paclitaxel")
print(result)
# 2. Vancomycin - Complex antibiotic with many H-bond donors/acceptors
print("\n--- Testing Vancomycin (antibiotic) ---")
result = await search_pubchem_advanced(name="Vancomycin")
print(result)
# 3. Cholesterol - Steroid with complex ring structure
print("\n--- Testing Cholesterol (steroid) ---")
result = await search_pubchem_advanced(name="Cholesterol")
print(result)
# 4. Ibuprofen - Common NSAID with rotatable bonds
print("\n--- Testing Ibuprofen (NSAID) ---")
result = await search_pubchem_advanced(name="Ibuprofen")
print(result)
# 5. Amoxicillin - Antibiotic with multiple functional groups
print("\n--- Testing Amoxicillin (antibiotic) ---")
result = await search_pubchem_advanced(name="Amoxicillin")
print(result)
async def test_molecules_by_smiles():
"""Test searching for complex molecules using SMILES notation"""
print("\n=== Testing complex molecules by SMILES ===")
# 1. Atorvastatin (Lipitor) - Cholesterol-lowering drug with complex structure
print("\n--- Testing Atorvastatin (Lipitor) ---")
result = await search_pubchem_advanced(
smiles="CC(C)C1=C(C(=C(C=C1)C(C)C)C2=CC(=C(C=C2)F)F)C(CC(CC(=O)O)O)NC(=O)C3=CC=C(C=C3)F"
)
print(result)
# 2. Morphine - Opioid with multiple rings and H-bond features
print("\n--- Testing Morphine ---")
result = await search_pubchem_advanced(
smiles="CN1CCC23C4C1CC5=C2C(=C(C=C5)O)OC3C(C=C4)O"
)
print(result)
async def test_invalid_search():
"""Test searching with invalid parameters"""
print("\n=== Testing invalid search ===")
# No parameters provided
result = await search_pubchem_advanced()
print(result)
# Invalid SMILES
print("\n=== Testing invalid SMILES ===")
result = await search_pubchem_advanced(smiles="INVALID_SMILES_STRING")
print(result)
async def run_all_tests():
"""Run all test functions"""
await test_search_by_name()
await test_search_by_smiles()
await test_search_by_formula()
await test_complex_formula()
# await test_complex_molecules()
# await test_molecules_by_smiles()
#await test_invalid_search()
# from sci_mcp.chemistry_mcp.pubchem_tools.pubchem_tools import _search_by_name
# compounds=await _search_by_formula('C6H12O6')
# print(compounds[0])
if __name__ == "__main__":
print("Testing PubChem search tools...")
asyncio.run(run_all_tests())
print("\nAll tests completed.")
# import pubchempy
# compunnds = pubchempy.get_compounds('Aspirin', 'name')
# print(compunnds[0].to_dict())

View File

@@ -0,0 +1,159 @@
"""
Test script for RDKit tools.
This script tests the functionality of the RDKit tools implemented in the
sci_mcp/chemistry_mcp/rdkit_tools module.
"""
import sys
import os
# Add the project root directory to the Python path
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
from sci_mcp.chemistry_mcp.rdkit_tools.rdkit_tools import (
calculate_molecular_properties,
calculate_drug_likeness,
calculate_topological_descriptors,
generate_molecular_fingerprints,
calculate_molecular_similarity,
analyze_molecular_structure,
generate_molecular_conformer,
identify_scaffolds,
convert_between_chemical_formats,
standardize_molecule,
enumerate_stereoisomers,
perform_substructure_search
)
def test_molecular_properties():
"""Test the calculation of molecular properties."""
print("Testing calculate_molecular_properties...")
# Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
result = calculate_molecular_properties(smiles)
print(result)
print("-" * 80)
def test_drug_likeness():
"""Test the calculation of drug-likeness properties."""
print("Testing calculate_drug_likeness...")
# Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
result = calculate_drug_likeness(smiles)
print(result)
print("-" * 80)
def test_topological_descriptors():
"""Test the calculation of topological descriptors."""
print("Testing calculate_topological_descriptors...")
# Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
result = calculate_topological_descriptors(smiles)
print(result)
print("-" * 80)
def test_molecular_fingerprints():
"""Test the generation of molecular fingerprints."""
print("Testing generate_molecular_fingerprints...")
# Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
result = generate_molecular_fingerprints(smiles, fingerprint_type="morgan")
print(result)
print("-" * 80)
def test_molecular_similarity():
"""Test the calculation of molecular similarity."""
print("Testing calculate_molecular_similarity...")
# Aspirin and Ibuprofen
smiles1 = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin
smiles2 = "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" # Ibuprofen
result = calculate_molecular_similarity(smiles1, smiles2)
print(result)
print("-" * 80)
def test_molecular_structure():
"""Test the analysis of molecular structure."""
print("Testing analyze_molecular_structure...")
# Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
result = analyze_molecular_structure(smiles)
print(result)
print("-" * 80)
def test_molecular_conformer():
"""Test the generation of molecular conformers."""
print("Testing generate_molecular_conformer...")
# Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
result = generate_molecular_conformer(smiles)
print(result)
print("-" * 80)
def test_scaffolds():
"""Test the identification of molecular scaffolds."""
print("Testing identify_scaffolds...")
# Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
result = identify_scaffolds(smiles)
print(result)
print("-" * 80)
def test_format_conversion():
"""Test the conversion between chemical formats."""
print("Testing convert_between_chemical_formats...")
# Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
result = convert_between_chemical_formats(smiles, "smiles", "inchi")
print(result)
print("-" * 80)
def test_standardize_molecule():
"""Test the standardization of molecules."""
print("Testing standardize_molecule...")
# Betaine with charges
smiles = "C[N+](C)(C)CC(=O)[O-]"
result = standardize_molecule(smiles)
print(result)
print("-" * 80)
def test_stereoisomers():
"""Test the enumeration of stereoisomers."""
print("Testing enumerate_stereoisomers...")
# 3-penten-2-ol (has both a stereocenter and a stereobond)
smiles = "CC(O)C=CC"
result = enumerate_stereoisomers(smiles)
print(result)
print("-" * 80)
def test_substructure_search():
"""Test the substructure search."""
print("Testing perform_substructure_search...")
# Aspirin, search for carboxylic acid group
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
pattern = "C(=O)O"
result = perform_substructure_search(smiles, pattern)
print(result)
print("-" * 80)
def main():
"""Run all tests."""
print("Testing RDKit tools...\n")
# Uncomment the tests you want to run
test_molecular_properties()
test_drug_likeness()
test_topological_descriptors()
test_molecular_fingerprints()
test_molecular_similarity()
test_molecular_structure()
test_molecular_conformer()
test_scaffolds()
test_format_conversion()
test_standardize_molecule()
test_stereoisomers()
test_substructure_search()
print("All tests completed.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,289 @@
"""
测试RXN工具函数模块
此模块包含用于测试rxn_tools模块中化学反应预测和分析工具函数的测试用例。
"""
import asyncio
import sys
import os
from pathlib import Path
from rich.console import Console
# 添加项目根目录到Python路径
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
from sci_mcp.chemistry_mcp.rxn_tools.rxn_tools import (
predict_reaction_outcome_rxn,
predict_reaction_topn_rxn,
predict_reaction_properties_rxn,
extract_reaction_actions_rxn
)
# 创建控制台对象用于格式化输出
console = Console()
async def test_predict_reaction_outcome():
"""测试反应结果预测功能"""
console.print("[bold cyan]测试反应结果预测功能[/bold cyan]")
# 使用固定参数:溴和蒽的反应
fixed_reactants = "BrBr.c1ccc2cc3ccccc3cc2c1"
console.print(f"固定反应物: {fixed_reactants}")
try:
result = await predict_reaction_outcome(fixed_reactants)
console.print("[green]结果:[/green]")
console.print(result)
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
async def test_predict_reaction_topn():
"""测试多产物预测功能"""
console.print("\n[bold cyan]测试多产物预测功能[/bold cyan]")
# 测试1单个反应字符串格式
fixed_reactants = "C=CC=O.CN" # 丙烯醛和甲胺
fixed_topn = 2
console.print(f"测试1 - 单个反应(字符串格式)")
console.print(f"固定反应物: {fixed_reactants}")
console.print(f"固定预测产物数量: {fixed_topn}")
try:
result = await predict_reaction_topn(fixed_reactants, fixed_topn)
console.print("[green]结果:[/green]")
console.print(result)
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
# 测试2单个反应列表格式
fixed_reactants_list = ["BrBr", "c1ccc2cc3ccccc3cc2c1"] # 溴和蒽
console.print(f"\n测试2 - 单个反应(列表格式)")
console.print(f"固定反应物: {fixed_reactants_list}")
console.print(f"固定预测产物数量: {fixed_topn}")
try:
result = await predict_reaction_topn(fixed_reactants_list, fixed_topn)
console.print("[green]结果:[/green]")
console.print(result)
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
# 测试3多个反应列表的列表格式
fixed_reactants_batch = [
["BrBr", "c1ccc2cc3ccccc3cc2c1"], # 溴和蒽
["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"] # 溴和修饰的蒽
]
console.print(f"\n测试3 - 多个反应(列表的列表格式)")
console.print(f"固定反应物批量: {fixed_reactants_batch}")
console.print(f"固定预测产物数量: {fixed_topn}")
try:
result = await predict_reaction_topn(fixed_reactants_batch, fixed_topn)
console.print("[green]结果:[/green]")
console.print(result)
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
console.print(f"[bold red]错误:[/bold red] {str(e)}")
async def test_predict_retrosynthesis():
"""测试逆合成分析功能"""
console.print("\n[bold cyan]测试逆合成分析功能[/bold cyan]")
# 使用固定参数:阿司匹林的逆合成分析
fixed_target_molecule = "CC(=O)OC1=CC=CC=C1C(=O)O" # 阿司匹林
fixed_max_steps = 1
console.print(f"固定目标分子: {fixed_target_molecule}")
console.print(f"固定最大步骤数: {fixed_max_steps}")
try:
result = await predict_retrosynthesis(fixed_target_molecule, fixed_max_steps)
console.print("[green]结果:[/green]")
console.print(result)
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
async def test_predict_biocatalytic_retrosynthesis():
"""测试生物催化逆合成分析功能"""
console.print("\n[bold cyan]测试生物催化逆合成分析功能[/bold cyan]")
# 使用固定参数:一个可能适合酶催化的分子
fixed_target_molecule = "OC1C(O)C=C(Br)C=C1"
console.print(f"固定目标分子: {fixed_target_molecule}")
try:
result = await predict_biocatalytic_retrosynthesis(fixed_target_molecule)
console.print("[green]结果:[/green]")
console.print(result)
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
async def test_predict_reaction_properties():
"""测试反应属性预测功能"""
console.print("\n[bold cyan]测试反应属性预测功能[/bold cyan]")
# 使用固定参数:原子映射
fixed_reaction = "CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F"
fixed_property_type = "atom-mapping"
console.print(f"固定反应: {fixed_reaction}")
console.print(f"固定属性类型: {fixed_property_type}")
try:
result = await predict_reaction_properties(fixed_reaction, fixed_property_type)
console.print("[green]结果:[/green]")
console.print(result)
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
async def test_extract_reaction_actions():
"""测试从文本提取反应步骤功能"""
console.print("\n[bold cyan]测试从文本提取反应步骤功能[/bold cyan]")
# 使用固定参数:从文本描述中提取反应步骤
fixed_reaction_text = """To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour."""
console.print(f"固定反应文本: {fixed_reaction_text}")
try:
result = await extract_reaction_actions(fixed_reaction_text)
console.print("[green]结果:[/green]")
console.print(result)
except Exception as e:
console.print(f"[bold red]错误:[/bold red] {str(e)}")
async def test_all():
"""测试所有RXN工具函数"""
console.print("[bold magenta]===== 开始测试RXN工具函数 =====[/bold magenta]\n")
# 测试各个功能
await test_predict_reaction_outcome()
await test_predict_reaction_topn()
await test_predict_retrosynthesis()
await test_predict_biocatalytic_retrosynthesis()
await test_predict_reaction_properties()
await test_extract_reaction_actions()
console.print("\n[bold magenta]===== RXN工具函数测试完成 =====[/bold magenta]")
def get_rxn_tool_questions():
"""
获取为RXN工具函数生成的问题列表
这些问题设计为能够引导大模型调用相应的工具函数
Returns:
包含工具名称和对应问题的字典
"""
questions = {
"predict_reaction_outcome": [
"如果我将溴和蒽混合在一起,会形成什么产物?",
"乙酸和乙醇反应会生成什么?",
"预测一下丙烯醛和甲胺反应的结果"
],
"predict_reaction_topn": [
"丙烯醛和甲胺反应可能生成哪几种主要产物?",
"预测溴和蒽反应可能的前3个产物",
"乙酸和乙醇反应可能有哪些不同的结果?请给出最可能的几种产物"
],
"predict_retrosynthesis": [
"如何合成阿司匹林?请给出可能的合成路线",
"对于分子CC(=O)OC1=CC=CC=C1C(=O)O有哪些可能的合成路径",
"请分析一下布洛芬的可能合成路线"
],
"predict_biocatalytic_retrosynthesis": [
"有没有可能用酶催化合成OC1C(O)C=C(Br)C=C1这个分子",
"请提供一种使用生物催化方法合成对羟基苯甲醇的路线",
"我想用酶催化方法合成一些复杂分子能否分析一下OC1C(O)C=C(Br)C=C1的可能合成路径"
],
"predict_reaction_properties": [
"在这个反应中CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F原子是如何映射的",
"这个反应的产率可能是多少Clc1ccccn1.Cc1ccc(N)cc1>>Cc1ccc(Nc2ccccn2)cc1",
"能分析一下这个反应中原子的去向吗CC(=O)O.CCO>>CC(=O)OCC"
],
"extract_reaction_actions": [
"能否将这段实验描述转换为结构化的反应步骤?'To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.'",
"请从这段文本中提取出具体的反应操作步骤:'A solution of benzoic acid (1.0 g, 8.2 mmol) in thionyl chloride (10 mL) was heated under reflux for 2 hours. The excess thionyl chloride was removed under reduced pressure to give benzoyl chloride as a colorless liquid.'",
"帮我解析这个实验步骤,提取出关键操作:'The aldehyde (5 mmol) was dissolved in methanol (20 mL) and sodium borohydride (7.5 mmol) was added portionwise at 0°C. The mixture was allowed to warm to room temperature and stirred for 3 hours.'"
]
}
return questions
def update_agent_test_questions():
"""
更新agent_test.py中的工具问题字典添加RXN工具函数的问题
"""
try:
# 获取agent_test.py文件路径
agent_test_path = Path('/home/ubuntu/sas0/lzy/multi_mcp_server/test_tools/agent_test.py')
# 读取文件内容
with open(agent_test_path, 'r') as f:
content = f.read()
# 获取RXN工具函数的问题
rxn_questions = get_rxn_tool_questions()
# 检查文件中是否已包含RXN工具函数的问题
rxn_tools_exist = any(tool in content for tool in rxn_questions.keys())
if not rxn_tools_exist:
# 找到questions字典的结束位置
dict_end_pos = content.find(' return questions')
if dict_end_pos != -1:
# 构建要插入的RXN工具函数问题
rxn_questions_str = ""
for tool_name, questions_list in rxn_questions.items():
rxn_questions_str += f'\n "{tool_name}": [\n'
for q in questions_list:
rxn_questions_str += f' "{q}",\n'
rxn_questions_str += ' ],'
# 在字典结束前插入RXN工具函数问题
new_content = content[:dict_end_pos] + rxn_questions_str + content[dict_end_pos:]
# 写回文件
with open(agent_test_path, 'w') as f:
f.write(new_content)
console.print("[green]成功更新agent_test.py添加了RXN工具函数的测试问题[/green]")
else:
console.print("[yellow]无法找到questions字典的结束位置未更新agent_test.py[/yellow]")
else:
console.print("[yellow]agent_test.py中已包含RXN工具函数的问题无需更新[/yellow]")
except Exception as e:
console.print(f"[bold red]更新agent_test.py时出错:[/bold red] {str(e)}")
if __name__ == "__main__":
# 运行所有测试
# asyncio.run(test_all())
# # 更新agent_test.py中的工具问题
# update_agent_test_questions()
api_key = 'apk-8928522a146c2503f30b16d9909222d7583f412ee8f1049f08d32a089ba88d34'
from rxn4chemistry import RXN4ChemistryWrapper
rxn4chemistry_wrapper = RXN4ChemistryWrapper(api_key=api_key)
rxn4chemistry_wrapper.create_project('test_wrapper')
response = rxn4chemistry_wrapper.predict_automatic_retrosynthesis(
'Brc1c2ccccc2c(Br)c2ccccc12')
results = rxn4chemistry_wrapper.get_predict_automatic_retrosynthesis_results(response['prediction_id'])
print(results['status'])
# NOTE: upon 'SUCCESS' you can inspect the predicted retrosynthetic paths.
print(results['retrosynthetic_paths'][0])

View File

@@ -0,0 +1,55 @@
import asyncio
import sys
import os
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
from test_tools.multi_round_conversation import process_conversation_round
# 初始化rich控制台
console = Console()
# 设计一个简单但需要多轮查询的问题可能会调用mattergen
# complex_question = """我想了解LiFePO4材料在不同温度下的性能变化。请先告诉我这种材料的基本结构特性。"""
# 设计一个不调用mattergen但仍然可以触发多轮工具调用的问题之前的尝试
# complex_question = """我想比较TiO2和ZnO这两种材料作为光催化剂的性能。请先告诉我TiO2的晶体结构和能带特性。"""
# 设计一个需要先获取信息然后基于这些信息进行进一步分析的问题
complex_question = """我需要分析一种名为Na2Fe2(SO4)3的钠离子电池材料。请先查询这种材料的晶体结构。"""
def run_complex_query():
"""运行复杂的材料科学查询演示"""
console.print(Panel.fit(
"[bold cyan]复杂材料科学查询演示[/bold cyan] - 测试多轮对话逻辑",
border_style="cyan"
))
# 处理复杂问题
conversation_history = process_conversation_round(complex_question)
# 多轮对话循环
while True:
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit''quit' 退出[/bold cyan]")
user_input = input("> ")
# 检查是否退出
if user_input.lower() in ['exit', 'quit', '退出']:
console.print("[bold cyan]演示结束,再见![/bold cyan]")
break
# 处理用户输入
conversation_history = process_conversation_round(user_input, conversation_history)
if __name__ == '__main__':
try:
run_complex_query()
except KeyboardInterrupt:
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
except Exception as e:
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
import traceback
console.print(traceback.format_exc())

View File

@@ -0,0 +1,375 @@
import asyncio
import json
import sys
import os
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
from rich import box
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}]
# 初始化rich控制台
console = Console()
# 获取工具模式和映射
tools_schemas = get_domain_tool_schemas(["material", 'general'])
tool_map = get_domain_tools(["material", 'general'])
# API配置
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url = "http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
def get_t1_response(messages):
"""获取T1模型的响应"""
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]正在调用MARS-T1模型..."),
transient=True,
) as progress:
progress.add_task("", total=None)
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas,
)
choice = completion.choices[0]
reasoning_content = choice.message.content
tool_calls_list = []
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
return reasoning_content, tool_calls_list
async def execute_tool(tool_name, tool_arguments):
"""执行工具调用"""
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
return result
finally:
# 清除LLM调用上下文标记
pass
def get_all_tool_calls_results(tool_calls_list):
"""获取所有工具调用的结果"""
all_results = []
with Progress(
SpinnerColumn(),
TextColumn("[bold green]正在执行工具调用..."),
transient=True,
) as progress:
task = progress.add_task("", total=len(tool_calls_list))
for tool_call in tool_calls_list:
tool_name = tool_call['function']['name']
tool_arguments = tool_call['function']['arguments']
# 显示当前执行的工具
progress.update(task, description=f"执行 {tool_name}")
result = asyncio.run(execute_tool(tool_name, tool_arguments))
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
all_results.append(result_str)
# 更新进度
progress.update(task, advance=1)
return all_results
def get_response_from_r1(messages):
"""获取R1模型的响应"""
with Progress(
SpinnerColumn(),
TextColumn("[bold purple]正在调用MARS-R1模型..."),
transient=True,
) as progress:
progress.add_task("", total=None)
completion = client.chat.completions.create(
model="MARS-R1",
messages=messages,
temperature=0.3,
)
choice = completion.choices[0]
return choice.message.content
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
"""显示单条消息"""
title = role.capitalize()
if model:
title = f"{model} {title}"
if role == "user":
console.print(Panel(
content,
title=f"[{title_style}]{title}[/{title_style}]",
border_style=border_style,
expand=False
))
elif role == "assistant" and model == "MARS-T1":
console.print(Panel(
content,
title=f"[bold yellow]{title}[/bold yellow]",
border_style="yellow",
expand=False
))
elif role == "tool":
# 创建一个表格来显示工具调用结果
table = Table(box=box.ROUNDED, expand=False, show_header=False)
table.add_column("内容", style="green")
# 分割工具调用结果并添加到表格
results = content.split("\n")
for result in results:
table.add_row(result)
console.print(Panel(
table,
title=f"[bold green]{title}[/bold green]",
border_style="green",
expand=False
))
elif role == "assistant" and model == "MARS-R1":
try:
# 尝试将内容解析为Markdown
md = Markdown(content)
console.print(Panel(
md,
title=f"[bold purple]{title}[/bold purple]",
border_style="purple",
expand=False
))
except:
# 如果解析失败,直接显示文本
console.print(Panel(
content,
title=f"[bold purple]{title}[/bold purple]",
border_style="purple",
expand=False
))
def process_conversation_round(user_input, conversation_history=None):
"""处理一轮对话,返回更新后的对话历史"""
if conversation_history is None:
conversation_history = []
# 添加用户消息到历史
conversation_history.append({
"role": "user",
"content": user_input
})
# 显示用户消息
display_message("user", user_input)
# 准备发送给T1模型的消息
t1_messages = []
for msg in conversation_history:
if msg["role"] in ["user", "assistant"]:
t1_messages.append({
"role": msg["role"],
"content": msg["content"]
})
# 获取T1模型的响应
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
# 添加T1推理到历史
conversation_history.append({
"role": "assistant",
"content": reasoning_content,
"model": "MARS-T1"
})
# 显示T1推理
display_message("assistant", reasoning_content, model="MARS-T1")
# 如果有工具调用,执行并获取结果
if tool_calls_list:
tool_call_results = get_all_tool_calls_results(tool_calls_list)
tool_call_results_str = "\n".join(tool_call_results)
# 添加工具调用结果到历史
conversation_history.append({
"role": "tool",
"content": tool_call_results_str
})
# 显示工具调用结果
display_message("tool", tool_call_results_str)
# 准备发送给R1模型的消息
user_message = {
"role": "user",
"content": f"# 信息如下:\n{tool_call_results_str}\n# 问题如下:\n{user_input}"
}
# 获取R1模型的响应
r1_response = get_response_from_r1([user_message])
# 添加R1回答到历史
conversation_history.append({
"role": "assistant",
"content": r1_response,
"model": "MARS-R1"
})
# 显示R1回答
display_message("assistant", r1_response, model="MARS-R1")
else:
# 如果没有工具调用直接使用T1的推理作为回答
conversation_history.append({
"role": "assistant",
"content": reasoning_content,
"model": "MARS-R1"
})
# 显示R1回答实际上是T1的推理
display_message("assistant", reasoning_content, model="MARS-R1")
return conversation_history
def run_demo():
"""运行演示,使用初始消息作为第一个用户问题"""
console.print(Panel.fit(
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
border_style="cyan"
))
# 获取初始用户问题
initial_user_input = initial_message[0]["content"]
# 处理第一轮对话
conversation_history = process_conversation_round(initial_user_input)
# 多轮对话循环
while True:
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit''quit' 退出[/bold cyan]")
user_input = input("> ")
# 检查是否退出
if user_input.lower() in ['exit', 'quit', '退出']:
console.print("[bold cyan]演示结束,再见![/bold cyan]")
break
# 处理用户输入
conversation_history = process_conversation_round(user_input, conversation_history)
if __name__ == '__main__':
try:
run_demo()
except KeyboardInterrupt:
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
except Exception as e:
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
import traceback
console.print(traceback.format_exc())

View File

@@ -0,0 +1,7 @@
import sys
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server/')
from sci_mcp.general_mcp.searxng_query.searxng_query_tools import search_online
import asyncio
# 字典
print(asyncio.run(search_online("CsPbBr3", 5)))

View File

@@ -0,0 +1,14 @@
import os
import sys
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
from sci_mcp.material_mcp.mattergen_gen.mattergen_service import format_cif_content
cif_zip_path = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/material/20250508110506/generated_crystals_cif.zip'
if os.path.exists(cif_zip_path):
with open(cif_zip_path, 'rb') as f:
cif_content = f.read().decode('utf-8', errors='replace')
print(format_cif_content(cif_content))

View File

@@ -0,0 +1,166 @@
"""
Test script for mattergen_gen/material_gen_tools.py
This script tests the generate_material function from the material_gen_tools module.
"""
import sys
import asyncio
import unittest
import json
import re
from pathlib import Path
# 添加项目根目录到Python路径
sys.path.append(str(Path(__file__).resolve().parents[2]))
from sci_mcp.material_mcp.mattergen_gen.material_gen_tools import generate_material_MatterGen
class TestMatterGen(unittest.TestCase):
"""Test cases for MatterGen material generation tools."""
def test_unconditional_generation(self):
"""Test unconditional crystal structure generation."""
# 无条件生成(不指定属性)
result = generate_material_MatterGen(properties=None, batch_size=1, num_batches=1)
# 验证结果是否包含预期的关键信息
self.assertIsInstance(result, str)
# 检查结果是否包含一些常见的描述性文本
self.assertIn("Material", result)
self.assertIn("structures", result)
print("无条件生成结果示例:")
print(result[:500] + "...\n" if len(result) > 500 else result)
return result
# def test_single_property_generation(self):
# """Test crystal structure generation with a single property constraint."""
# # 单属性条件生成 - 使用化学系统属性
# properties = {"chemical_system": "Si-O"}
# result = generate_material(properties=properties, batch_size=1, num_batches=1)
# # 验证结果是否包含预期的关键信息
# self.assertIsInstance(result, str)
# # 检查结果是否包含相关的化学元素
# self.assertIn("Si-O", result)
# print("单属性条件生成结果示例:")
# print(result[:500] + "...\n" if len(result) > 500 else result)
# return result
# def test_multi_property_generation(self):
# """Test crystal structure generation with multiple property constraints."""
# # 多属性条件生成
# properties = {
# "chemical_system": "Fe-O",
# "space_group": 227 # 立方晶系空间群Fd-3m
# }
# result = generate_material(properties=properties, batch_size=1, num_batches=1)
# # 验证结果是否为字符串
# self.assertIsInstance(result, str)
# # 检查结果 - 可能是成功生成或错误信息
# if "Error" in result:
# # 如果是错误信息,验证它包含相关的属性信息
# self.assertIn("properties", result)
# print("多属性条件生成返回错误 (这是预期的,因为可能不支持多属性):")
# else:
# # 如果成功,验证包含相关元素
# self.assertIn("Fe", result)
# self.assertIn("O", result)
# print("多属性条件生成成功:")
# print(result[:500] + "...\n" if len(result) > 500 else result)
# return result
# def test_batch_generation(self):
# """Test generating multiple structures in batches."""
# # 测试批量生成
# result = generate_material(properties=None, batch_size=2, num_batches=2)
# # 验证结果是否包含预期的关键信息
# self.assertIsInstance(result, str)
# # 检查结果是否提到了批量生成
# self.assertIn("structures", result)
# print("批量生成结果示例:")
# print(result[:500] + "...\n" if len(result) > 500 else result)
# return result
# def test_guidance_factor(self):
# """Test the effect of diffusion guidance factor."""
# # 测试不同的diffusion_guidance_factor值
# properties = {"chemical_system": "Al-O"}
# # 使用较低的指导因子
# result_low = generate_material(
# properties=properties,
# batch_size=1,
# num_batches=1,
# diffusion_guidance_factor=1.0
# )
# # 使用较高的指导因子
# result_high = generate_material(
# properties=properties,
# batch_size=1,
# num_batches=1,
# diffusion_guidance_factor=3.0
# )
# # 验证两个结果都是有效的
# self.assertIsInstance(result_low, str)
# self.assertIsInstance(result_high, str)
# self.assertIn("Al-O", result_low)
# self.assertIn("Al-O", result_high)
# # 验证两个结果都提到了diffusion guidance factor
# self.assertIn("guidance factor", result_low)
# self.assertIn("guidance factor", result_high)
# print("不同指导因子的生成结果示例:")
# print("低指导因子 (1.0):")
# print(result_low[:300] + "...\n" if len(result_low) > 300 else result_low)
# print("高指导因子 (3.0):")
# print(result_high[:300] + "...\n" if len(result_high) > 300 else result_high)
# return result_low, result_high
# def test_invalid_properties(self):
# """Test handling of invalid properties."""
# # 测试无效属性
# invalid_properties = {"invalid_property": "value"}
# result = generate_material(properties=invalid_properties)
# # 验证结果是否为字符串
# self.assertIsInstance(result, str)
# # 检查结果 - 可能返回错误信息或尝试生成
# if "Error" in result:
# print("无效属性测试返回错误 (预期行为):")
# else:
# # 如果没有返回错误,至少应该包含我们请求的属性名称
# self.assertIn("invalid_property", result)
# print("无效属性测试尝试生成:")
# print(result)
# return result
def run_tests():
"""运行所有测试。"""
unittest.main()
if __name__ == "__main__":
run_tests()

View File

@@ -0,0 +1,73 @@
import asyncio
import sys
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server')
import matgl
from sci_mcp.material_mcp.matgl_tools import relax_crystal_structure_M3GNet,predict_formation_energy_M3GNet,run_molecular_dynamics_M3GNet,calculate_single_point_energy_M3GNet
print(matgl.get_available_pretrained_models())
cif_file_name = 'GdPbGdHGd.cif'
cif_content="""data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
"""
#print(asyncio.run(relax_crystal_structure(cif_file_name)))
print(asyncio.run(predict_formation_energy_M3GNet(cif_content)))
#print(asyncio.run(calculate_single_point_energy(cif_file_name)))
#print(asyncio.run(run_molecular_dynamics(cif_file_name)))

View File

@@ -0,0 +1,15 @@
import sys
sys.path.append('/home/ubuntu/sas0/lzy/multi_mcp_server/')
from sci_mcp_server.material_mcp.mp_query.mp_query_tools import search_material_property_from_material_project,search_crystal_structures_from_materials_project
import asyncio
from sci_mcp_server.core.llm_tools import set_llm_context, clear_llm_context
set_llm_context(True)
print(asyncio.run(search_material_property_from_material_project("CsPbBr3")))
clear_llm_context()
print(asyncio.run(search_material_property_from_material_project("CsPbBr3")))
set_llm_context(True)
print(asyncio.run(search_crystal_structures_from_materials_project("CsPbBr3")))
clear_llm_context()
print(asyncio.run(search_crystal_structures_from_materials_project("CsPbBr3")))

View File

@@ -0,0 +1,143 @@
"""
Test script for property_pred_tools.py
This script tests the predict_properties function from the property_pred_tools module.
"""
import os
import sys
import asyncio
import unittest
from pathlib import Path
# 添加项目根目录到Python路径
sys.path.append(str(Path(__file__).resolve().parents[2]))
from sci_mcp.material_mcp.mattersim_pred.property_pred_tools import predict_properties_MatterSim
class TestPropertyPrediction(unittest.TestCase):
"""Test cases for property prediction tools."""
def setUp(self):
"""Set up test fixtures."""
# 简单的CIF字符串示例 - 硅晶体结构
self.simple_cif = """
data_Si
_cell_length_a 5.43
_cell_length_b 5.43
_cell_length_c 5.43
_cell_angle_alpha 90
_cell_angle_beta 90
_cell_angle_gamma 90
_symmetry_space_group_name_H-M 'P 1'
_symmetry_Int_Tables_number 1
loop_
_atom_site_label
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
Si 0.0 0.0 0.0
Si 0.5 0.5 0.0
Si 0.5 0.0 0.5
Si 0.0 0.5 0.5
Si 0.25 0.25 0.25
Si 0.75 0.75 0.25
Si 0.75 0.25 0.75
Si 0.25 0.75 0.75
"""
def test_predict_properties_async(self):
"""Test predict_properties function with a simple CIF string (异步版本)."""
async def _async_test():
result = await predict_properties(self.simple_cif)
# 验证结果是否包含预期的关键信息
self.assertIsInstance(result, str)
self.assertIn("Crystal Structure Property Prediction Results", result)
self.assertIn("Total Energy (eV):", result)
self.assertIn("Energy per Atom (eV/atom):", result)
self.assertIn("Forces (eV/Angstrom):", result)
self.assertIn("Stress (GPa):", result)
self.assertIn("Stress (eV/A^3):", result)
print("预测结果示例:")
print(result)
return result
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_test())
def test_predict_properties_sync(self):
"""同步方式测试predict_properties函数。"""
self.test_predict_properties_async()
class TestPropertyPredictionWithFile(unittest.TestCase):
"""使用文件测试属性预测工具。"""
def setUp(self):
"""设置测试夹具创建临时CIF文件。"""
self.temp_cif_path = "temp_test_structure.cif"
# 简单的CIF内容 - 氧化铝结构
cif_content = """
data_Al2O3
_cell_length_a 4.76
_cell_length_b 4.76
_cell_length_c 12.99
_cell_angle_alpha 90
_cell_angle_beta 90
_cell_angle_gamma 120
_symmetry_space_group_name_H-M 'R -3 c'
_symmetry_Int_Tables_number 167
loop_
_atom_site_label
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
Al 0.0 0.0 0.35
Al 0.0 0.0 0.85
O 0.31 0.0 0.25
"""
# 创建临时文件
with open(self.temp_cif_path, "w") as f:
f.write(cif_content)
def tearDown(self):
"""清理测试夹具,删除临时文件。"""
if os.path.exists(self.temp_cif_path):
os.remove(self.temp_cif_path)
def test_predict_properties_from_file_async(self):
"""测试从文件预测属性(异步版本)。"""
async def _async_test():
result = await predict_properties(self.temp_cif_path)
# 验证结果
self.assertIsInstance(result, str)
self.assertIn("Crystal Structure Property Prediction Results", result)
self.assertIn("Total Energy (eV):", result)
print("从文件预测结果示例:")
print(result)
return result
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_test())
def test_predict_properties_from_file_sync(self):
"""同步方式测试从文件预测属性。"""
self.test_predict_properties_from_file_async()
def run_tests():
"""运行所有测试。"""
unittest.main()
if __name__ == "__main__":
run_tests()

View File

@@ -0,0 +1,208 @@
"""
Test script for pymatgen_cal_tools.py
This script tests the functions from the pymatgen_cal_tools module.
"""
import os
import sys
import asyncio
import unittest
from pathlib import Path
# 添加项目根目录到Python路径
sys.path.append(str(Path(__file__).resolve().parents[2]))
from sci_mcp.material_mcp.pymatgen_cal.pymatgen_cal_tools import (
calculate_density,
get_element_composition,
calculate_symmetry
)
class TestPymatgenCalculations(unittest.TestCase):
"""Test cases for pymatgen calculation tools."""
def setUp(self):
"""Set up test fixtures."""
# 简单的CIF字符串示例 - 硅晶体结构
self.simple_cif = """
data_Si
_cell_length_a 5.43
_cell_length_b 5.43
_cell_length_c 5.43
_cell_angle_alpha 90
_cell_angle_beta 90
_cell_angle_gamma 90
_symmetry_space_group_name_H-M 'P 1'
_symmetry_Int_Tables_number 1
loop_
_atom_site_label
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
Si 0.0 0.0 0.0
Si 0.5 0.5 0.0
Si 0.5 0.0 0.5
Si 0.0 0.5 0.5
Si 0.25 0.25 0.25
Si 0.75 0.75 0.25
Si 0.75 0.25 0.75
Si 0.25 0.75 0.75
"""
def test_calculate_density_async(self):
"""Test calculate_density function with a simple CIF string (异步版本)."""
async def _async_test():
result = await calculate_density(self.simple_cif)
# 验证结果是否包含预期的关键信息
self.assertIsInstance(result, str)
self.assertIn("Density Calculation", result)
self.assertIn("Density", result)
self.assertIn("g/cm³", result)
print("密度计算结果示例:")
print(result)
return result
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_test())
def test_get_element_composition_async(self):
"""Test get_element_composition function with a simple CIF string (异步版本)."""
async def _async_test():
result = await get_element_composition(self.simple_cif)
# 验证结果是否包含预期的关键信息
self.assertIsInstance(result, str)
self.assertIn("Element Composition", result)
self.assertIn("Composition", result)
self.assertIn("Si", result)
print("元素组成计算结果示例:")
print(result)
return result
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_test())
def test_calculate_symmetry_async(self):
"""Test calculate_symmetry function with a simple CIF string (异步版本)."""
async def _async_test():
result = await calculate_symmetry(self.simple_cif)
# 验证结果是否包含预期的关键信息
self.assertIsInstance(result, str)
self.assertIn("Symmetry Information", result)
self.assertIn("Space Group", result)
self.assertIn("Number", result)
print("对称性计算结果示例:")
print(result)
return result
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_test())
class TestPymatgenCalculationsWithFile(unittest.TestCase):
"""使用文件测试pymatgen计算工具。"""
def setUp(self):
"""设置测试夹具创建临时CIF文件。"""
self.temp_cif_path = "temp_test_structure.cif"
# 简单的CIF内容 - 氧化铝结构
cif_content = """
data_Al2O3
_cell_length_a 4.76
_cell_length_b 4.76
_cell_length_c 12.99
_cell_angle_alpha 90
_cell_angle_beta 90
_cell_angle_gamma 120
_symmetry_space_group_name_H-M 'R -3 c'
_symmetry_Int_Tables_number 167
_symmetry_equiv_pos_as_xyz 'x, y, z'
_symmetry_equiv_pos_as_xyz '-x, -y, -z'
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Al1 Al 0.0 0.0 0.35 1.0
Al2 Al 0.0 0.0 0.85 1.0
O1 O 0.31 0.0 0.25 1.0
"""
# 创建临时文件
with open(self.temp_cif_path, "w") as f:
f.write(cif_content)
def tearDown(self):
"""清理测试夹具,删除临时文件。"""
if os.path.exists(self.temp_cif_path):
os.remove(self.temp_cif_path)
def test_calculate_density_from_file_async(self):
"""测试从文件计算密度(异步版本)。"""
async def _async_test():
result = await calculate_density(self.temp_cif_path)
# 验证结果
self.assertIsInstance(result, str)
self.assertIn("Density Calculation", result)
self.assertIn("Density", result)
print("从文件计算密度结果示例:")
print(result)
return result
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_test())
def test_get_element_composition_from_file_async(self):
"""测试从文件获取元素组成(异步版本)。"""
async def _async_test():
result = await get_element_composition(self.temp_cif_path)
# 验证结果
self.assertIsInstance(result, str)
self.assertIn("Element Composition", result)
self.assertIn("Composition", result)
print("从文件获取元素组成结果示例:")
print(result)
return result
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_test())
def test_calculate_symmetry_from_file_async(self):
"""测试从文件计算对称性(异步版本)。"""
async def _async_test():
result = await calculate_symmetry(self.temp_cif_path)
# 验证结果
self.assertIsInstance(result, str)
self.assertIn("Symmetry Information", result)
self.assertIn("Space Group", result)
print("从文件计算对称性结果示例:")
print(result)
return result
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_test())
def run_tests():
"""运行所有测试。"""
unittest.main()
if __name__ == "__main__":
run_tests()

View File

@@ -0,0 +1,118 @@
"""
测试晶体结构优化工具函数
此脚本测试改进后的optimize_crystal_structure函数
该函数接受单一的file_name_or_content_string参数可以是文件路径或直接的结构内容。
"""
import sys
import asyncio
import os
import tempfile
sys.path.append("/home/ubuntu/sas0/lzy/multi_mcp_server")
from sci_mcp.material_mcp.fairchem_structure_opt.structure_opt_tools import optimize_crystal_structure
from sci_mcp.core.config import material_config
# 简单的CIF结构示例
SAMPLE_CIF = """
data_SrTiO3
_cell_length_a 3.905
_cell_length_b 3.905
_cell_length_c 3.905
_cell_angle_alpha 90
_cell_angle_beta 90
_cell_angle_gamma 90
_symmetry_space_group_name_H-M 'P m -3 m'
_symmetry_Int_Tables_number 221
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
Sr1 Sr 0.0 0.0 0.0
Ti1 Ti 0.5 0.5 0.5
O1 O 0.5 0.5 0.0
O2 O 0.5 0.0 0.5
O3 O 0.0 0.5 0.5
"""
async def test_with_content():
"""测试使用直接结构内容"""
print("\n=== 测试使用直接结构内容 ===")
result = await optimize_crystal_structure(
file_name_or_content_string=SAMPLE_CIF,
format_type="cif",
optimization_level="quick"
)
print(result)
async def test_with_file():
"""测试使用文件路径(如果文件存在)"""
print("\n=== 测试使用文件路径 ===")
# 创建临时CIF文件
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w", delete=False) as tmp_file:
tmp_file.write(SAMPLE_CIF)
tmp_path = tmp_file.name
try:
result = await optimize_crystal_structure(
file_name_or_content_string=tmp_path,
format_type="auto",
optimization_level="quick"
)
print(result)
finally:
# 清理临时文件
if os.path.exists(tmp_path):
os.unlink(tmp_path)
async def test_with_temp_file():
"""测试使用临时目录中的文件名"""
print("\n=== 测试使用临时目录中的文件名 ===")
# 确保临时目录存在
os.makedirs(material_config.TEMP_ROOT, exist_ok=True)
# 在临时目录中创建文件
temp_filename = "test_structure.cif"
temp_filepath = os.path.join(material_config.TEMP_ROOT, temp_filename)
with open(temp_filepath, 'w', encoding='utf-8') as f:
f.write(SAMPLE_CIF)
try:
# 只传递文件名,而不是完整路径
result = await optimize_crystal_structure(
file_name_or_content_string=temp_filename,
format_type="auto",
optimization_level="quick"
)
print(result)
finally:
# 清理临时文件
if os.path.exists(temp_filepath):
os.unlink(temp_filepath)
async def test_auto_format():
"""测试自动格式检测"""
print("\n=== 测试自动格式检测 ===")
result = await optimize_crystal_structure(
file_name_or_content_string=SAMPLE_CIF,
format_type="auto"
)
print(result)
async def main():
"""运行所有测试"""
print("测试改进后的optimize_crystal_structure函数")
await test_with_content()
await test_with_file()
await test_with_temp_file()
await test_auto_format()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,521 @@
import asyncio
import json
import sys
import os
import re
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
from rich import box
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
initial_message=messages = [ {"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}]
# 初始化rich控制台
console = Console()
# 获取工具模式和映射
tools_schemas = get_domain_tool_schemas(["material", 'general'])
tool_map = get_domain_tools(["material", 'general'])
# API配置
api_key = "gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url = "http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
def get_t1_response(messages):
"""获取T1模型的响应"""
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]正在调用MARS-T1模型..."),
transient=True,
) as progress:
progress.add_task("", total=None)
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas,
)
choice = completion.choices[0]
reasoning_content = choice.message.content
tool_calls_list = []
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
return reasoning_content, tool_calls_list
async def execute_tool(tool_name, tool_arguments):
"""执行工具调用"""
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
return result
finally:
# 清除LLM调用上下文标记
pass
def get_all_tool_calls_results(tool_calls_list):
"""获取所有工具调用的结果"""
all_results = []
with Progress(
SpinnerColumn(),
TextColumn("[bold green]正在执行工具调用..."),
transient=True,
) as progress:
task = progress.add_task("", total=len(tool_calls_list))
for tool_call in tool_calls_list:
tool_name = tool_call['function']['name']
tool_arguments = tool_call['function']['arguments']
# 显示当前执行的工具
progress.update(task, description=f"执行 {tool_name}")
result = asyncio.run(execute_tool(tool_name, tool_arguments))
result_str = f"[{tool_name} content begin]\n{result}\n[{tool_name} content end]\n"
all_results.append(result_str)
# 更新进度
progress.update(task, advance=1)
return all_results
def get_response_from_r1(messages):
"""获取R1模型的响应"""
with Progress(
SpinnerColumn(),
TextColumn("[bold purple]正在调用MARS-R1模型..."),
transient=True,
) as progress:
progress.add_task("", total=None)
completion = client.chat.completions.create(
model="MARS-R1",
messages=messages,
temperature=0.3,
)
choice = completion.choices[0]
return choice.message.content
def display_message(role, content, model=None, title_style="bold blue", border_style="blue"):
"""显示单条消息"""
title = role.capitalize()
if model:
title = f"{model} {title}"
if role == "user":
console.print(Panel(
content,
title=f"[{title_style}]{title}[/{title_style}]",
border_style=border_style,
expand=False
))
elif role == "assistant" and model == "MARS-T1":
console.print(Panel(
content,
title=f"[bold yellow]{title}[/bold yellow]",
border_style="yellow",
expand=False
))
elif role == "tool":
# 创建一个表格来显示工具调用结果
table = Table(box=box.ROUNDED, expand=False, show_header=False)
table.add_column("内容", style="green")
# 分割工具调用结果并添加到表格
results = content.split("\n")
for result in results:
table.add_row(result)
console.print(Panel(
table,
title=f"[bold green]{title}[/bold green]",
border_style="green",
expand=False
))
elif role == "assistant" and model == "MARS-R1":
try:
# 尝试将内容解析为Markdown
md = Markdown(content)
console.print(Panel(
md,
title=f"[bold purple]{title}[/bold purple]",
border_style="purple",
expand=False
))
except:
# 如果解析失败,直接显示文本
console.print(Panel(
content,
title=f"[bold purple]{title}[/bold purple]",
border_style="purple",
expand=False
))
def process_conversation_round(user_input, conversation_history=None):
"""处理一轮对话,返回更新后的对话历史"""
if conversation_history is None:
conversation_history = []
# 添加用户消息到外部历史
conversation_history.append({
"role": "user",
"content": user_input
})
# 显示用户消息
display_message("user", user_input)
# 内部循环变量
max_iterations = 3 # 防止无限循环
iterations = 0
# 分别管理T1和R1的对话历史
t1_messages = []
r1_messages = []
# 初始化T1消息历史从外部历史中提取用户和助手消息
for msg in conversation_history:
if msg["role"] in ["user", "assistant"]:
t1_messages.append({
"role": msg["role"],
"content": msg["content"]
})
# 当前问题(初始为用户输入)
current_question = user_input
while iterations < max_iterations:
iterations += 1
# 如果不是第一次迭代添加R1生成的后续问题作为新的用户消息
if iterations > 1:
# 显示后续问题
display_message("user", f"[后续问题] {current_question}")
# 添加到T1消息历史
t1_messages.append({
"role": "user",
"content": current_question
})
# 获取T1模型的响应
reasoning_content, tool_calls_list = get_t1_response(t1_messages)
# 显示T1推理
display_message("assistant", reasoning_content, model="MARS-T1")
# 添加T1的回答到外部历史
conversation_history.append({
"role": "assistant",
"content": reasoning_content,
"model": "MARS-T1"
})
# 添加T1的回答到T1消息历史
t1_messages.append({
"role": "assistant",
"content": reasoning_content
})
# 如果没有工具调用使用T1的推理作为最终答案
if not tool_calls_list:
# 添加相同的回答作为R1的回答因为没有工具调用
conversation_history.append({
"role": "assistant",
"content": reasoning_content,
"model": "MARS-R1"
})
display_message("assistant", reasoning_content, model="MARS-R1")
break
# 执行工具调用并获取结果
tool_call_results = get_all_tool_calls_results(tool_calls_list)
tool_call_results_str = "\n".join(tool_call_results)
# 添加工具调用结果到外部历史
conversation_history.append({
"role": "tool",
"content": tool_call_results_str
})
# 显示工具调用结果
display_message("tool", tool_call_results_str)
# 重置R1消息历史每次迭代都重新构建
r1_messages = []
# 添加系统消息指导R1如何处理信息
r1_messages.append({
"role": "system",
"content": """你是一个能够分析工具调用结果并回答问题的助手。
请分析提供的信息,并执行以下操作之一:
1. 如果你能够基于提供的工具调用信息直接回答原始问题,请提供完整的回答。
2. 如果目前的工具调用信息不足以让你回答原始问题,请明确说明缺少哪些信息,并生成一个新的问题来获取这些信息。
新问题格式:<FOLLOW_UP_QUESTION>你的问题</FOLLOW_UP_QUESTION>
注意:如果你生成了后续问题,系统将自动将其发送给工具调用模型以获取更多信息。"""
})
# 构建R1的用户消息包含原始问题、工具调用信息和结果
r1_user_message = f"""# 原始问题
{user_input}
# 工具调用信息
{reasoning_content}
# 工具调用结果
{tool_call_results_str}"""
# 如果有后续问题添加到R1用户消息
if iterations > 1:
r1_user_message += f"\n\n# 后续问题\n{current_question}"
# 添加构建好的用户消息
r1_messages.append({
"role": "user",
"content": r1_user_message
})
# 获取R1模型的响应
r1_response = get_response_from_r1(r1_messages)
# 显示R1回答
display_message("assistant", r1_response, model="MARS-R1")
# 检查R1是否生成了后续问题
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', r1_response, re.DOTALL)
if follow_up_match:
# 提取后续问题
follow_up_question = follow_up_match.group(1).strip()
# 将后续问题作为新的当前问题
current_question = follow_up_question
# 添加R1的回答到外部历史不包括后续问题标记
clean_response = r1_response.replace(follow_up_match.group(0), "")
conversation_history.append({
"role": "assistant",
"content": clean_response,
"model": "MARS-R1"
})
# 继续循环使用新问题调用T1
else:
# R1能够回答问题添加回答到历史并结束循环
conversation_history.append({
"role": "assistant",
"content": r1_response,
"model": "MARS-R1"
})
break
return conversation_history
def run_demo():
"""运行演示,使用初始消息作为第一个用户问题"""
console.print(Panel.fit(
"[bold cyan]多轮对话演示[/bold cyan] - 使用 MARS-T1 和 MARS-R1 模型",
border_style="cyan"
))
# 获取初始用户问题
initial_user_input = initial_message[0]["content"]
# 处理第一轮对话
conversation_history = process_conversation_round(initial_user_input)
# 检查R1是否生成了后续问题并自动处理
auto_process_follow_up_questions(conversation_history)
# 多轮对话循环
while True:
console.print("\n[bold cyan]输入问题继续对话,或输入 'exit''quit' 退出[/bold cyan]")
user_input = input("> ")
# 检查是否退出
if user_input.lower() in ['exit', 'quit', '退出']:
console.print("[bold cyan]演示结束,再见![/bold cyan]")
break
# 处理用户输入
conversation_history = process_conversation_round(user_input, conversation_history)
# 检查R1是否生成了后续问题并自动处理
auto_process_follow_up_questions(conversation_history)
def auto_process_follow_up_questions(conversation_history):
"""自动处理R1生成的后续问题"""
# 检查最后一条消息是否是R1的回答
if not conversation_history or len(conversation_history) == 0:
return
last_message = conversation_history[-1]
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
# 检查是否包含后续问题
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
if follow_up_match:
# 提取后续问题
follow_up_question = follow_up_match.group(1).strip()
# 显示检测到的后续问题
console.print(Panel(
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
border_style="yellow",
expand=False
))
# 自动处理后续问题
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
# 递归处理后续问题,直到没有更多后续问题或达到最大迭代次数
max_auto_iterations = 3
current_iterations = 0
while current_iterations < max_auto_iterations:
current_iterations += 1
# 处理后续问题
conversation_history = process_conversation_round(follow_up_question, conversation_history)
# 检查是否还有后续问题
if len(conversation_history) > 0:
last_message = conversation_history[-1]
if last_message["role"] == "assistant" and last_message.get("model") == "MARS-R1":
follow_up_match = re.search(r'<FOLLOW_UP_QUESTION>(.*?)</FOLLOW_UP_QUESTION>', last_message["content"], re.DOTALL)
if follow_up_match:
# 提取后续问题
follow_up_question = follow_up_match.group(1).strip()
# 显示检测到的后续问题
console.print(Panel(
f"[bold yellow]检测到后续问题: {follow_up_question}[/bold yellow]",
border_style="yellow",
expand=False
))
# 自动处理后续问题
console.print("[bold cyan]自动处理后续问题...[/bold cyan]")
continue
# 如果没有更多后续问题,退出循环
break
if __name__ == '__main__':
try:
run_demo()
except KeyboardInterrupt:
console.print("\n[bold cyan]程序被用户中断,再见![/bold cyan]")
except Exception as e:
console.print(f"\n[bold red]发生错误: {str(e)}[/bold red]")
import traceback
console.print(traceback.format_exc())

83
test_tools/test.py Normal file
View File

@@ -0,0 +1,83 @@
"""
Test script for domain-specific tool retrieval functions
"""
import json
from pprint import pprint
import sys
sys.path.append("/home/ubuntu/sas0/lzy/multi_mcp_server")
from sci_mcp.core.llm_tools import get_domain_tools, get_domain_tool_schemas, get_all_tool_schemas,get_all_tools
def test_get_domain_tools():
"""Test retrieving tools from specific domains"""
print("\n=== Testing get_domain_tools ===")
# Test with material and general domains
domains = ['material', 'general']
print(f"Getting tools for domains: {domains}")
domain_tools = get_domain_tools(domains)
# Print results
for domain, tools in domain_tools.items():
print(f"\nDomain: {domain}")
print(f"Number of tools: {len(tools)}")
print("Tool names:")
for tool_name in tools.keys():
print(f" - {tool_name}")
def test_get_domain_tool_schemas():
"""Test retrieving tool schemas from specific domains"""
print("\n=== Testing get_domain_tool_schemas (OpenAI format) ===")
# Test with material and general domains
domains = ['material', 'general']
print(f"Getting tool schemas for domains: {domains}")
domain_schemas = get_domain_tool_schemas(domains)
# Print results
for domain, schemas in domain_schemas.items():
print(f"\nDomain: {domain}")
print(f"Number of schemas: {len(schemas)}")
print("Tool names in schemas:")
for schema in schemas:
print(f" - {schema['function']['name']}")
def test_all_tool_schemas():
"""Test retrieving all tool schemas in both formats"""
print("\n=== Testing get_all_tool_schemas ===")
# OpenAI format
print("Getting all tool schemas in OpenAI format")
openai_schemas = get_all_tool_schemas()
print(f"Number of schemas: {len(openai_schemas)}")
print("Tool names:")
for schema in openai_schemas:
print(f" - {schema['function']['name']}")
# MCP format
print("\nGetting all tool schemas in MCP format")
mcp_schemas = get_all_tool_schemas(use_mcp_format=True)
print(f"Number of schemas: {len(mcp_schemas)}")
print("Tool names:")
for schema in mcp_schemas:
print(f" - {schema['name']}")
def main():
"""Main function to run tests"""
print("Testing domain-specific tool retrieval functions")
#test_get_domain_tools()
test_get_domain_tool_schemas()
test_all_tool_schemas()
if __name__ == "__main__":
#print(get_all_tool_schemas())
from rich.console import Console
console = Console()
console.print("Testing domain-specific tool retrieval functions", style="bold green")
console.print(get_domain_tool_schemas(['chemistry']))

253
test_tools/test_mars_t1.py Normal file
View File

@@ -0,0 +1,253 @@
import asyncio
import json
import sys
import os
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
# 添加项目根目录到sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp_server.core.llm_tools import set_llm_context, clear_llm_context
# 创建Rich控制台对象
console = Console()
# 定义分隔符样式
def print_separator(title=""):
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url="http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
from sci_mcp import *
tools_schemas = get_domain_tool_schemas(["material",'general'])
tool_map = get_domain_tools(["material",'general'])
# 打印消息的函数
def print_message(message):
# 处理不同类型的消息对象
if hasattr(message, 'role'): # ChatCompletionMessage 对象
role = message.role
content = message.content if hasattr(message, 'content') else ""
# 如果是工具消息,获取工具名称
tool_name = None
if role == "tool" and hasattr(message, 'name'):
tool_name = message.name
else: # 字典类型
role = message.get("role", "unknown")
content = message.get("content", "")
# 如果是工具消息,获取工具名称
tool_name = message.get("name") if role == "tool" else None
# 根据角色选择不同的颜色
role_colors = {
"system": "bright_blue",
"user": "green",
"assistant": "yellow",
"tool": "bright_red"
}
color = role_colors.get(role, "white")
# 创建富文本面板
text = Text()
# 如果是工具消息,添加工具名称
if role == "tool" and tool_name:
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
else:
text.append(f"{role}: ", style=f"bold {color}")
text.append(str(content))
console.print(Panel(text, border_style=color))
messages = [
{"role": "system",
"content": "You are MARS-R1, a professional assistant in materials science. You first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> <structured_answer> </structured_answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here <structured_answer> structured answer here <structured_answer> </answer>'"},
{"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}
]
#how to synthesize CsPbBr3 at room temperature
#
# 打印初始消息
print_separator("初始消息")
for message in messages:
print_message(message)
finish_reason = None
async def execute_tool(tool_name,tool_arguments):
# 设置LLM调用上下文标记
set_llm_context(True)
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
finally:
# 清除LLM调用上下文标记
clear_llm_context()
while finish_reason is None or finish_reason == "tool_calls":
completion = client.chat.completions.create(
model="MARS-R1",
messages=messages,
temperature=0.3,
tools=tools_schemas, # <-- 我们通过 tools 参数,将定义好的 tools 提交给 Kimi 大模型
)
choice = completion.choices[0]
finish_reason = choice.finish_reason
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
# 打印assistant消息
print_separator("Assistant消息")
print_message(choice.message)
# 将ChatCompletionMessage对象转换为字典
assistant_message = {
"role": "assistant",
"content": choice.message.content if hasattr(choice.message, 'content') else None
}
# 如果有工具调用,添加到字典中
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
# 将tool_calls对象转换为字典列表
tool_calls_list = []
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
assistant_message["tool_calls"] = tool_calls_list
# 添加消息到上下文
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
# 打印工具调用信息
print_separator("工具调用")
for tool_call in choice.message.tool_calls:
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
console.print("")
tool_call_name = tool_call.function.name
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object我们需要使用 json.loads 反序列化一下
try:
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
except Exception as e:
tool_result=f'工具调用失败{e}'
# 构造工具响应消息
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call_name,
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
}
# 打印工具响应
print_separator(f"工具响应: {tool_call_name}")
print_message(tool_message)
# 添加消息到上下文
messages.append(tool_message)
# 打印最终响应
if choice.message.content:
print_separator("最终响应")
console.print(Panel(choice.message.content, border_style="green"))

177
test_tools/test_mars_t1_.py Normal file
View File

@@ -0,0 +1,177 @@
import asyncio
import json
from openai import OpenAI
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
tools_schemas = get_domain_tool_schemas(["material",'general'])
tool_map = get_domain_tools(["material",'general'])
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url="http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
messages = [ {"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}]
def get_t1_response(messages):
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas,
)
choice = completion.choices[0]
reasoning_content = choice.message.content
#print("Reasoning content:", reasoning_content)
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
tool_calls_list = []
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
return reasoning_content, tool_calls_list
async def execute_tool(tool_name,tool_arguments):
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
finally:
# 清除LLM调用上下文标记
pass
def get_all_tool_calls_results(tool_calls_list):
all_results = []
for tool_call in tool_calls_list:
tool_name = tool_call['function']['name']
tool_arguments = tool_call['function']['arguments']
result = asyncio.run(execute_tool(tool_name,tool_arguments))
result_str = f"[{tool_name} content begin]\n"+result+f"\n[{tool_name} content end]\n"
all_results.append(result_str)
return all_results
def get_response_from_r1(messages):
completion = client.chat.completions.create(
model="MARS-R1",
messages=messages,
temperature=0.3,
)
choice = completion.choices[0]
return choice.message.content
print("R1 RESPONSE:", choice.message.content)
if __name__ == '__main__':
reasoning_content, tool_calls_list=get_t1_response(messages)
print("Reasoning content:", reasoning_content)
tool_call_results=get_all_tool_calls_results(tool_calls_list)
tool_call_results_str = "\n".join(tool_call_results)
# for tool_call in tool_call_results:
# print(tool_call)
user_message = {
"role": "user",
"content": f"# 信息如下:{tool_call_results_str}# 问题如下 {messages[0]['content']}"
}
print("user_message_for_r1:", user_message)
get_response_from_r1([user_message])

View File

@@ -0,0 +1,339 @@
import asyncio
import json
import sys
import os
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
# 添加项目根目录到sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
# 创建Rich控制台对象
console = Console()
# 创建一个列表来存储工具调用结果
tool_results = []
# 定义分隔符样式
def print_separator(title=""):
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url="http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
from sci_mcp import *
tools_schemas = get_domain_tool_schemas(["material",'general'])
tool_map = get_domain_tools(["material",'general'])
# 打印消息的函数
def print_message(message):
# 处理不同类型的消息对象
if hasattr(message, 'role'): # ChatCompletionMessage 对象
role = message.role
content = message.content if hasattr(message, 'content') else ""
# 如果是工具消息,获取工具名称
tool_name = None
if role == "tool" and hasattr(message, 'name'):
tool_name = message.name
else: # 字典类型
role = message.get("role", "unknown")
content = message.get("content", "")
# 如果是工具消息,获取工具名称
tool_name = message.get("name") if role == "tool" else None
# 根据角色选择不同的颜色
role_colors = {
"system": "bright_blue",
"user": "green",
"assistant": "yellow",
"tool": "bright_red"
}
color = role_colors.get(role, "white")
# 创建富文本面板
text = Text()
# 如果是工具消息,添加工具名称
if role == "tool" and tool_name:
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
else:
text.append(f"{role}: ", style=f"bold {color}")
text.append(str(content))
console.print(Panel(text, border_style=color))
messages = [
{"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}
]
#how to synthesize CsPbBr3 at room temperature
#
# 打印初始消息
print_separator("初始消息")
for message in messages:
print_message(message)
finish_reason = None
async def execute_tool(tool_name,tool_arguments):
# 设置LLM调用上下文标记
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
finally:
# 清除LLM调用上下文标记
pass
# 定义一个函数来估算消息的token数量粗略估计
def estimate_tokens(messages):
# 简单估计每个英文单词约1.3个token每个中文字符约1个token
total = 0
for msg in messages:
content = msg.get("content", "") if isinstance(msg, dict) else (msg.content if hasattr(msg, "content") else "")
if content:
# 粗略估计内容的token数
total += len(content) * 1.3
# 估计工具调用的token数
tool_calls = msg.get("tool_calls", []) if isinstance(msg, dict) else (msg.tool_calls if hasattr(msg, "tool_calls") else [])
for tool_call in tool_calls:
if isinstance(tool_call, dict):
args = tool_call.get("function", {}).get("arguments", "")
total += len(args) * 1.3
else:
args = tool_call.function.arguments if hasattr(tool_call, "function") else ""
total += len(args) * 1.3
return int(total)
# 管理消息历史保持在token限制以内
def manage_message_history(messages, max_tokens=7000):
# 保留第一条消息(通常是系统消息或初始用户消息)
if len(messages) <= 1:
return messages
# 估算当前消息的token数
current_tokens = estimate_tokens(messages)
# 如果当前token数已经接近限制开始裁剪历史消息
if current_tokens > max_tokens:
# 保留第一条消息和最近的消息
preserved_messages = [messages[0]]
# 从最新的消息开始添加直到接近但不超过token限制
temp_messages = []
for msg in reversed(messages[1:]):
temp_messages.insert(0, msg)
if estimate_tokens([messages[0]] + temp_messages) > max_tokens:
# 如果添加这条消息会超过限制,则停止添加
temp_messages.pop(0)
break
# 如果裁剪后的消息太少,至少保留最近的几条消息
if len(temp_messages) < 4 and len(messages) > 4:
temp_messages = messages[-4:]
return preserved_messages + temp_messages
return messages
while finish_reason is None or finish_reason == "tool_calls":
# 在发送请求前管理消息历史
managed_messages = manage_message_history(messages)
if len(managed_messages) < len(messages):
print_separator(f"消息历史已裁剪,从{len(messages)}条减少到{len(managed_messages)}")
messages = managed_messages
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas,
timeout=120,
)
choice = completion.choices[0]
finish_reason = choice.finish_reason
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
# 打印assistant消息
print_separator("Assistant消息")
print_message(choice.message)
# 将ChatCompletionMessage对象转换为字典
assistant_message = {
"role": "assistant",
"content": choice.message.content if hasattr(choice.message, 'content') else None
}
# 如果有工具调用,添加到字典中
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
# 将tool_calls对象转换为字典列表
tool_calls_list = []
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
assistant_message["tool_calls"] = tool_calls_list
# 添加消息到上下文
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
# 打印工具调用信息
print_separator("工具调用")
for tool_call in choice.message.tool_calls:
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
console.print("")
tool_call_name = tool_call.function.name
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object我们需要使用 json.loads 反序列化一下
try:
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
except Exception as e:
tool_result=f'工具调用失败{e}'
# 将工具调用结果保存到单独的列表中,使用指定格式包裹
formatted_result = f"[{tool_call_name} content begin]{tool_result}[{tool_call_name} content end]"
tool_results.append({
"tool_name": tool_call_name,
"tool_id": tool_call.id,
"formatted_result": formatted_result,
"raw_result": tool_result
})
# 打印保存的工具结果信息
console.print(f"[bold green]已保存工具结果:[/] [cyan]{tool_call_name}[/]")
# 构造工具响应消息
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call_name,
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
}
# 打印工具响应
print_separator(f"工具响应: {tool_call_name}")
print_message(tool_message)
# 添加消息到上下文
#messages.append(tool_message)
# 打印最终响应
if choice.message.content:
print_separator("最终响应")
console.print(Panel(choice.message.content, border_style="green"))
# 打印收集的所有工具调用结果
if tool_results:
print_separator("所有工具调用结果")
console.print(f"[bold cyan]共收集了 {len(tool_results)} 个工具调用结果[/]")
# 将所有格式化的结果写入文件
with open("tool_results.txt", "w", encoding="utf-8") as f:
for result in tool_results:
f.write(f"{result['formatted_result']}\n\n")
console.print(f"[bold green]所有工具调用结果已保存到 tool_results.txt 文件中[/]")