初次提交
This commit is contained in:
38
sci_mcp/__init__.py
Normal file
38
sci_mcp/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from .core.llm_tools import llm_tool,get_all_tools, get_all_tool_schemas, get_domain_tools, get_domain_tool_schemas
|
||||
|
||||
#general_mcp
|
||||
#from .general_mcp.searxng_query.searxng_query_tools import search_online
|
||||
|
||||
# #material_mcp
|
||||
from .material_mcp.mp_query.mp_query_tools import search_crystal_structures_from_materials_project,search_material_property_from_materials_project
|
||||
from .material_mcp.oqmd_query.oqmd_query_tools import query_material_from_OQMD
|
||||
from .material_mcp.knowledge_base_query.retrieval_from_knowledge_base_tools import retrieval_from_knowledge_base
|
||||
|
||||
from .material_mcp.mattersim_pred.property_pred_tools import predict_properties_MatterSim
|
||||
from .material_mcp.mattergen_gen.material_gen_tools import generate_material_MatterGen
|
||||
from .material_mcp.fairchem_structure_opt.structure_opt_tools import optimize_crystal_structure_FairChem
|
||||
from .material_mcp.pymatgen_cal.pymatgen_cal_tools import calculate_density_Pymatgen,get_element_composition_Pymatgen,calculate_symmetry_Pymatgen
|
||||
from .material_mcp.matgl_tools.matgl_tools import relax_crystal_structure_M3GNet,predict_formation_energy_M3GNet,run_molecular_dynamics_M3GNet,calculate_single_point_energy_M3GNet
|
||||
#chemistry_mcp
|
||||
from .chemistry_mcp.pubchem_tools.pubchem_tools import search_advanced_pubchem
|
||||
from .chemistry_mcp.rdkit_tools.rdkit_tools import (
|
||||
calculate_molecular_properties_rdkit,
|
||||
calculate_drug_likeness_rdkit,
|
||||
calculate_topological_descriptors_rdkit,
|
||||
generate_molecular_fingerprints_rdkit,
|
||||
calculate_molecular_similarity_rdkit,
|
||||
analyze_molecular_structure_rdkit,
|
||||
generate_molecular_conformer_rdkit,
|
||||
identify_scaffolds_rdkit,
|
||||
convert_between_chemical_formats_rdkit,
|
||||
standardize_molecule_rdkit,
|
||||
enumerate_stereoisomers_rdkit,
|
||||
perform_substructure_search_rdkit
|
||||
)
|
||||
from .chemistry_mcp.rxn_tools.rxn_tools import (
|
||||
predict_reaction_outcome_rxn,
|
||||
predict_reaction_topn_rxn,
|
||||
predict_reaction_properties_rxn,
|
||||
extract_reaction_actions_rxn
|
||||
)
|
||||
__all__ = ["llm_tool", "get_all_tools", "get_all_tool_schemas", "get_domain_tools", "get_domain_tool_schemas"]
|
||||
9
sci_mcp/chemistry_mcp/__init__.py
Normal file
9
sci_mcp/chemistry_mcp/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Chemistry MCP Module
|
||||
|
||||
This module provides tools for chemistry-related operations.
|
||||
"""
|
||||
|
||||
from .pubchem_tools import search_advanced_pubchem
|
||||
|
||||
__all__ = ["search_advanced_pubchem"]
|
||||
9
sci_mcp/chemistry_mcp/pubchem_tools/__init__.py
Normal file
9
sci_mcp/chemistry_mcp/pubchem_tools/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
PubChem Tools Module
|
||||
|
||||
This module provides tools for accessing and processing chemical data from PubChem.
|
||||
"""
|
||||
|
||||
from .pubchem_tools import search_advanced_pubchem
|
||||
|
||||
__all__ = ["search_advanced_pubchem"]
|
||||
325
sci_mcp/chemistry_mcp/pubchem_tools/pubchem_tools.py
Normal file
325
sci_mcp/chemistry_mcp/pubchem_tools/pubchem_tools.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
PubChem Tools Module
|
||||
|
||||
This module provides tools for searching and retrieving chemical compound information
|
||||
from the PubChem database using the PubChemPy library.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Union, Optional, Any
|
||||
|
||||
import pubchempy as pcp
|
||||
from ...core.llm_tools import llm_tool
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def compound_to_dict(compound: pcp.Compound) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a PubChem compound to a structured dictionary with relevant information.
|
||||
|
||||
Args:
|
||||
compound: PubChem compound object
|
||||
|
||||
Returns:
|
||||
Dictionary containing organized compound information
|
||||
"""
|
||||
if not compound:
|
||||
return {}
|
||||
|
||||
# Basic information
|
||||
result = {
|
||||
"basic_info": {
|
||||
"cid": compound.cid,
|
||||
"iupac_name": compound.iupac_name,
|
||||
"molecular_formula": compound.molecular_formula,
|
||||
"molecular_weight": compound.molecular_weight,
|
||||
"canonical_smiles": compound.canonical_smiles,
|
||||
"isomeric_smiles": compound.isomeric_smiles,
|
||||
},
|
||||
"identifiers": {
|
||||
"inchi": compound.inchi,
|
||||
"inchikey": compound.inchikey,
|
||||
},
|
||||
"physical_properties": {
|
||||
"xlogp": compound.xlogp,
|
||||
"exact_mass": compound.exact_mass,
|
||||
"monoisotopic_mass": compound.monoisotopic_mass,
|
||||
"tpsa": compound.tpsa,
|
||||
"complexity": compound.complexity,
|
||||
"charge": compound.charge,
|
||||
},
|
||||
"molecular_features": {
|
||||
"h_bond_donor_count": compound.h_bond_donor_count,
|
||||
"h_bond_acceptor_count": compound.h_bond_acceptor_count,
|
||||
"rotatable_bond_count": compound.rotatable_bond_count,
|
||||
"heavy_atom_count": compound.heavy_atom_count,
|
||||
"atom_stereo_count": compound.atom_stereo_count,
|
||||
"defined_atom_stereo_count": compound.defined_atom_stereo_count,
|
||||
"undefined_atom_stereo_count": compound.undefined_atom_stereo_count,
|
||||
"bond_stereo_count": compound.bond_stereo_count,
|
||||
"defined_bond_stereo_count": compound.defined_bond_stereo_count,
|
||||
"undefined_bond_stereo_count": compound.undefined_bond_stereo_count,
|
||||
"covalent_unit_count": compound.covalent_unit_count,
|
||||
}
|
||||
}
|
||||
|
||||
# Add synonyms if available
|
||||
if hasattr(compound, 'synonyms') and compound.synonyms:
|
||||
result["alternative_names"] = {
|
||||
"synonyms": compound.synonyms[:10] if len(compound.synonyms) > 10 else compound.synonyms
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _search_by_name(name: str, max_results: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search compounds by name asynchronously.
|
||||
|
||||
Args:
|
||||
name: Chemical compound name
|
||||
max_results: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of compound dictionaries
|
||||
"""
|
||||
try:
|
||||
compounds = await asyncio.to_thread(
|
||||
pcp.get_compounds, name, 'name', max_records=max_results
|
||||
)
|
||||
#print(compounds[0].to_dict())
|
||||
return [compound.to_dict() for compound in compounds]
|
||||
except Exception as e:
|
||||
logging.error(f"Error searching by name '{name}': {str(e)}")
|
||||
return [{"error": f"Error: {str(e)}"}]
|
||||
|
||||
|
||||
async def _search_by_smiles(smiles: str, max_results: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search compounds by SMILES notation asynchronously.
|
||||
|
||||
Args:
|
||||
smiles: SMILES notation of chemical compound
|
||||
max_results: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of compound dictionaries
|
||||
"""
|
||||
try:
|
||||
compounds = await asyncio.to_thread(
|
||||
pcp.get_compounds, smiles, 'smiles', max_records=max_results
|
||||
)
|
||||
return [compound.to_dict() for compound in compounds]
|
||||
except Exception as e:
|
||||
logging.error(f"Error searching by SMILES '{smiles}': {str(e)}")
|
||||
return [{"error": f"Error: {str(e)}"}]
|
||||
|
||||
|
||||
async def _search_by_formula(
|
||||
formula: str,
|
||||
max_results: int = 5,
|
||||
listkey_count: int = 5,
|
||||
listkey_start: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search compounds by molecular formula asynchronously.
|
||||
|
||||
Uses pagination with listkey parameters to avoid timeout errors when searching
|
||||
formulas that might return many results.
|
||||
|
||||
Args:
|
||||
formula: Molecular formula
|
||||
max_results: Maximum number of results to return
|
||||
listkey_count: Number of results per page (default: 5)
|
||||
listkey_start: Starting position for pagination (default: 0)
|
||||
|
||||
Returns:
|
||||
List of compound dictionaries
|
||||
"""
|
||||
try:
|
||||
# Use listkey parameters to avoid timeout errors
|
||||
compounds = await asyncio.to_thread(
|
||||
pcp.get_compounds,
|
||||
formula,
|
||||
'formula',
|
||||
max_records=max_results,
|
||||
listkey_count=listkey_count,
|
||||
listkey_start=listkey_start
|
||||
)
|
||||
|
||||
return [compound.to_dict() for compound in compounds]
|
||||
except Exception as e:
|
||||
logging.error(f"Error searching by formula '{formula}': {str(e)}")
|
||||
return [{"error": f"Error: {str(e)}"}]
|
||||
|
||||
|
||||
def _format_results_as_markdown(results: List[Dict[str, Any]], query_type: str, query_value: str) -> str:
|
||||
"""
|
||||
Format search results as a structured Markdown string.
|
||||
|
||||
Args:
|
||||
results: List of compound dictionaries from compound.to_dict()
|
||||
query_type: Type of search query (name, SMILES, formula)
|
||||
query_value: Value of the search query
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string
|
||||
"""
|
||||
if not results:
|
||||
return f"## PubChem Search Results\n\nNo compounds found for {query_type}: `{query_value}`"
|
||||
|
||||
if "error" in results[0]:
|
||||
return f"## PubChem Search Error\n\n{results[0]['error']}"
|
||||
|
||||
markdown = f"## PubChem Search Results\n\nSearch by {query_type}: `{query_value}`\n\nFound {len(results)} compound(s)\n\n"
|
||||
|
||||
for i, compound in enumerate(results):
|
||||
# Extract information from the compound.to_dict() structure
|
||||
cid = compound.get("cid", "N/A")
|
||||
iupac_name = compound.get("iupac_name", "Unknown")
|
||||
molecular_formula = compound.get("molecular_formula", "N/A")
|
||||
molecular_weight = compound.get("molecular_weight", "N/A")
|
||||
canonical_smiles = compound.get("canonical_smiles", "N/A")
|
||||
isomeric_smiles = compound.get("isomeric_smiles", "N/A")
|
||||
inchi = compound.get("inchi", "N/A")
|
||||
inchikey = compound.get("inchikey", "N/A")
|
||||
xlogp = compound.get("xlogp", "N/A")
|
||||
exact_mass = compound.get("exact_mass", "N/A")
|
||||
tpsa = compound.get("tpsa", "N/A")
|
||||
h_bond_donor_count = compound.get("h_bond_donor_count", "N/A")
|
||||
h_bond_acceptor_count = compound.get("h_bond_acceptor_count", "N/A")
|
||||
rotatable_bond_count = compound.get("rotatable_bond_count", "N/A")
|
||||
heavy_atom_count = compound.get("heavy_atom_count", "N/A")
|
||||
|
||||
# Get atoms and bonds information if available
|
||||
atoms = compound.get("atoms", [])
|
||||
bonds = compound.get("bonds", [])
|
||||
|
||||
# Format the markdown output
|
||||
markdown += f"### Compound {i+1}: {iupac_name}\n\n"
|
||||
|
||||
# Basic information section
|
||||
markdown += "#### Basic Information\n\n"
|
||||
markdown += f"- **CID**: {cid}\n"
|
||||
markdown += f"- **Formula**: {molecular_formula}\n"
|
||||
markdown += f"- **Molecular Weight**: {molecular_weight} g/mol\n"
|
||||
markdown += f"- **Canonical SMILES**: `{canonical_smiles}`\n"
|
||||
markdown += f"- **Isomeric SMILES**: `{isomeric_smiles}`\n"
|
||||
|
||||
# Identifiers section
|
||||
markdown += "\n#### Identifiers\n\n"
|
||||
markdown += f"- **InChI**: `{inchi}`\n"
|
||||
markdown += f"- **InChIKey**: `{inchikey}`\n"
|
||||
|
||||
# Physical properties section
|
||||
markdown += "\n#### Physical Properties\n\n"
|
||||
markdown += f"- **XLogP**: {xlogp}\n"
|
||||
markdown += f"- **Exact Mass**: {exact_mass}\n"
|
||||
markdown += f"- **TPSA**: {tpsa} Ų\n"
|
||||
|
||||
# Molecular features section
|
||||
markdown += "\n#### Molecular Features\n\n"
|
||||
markdown += f"- **H-Bond Donors**: {h_bond_donor_count}\n"
|
||||
markdown += f"- **H-Bond Acceptors**: {h_bond_acceptor_count}\n"
|
||||
markdown += f"- **Rotatable Bonds**: {rotatable_bond_count}\n"
|
||||
markdown += f"- **Heavy Atoms**: {heavy_atom_count}\n"
|
||||
|
||||
# Structure information
|
||||
markdown += "\n#### Structure Information\n\n"
|
||||
markdown += f"- **Atoms Count**: {len(atoms)}\n"
|
||||
markdown += f"- **Bonds Count**: {len(bonds)}\n"
|
||||
|
||||
# Add a summary of atom elements if available
|
||||
if atoms:
|
||||
elements = {}
|
||||
for atom in atoms:
|
||||
element = atom.get("element", "")
|
||||
if element:
|
||||
elements[element] = elements.get(element, 0) + 1
|
||||
|
||||
if elements:
|
||||
markdown += "- **Elements**: "
|
||||
elements_str = ", ".join([f"{element}: {count}" for element, count in elements.items()])
|
||||
markdown += f"{elements_str}\n"
|
||||
|
||||
markdown += "\n---\n\n" if i < len(results) - 1 else "\n"
|
||||
|
||||
return markdown
|
||||
|
||||
|
||||
@llm_tool(name="search_advanced_pubchem",
|
||||
description="Search for chemical compounds on PubChem database using name, SMILES notation, or molecular formula via PubChemPy library")
|
||||
async def search_advanced_pubchem(
|
||||
name: Optional[str] = None,
|
||||
smiles: Optional[str] = None,
|
||||
formula: Optional[str] = None,
|
||||
max_results: int = 3
|
||||
) -> str:
|
||||
"""
|
||||
Perform an advanced search for chemical compounds on PubChem using various identifiers.
|
||||
|
||||
This function allows searching by compound name, SMILES notation, or molecular formula.
|
||||
At least one search parameter must be provided. If multiple parameters are provided,
|
||||
the search will prioritize in the order: name > smiles > formula.
|
||||
|
||||
Args:
|
||||
name: Name of the chemical compound (e.g., "Aspirin", "Caffeine")
|
||||
smiles: SMILES notation of the chemical compound (e.g., "CC(=O)OC1=CC=CC=C1C(=O)O" for Aspirin)
|
||||
formula: Molecular formula (e.g., "C9H8O4" for Aspirin)
|
||||
max_results: Maximum number of results to return (default: 3)
|
||||
|
||||
Returns:
|
||||
A formatted Markdown string with search results
|
||||
|
||||
Examples:
|
||||
>>> search_advanced_pubchem(name="Aspirin")
|
||||
# Returns information about Aspirin
|
||||
|
||||
>>> search_advanced_pubchem(smiles="CC(=O)OC1=CC=CC=C1C(=O)O")
|
||||
# Returns information about compounds matching the SMILES notation
|
||||
|
||||
>>> search_advanced_pubchem(formula="C9H8O4", max_results=5)
|
||||
# Returns up to 5 compounds with the formula C9H8O4
|
||||
"""
|
||||
logging.info(f"Performing advanced PubChem search with parameters: name={name}, smiles={smiles}, formula={formula}, max_results={max_results}")
|
||||
|
||||
# Validate input parameters
|
||||
if name is None and smiles is None and formula is None:
|
||||
return "## PubChem Search Error\n\nError: At least one search parameter (name, smiles, or formula) must be provided"
|
||||
|
||||
# Validate max_results
|
||||
if max_results < 1:
|
||||
max_results = 1
|
||||
elif max_results > 10:
|
||||
max_results = 10 # Limit to 10 results to prevent overwhelming responses
|
||||
|
||||
try:
|
||||
results = []
|
||||
query_type = ""
|
||||
query_value = ""
|
||||
|
||||
# Prioritize search by name, then SMILES, then formula
|
||||
if name is not None:
|
||||
results = await _search_by_name(name, max_results)
|
||||
query_type = "name"
|
||||
query_value = name
|
||||
elif smiles is not None:
|
||||
results = await _search_by_smiles(smiles, max_results)
|
||||
query_type = "SMILES"
|
||||
query_value = smiles
|
||||
elif formula is not None:
|
||||
# Use pagination parameters for formula searches to avoid timeout
|
||||
# Using the default values from _search_by_formula
|
||||
results = await _search_by_formula(formula, max_results)
|
||||
query_type = "formula"
|
||||
query_value = formula
|
||||
|
||||
# Return results as markdown
|
||||
return _format_results_as_markdown(results, query_type, query_value)
|
||||
|
||||
except Exception as e:
|
||||
return f"## PubChem Search Error\n\nError: {str(e)}"
|
||||
9
sci_mcp/chemistry_mcp/rdkit_tools/__init__.py
Normal file
9
sci_mcp/chemistry_mcp/rdkit_tools/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
RDKit Tools Package
|
||||
|
||||
This package provides tools for molecular analysis, manipulation, and visualization
|
||||
using the RDKit library. It includes functions for calculating molecular descriptors,
|
||||
generating molecular fingerprints, analyzing molecular structures, and more.
|
||||
"""
|
||||
|
||||
from .rdkit_tools import *
|
||||
1154
sci_mcp/chemistry_mcp/rdkit_tools/rdkit_tools.py
Normal file
1154
sci_mcp/chemistry_mcp/rdkit_tools/rdkit_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
6
sci_mcp/chemistry_mcp/rxn_tools/__init__.py
Normal file
6
sci_mcp/chemistry_mcp/rxn_tools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
RXN Tools Module
|
||||
|
||||
This module provides tools for chemical reaction prediction and analysis
|
||||
using the IBM RXN for Chemistry API.
|
||||
"""
|
||||
772
sci_mcp/chemistry_mcp/rxn_tools/rxn_tools.py
Normal file
772
sci_mcp/chemistry_mcp/rxn_tools/rxn_tools.py
Normal file
@@ -0,0 +1,772 @@
|
||||
"""
|
||||
RXN Tools Module
|
||||
|
||||
This module provides tools for chemical reaction prediction and analysis
|
||||
using the IBM RXN for Chemistry API through the rxn4chemistry package.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Union, Optional, Any, Tuple
|
||||
|
||||
from rxn4chemistry import RXN4ChemistryWrapper
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import Chemistry_Config
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
DEFAULT_MAX_RESULTS = 3
|
||||
DEFAULT_TIMEOUT = 180 # seconds - increased from 60 to 180
|
||||
|
||||
|
||||
def _get_rxn_wrapper() -> RXN4ChemistryWrapper:
|
||||
"""
|
||||
Get an initialized RXN4Chemistry wrapper with API key and project set.
|
||||
|
||||
Returns:
|
||||
Initialized RXN4ChemistryWrapper instance with project set
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is not available
|
||||
"""
|
||||
# Try to get API key from environment or config
|
||||
api_key = Chemistry_Config.RXN4_CHEMISTRY_KEY or os.environ.get("RXN_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Error: RXN API key not found. Please set the RXN_API_KEY environment variable.")
|
||||
|
||||
# Initialize the wrapper
|
||||
wrapper = RXN4ChemistryWrapper(api_key=api_key)
|
||||
|
||||
try:
|
||||
# Create a new project
|
||||
project_name = f"RXN_Tools_Project_{os.getpid()}" # Add process ID to make name unique
|
||||
project_response = wrapper.create_project(project_name)
|
||||
|
||||
# Extract project ID from response
|
||||
# The API response format is nested: {'response': {'payload': {'id': '...'}}
|
||||
if project_response and isinstance(project_response, dict):
|
||||
# Try to extract project ID from different possible response formats
|
||||
project_id = None
|
||||
|
||||
# Direct format: {"project_id": "..."}
|
||||
if "project_id" in project_response:
|
||||
project_id = project_response["project_id"]
|
||||
|
||||
# Nested format: {"response": {"payload": {"id": "..."}}}
|
||||
elif "response" in project_response and isinstance(project_response["response"], dict):
|
||||
payload = project_response["response"].get("payload", {})
|
||||
if isinstance(payload, dict) and "id" in payload:
|
||||
project_id = payload["id"]
|
||||
|
||||
if project_id:
|
||||
wrapper.set_project(project_id)
|
||||
logger.info(f"RXN project '{project_name}' created and set successfully with ID: {project_id}")
|
||||
else:
|
||||
logger.warning(f"Could not extract project ID from response: {project_response}")
|
||||
else:
|
||||
logger.warning(f"Unexpected project creation response: {project_response}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating RXN project: {e}")
|
||||
|
||||
# Check if project is set
|
||||
if not hasattr(wrapper, "project_id") or not wrapper.project_id:
|
||||
logger.warning("No project set. Some API calls may fail.")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
def _format_reaction_markdown(reactants: str, products: List[str],
|
||||
confidence: Optional[List[float]] = None) -> str:
|
||||
"""
|
||||
Format reaction results as Markdown.
|
||||
|
||||
Args:
|
||||
reactants: SMILES of reactants
|
||||
products: List of product SMILES
|
||||
confidence: Optional list of confidence scores
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string
|
||||
"""
|
||||
markdown = f"## 反应预测结果\n\n"
|
||||
markdown += f"### 输入反应物\n\n`{reactants}`\n\n"
|
||||
|
||||
markdown += f"### 预测产物\n\n"
|
||||
|
||||
for i, product in enumerate(products):
|
||||
conf_str = f" (置信度: {confidence[i]:.2f})" if confidence and i < len(confidence) else ""
|
||||
markdown += f"{i+1}. `{product}`{conf_str}\n"
|
||||
|
||||
return markdown
|
||||
|
||||
|
||||
@llm_tool(name="predict_reaction_outcome_rxn",
|
||||
description="Predict chemical reaction outcomes for given reactants using IBM RXN for Chemistry API")
|
||||
async def predict_reaction_outcome_rxn(reactants: str) -> str:
|
||||
"""
|
||||
Predict chemical reaction outcomes for given reactants.
|
||||
|
||||
This function uses the IBM RXN for Chemistry API to predict the most likely
|
||||
products formed when the given reactants are combined.
|
||||
|
||||
Args:
|
||||
reactants: SMILES notation of reactants, multiple reactants separated by dots (.).
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string containing the predicted reaction results.
|
||||
|
||||
Examples:
|
||||
>>> predict_reaction_outcome_rxn("BrBr.c1ccc2cc3ccccc3cc2c1")
|
||||
# Returns predicted products of bromine and anthracene reaction
|
||||
"""
|
||||
try:
|
||||
# Get RXN wrapper
|
||||
wrapper = _get_rxn_wrapper()
|
||||
|
||||
# Clean input
|
||||
reactants = reactants.strip()
|
||||
|
||||
# Submit prediction
|
||||
response = await asyncio.to_thread(
|
||||
wrapper.predict_reaction, reactants
|
||||
)
|
||||
|
||||
if not response or "prediction_id" not in response:
|
||||
return "Error: 无法提交反应预测请求"
|
||||
|
||||
# 直接获取结果,而不是通过_wait_for_result
|
||||
results = await asyncio.to_thread(
|
||||
wrapper.get_predict_reaction_results,
|
||||
response["prediction_id"]
|
||||
)
|
||||
|
||||
# Extract products
|
||||
try:
|
||||
attempts = results.get("response", {}).get("payload", {}).get("attempts", [])
|
||||
if not attempts:
|
||||
return "Error: 未找到预测结果"
|
||||
|
||||
# Get the top predicted product
|
||||
product_smiles = attempts[0].get("smiles", "")
|
||||
confidence = attempts[0].get("confidence", None)
|
||||
|
||||
# Format results
|
||||
return _format_reaction_markdown(
|
||||
reactants,
|
||||
[product_smiles] if product_smiles else ["无法预测产物"],
|
||||
[confidence] if confidence is not None else None
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.error(f"Error parsing prediction results: {e}")
|
||||
return f"Error: 解析预测结果时出错: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in predict_reaction_outcome: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="predict_reaction_topn_rxn",
|
||||
description="Predict multiple possible products for chemical reactions using IBM RXN for Chemistry API")
|
||||
async def predict_reaction_topn_rxn(reactants: Union[str, List[str], List[List[str]]], topn: int = 3) -> str:
|
||||
"""
|
||||
Predict multiple possible products for chemical reactions.
|
||||
|
||||
This function uses the IBM RXN for Chemistry API to predict multiple products
|
||||
that may be formed from given reactants, ranked by likelihood. Suitable for
|
||||
scenarios where multiple reaction pathways need to be considered.
|
||||
|
||||
Args:
|
||||
reactants: Reactants in one of the following formats:
|
||||
- String: SMILES notation for a single reaction, multiple reactants separated by dots (.)
|
||||
- List of strings: Multiple reactants for a single reaction, each reactant as a SMILES string
|
||||
- List of lists of strings: Multiple reactions, each reaction composed of multiple reactant SMILES strings
|
||||
topn: Number of predicted products to return for each reaction, default is 3.
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string containing multiple predicted reaction results.
|
||||
|
||||
Examples:
|
||||
>>> predict_reaction_topn_rxn("BrBr.c1ccc2cc3ccccc3cc2c1", 5)
|
||||
# Returns top 5 possible products for bromine and anthracene reaction
|
||||
|
||||
>>> predict_reaction_topn_rxn(["BrBr", "c1ccc2cc3ccccc3cc2c1"], 3)
|
||||
# Returns top 3 possible products for bromine and anthracene reaction
|
||||
|
||||
>>> predict_reaction_topn_rxn([
|
||||
... ["BrBr", "c1ccc2cc3ccccc3cc2c1"],
|
||||
... ["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"]
|
||||
... ], 3)
|
||||
# Returns top 3 possible products for two different reactions
|
||||
"""
|
||||
try:
|
||||
# Get RXN wrapper
|
||||
wrapper = _get_rxn_wrapper()
|
||||
|
||||
# Validate topn
|
||||
if topn < 1:
|
||||
topn = 1
|
||||
elif topn > 10:
|
||||
topn = 10
|
||||
logger.warning("topn限制为最大10个结果")
|
||||
|
||||
# Process input to create precursors_lists
|
||||
precursors_lists = []
|
||||
|
||||
if isinstance(reactants, str):
|
||||
# Single reaction as string (e.g., "BrBr.c1ccc2cc3ccccc3cc2c1")
|
||||
reactants = reactants.strip()
|
||||
precursors_lists = [reactants.split(".")]
|
||||
# For display in results
|
||||
reactants_display = [reactants]
|
||||
|
||||
elif isinstance(reactants, list):
|
||||
if all(isinstance(r, str) for r in reactants):
|
||||
# Single reaction as list of strings (e.g., ["BrBr", "c1ccc2cc3ccccc3cc2c1"])
|
||||
precursors_lists = [reactants]
|
||||
# For display in results
|
||||
reactants_display = [".".join(reactants)]
|
||||
|
||||
elif all(isinstance(r, list) for r in reactants):
|
||||
# Multiple reactions as list of lists (e.g., [["BrBr", "c1ccc2cc3ccccc3cc2c1"], ["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"]])
|
||||
precursors_lists = reactants
|
||||
# For display in results
|
||||
reactants_display = [".".join(r) for r in reactants]
|
||||
|
||||
else:
|
||||
return "Error: 反应物列表格式无效,必须是字符串列表或字符串列表的列表"
|
||||
else:
|
||||
return "Error: 反应物参数类型无效,必须是字符串或列表"
|
||||
|
||||
# Submit prediction
|
||||
response = await asyncio.to_thread(
|
||||
wrapper.predict_reaction_batch_topn,
|
||||
precursors_lists=precursors_lists,
|
||||
topn=topn
|
||||
)
|
||||
|
||||
if not response or "task_id" not in response:
|
||||
return "Error: 无法提交多产物反应预测请求"
|
||||
|
||||
# 直接获取结果,不使用循环等待
|
||||
results = await asyncio.to_thread(
|
||||
wrapper.get_predict_reaction_batch_topn_results,
|
||||
response["task_id"]
|
||||
)
|
||||
|
||||
# Extract products
|
||||
try:
|
||||
# 记录结果的结构,以便调试
|
||||
logger.info(f"Results structure: {results.keys()}")
|
||||
|
||||
# 更灵活地获取结果,使用get方法并提供默认值
|
||||
reaction_results = results.get("result", [])
|
||||
|
||||
# 如果结果为空,尝试其他可能的键
|
||||
if not reaction_results and "predictions" in results:
|
||||
reaction_results = results.get("predictions", [])
|
||||
logger.info("Using 'predictions' key instead of 'result'")
|
||||
|
||||
# 如果结果仍然为空,尝试直接使用整个结果
|
||||
if not reaction_results and isinstance(results, list):
|
||||
reaction_results = results
|
||||
logger.info("Using entire results as list")
|
||||
|
||||
if not reaction_results:
|
||||
logger.warning(f"No reaction results found. Available keys: {results.keys()}")
|
||||
return "Error: 未找到预测结果。请检查API响应格式。"
|
||||
|
||||
# Format results for all reactions
|
||||
markdown = "## 反应预测结果\n\n"
|
||||
|
||||
# 确保reaction_results和reactants_display长度匹配
|
||||
if len(reaction_results) != len(reactants_display):
|
||||
logger.warning(f"Mismatch between results ({len(reaction_results)}) and reactants ({len(reactants_display)})")
|
||||
# 如果不匹配,使用较短的长度
|
||||
min_len = min(len(reaction_results), len(reactants_display))
|
||||
reaction_results = reaction_results[:min_len]
|
||||
reactants_display = reactants_display[:min_len]
|
||||
|
||||
for i, (reaction_result, reactants_str) in enumerate(zip(reaction_results, reactants_display), 1):
|
||||
if not reaction_result:
|
||||
markdown += f"### 反应 {i}: 未找到预测结果\n\n"
|
||||
continue
|
||||
|
||||
# 记录每个反应结果的结构
|
||||
logger.info(f"Reaction {i} result structure: {type(reaction_result)}")
|
||||
|
||||
# Extract products and confidences for this reaction
|
||||
products = []
|
||||
confidences = []
|
||||
|
||||
# 处理不同格式的反应结果
|
||||
if isinstance(reaction_result, list):
|
||||
# 标准格式:列表中的每个项目是一个预测
|
||||
for pred in reaction_result:
|
||||
if isinstance(pred, dict) and "smiles" in pred:
|
||||
# 检查smiles是否为列表
|
||||
if isinstance(pred["smiles"], list) and pred["smiles"]:
|
||||
products.append(pred["smiles"][0]) # 取列表中的第一个元素
|
||||
else:
|
||||
products.append(pred["smiles"])
|
||||
confidences.append(pred.get("confidence", 0.0))
|
||||
elif isinstance(reaction_result, dict):
|
||||
# 根据用户反馈,检查是否有'results'键
|
||||
if "results" in reaction_result:
|
||||
# 遍历results列表
|
||||
for pred in reaction_result.get("results", []):
|
||||
if isinstance(pred, dict) and "smiles" in pred:
|
||||
# 检查smiles是否为列表
|
||||
if isinstance(pred["smiles"], list) and pred["smiles"]:
|
||||
products.append(pred["smiles"][0]) # 取列表中的第一个元素
|
||||
else:
|
||||
products.append(pred["smiles"])
|
||||
confidences.append(pred.get("confidence", 0.0))
|
||||
# 替代格式:字典中直接包含预测
|
||||
elif "smiles" in reaction_result:
|
||||
# 检查smiles是否为列表
|
||||
if isinstance(reaction_result["smiles"], list) and reaction_result["smiles"]:
|
||||
products.append(reaction_result["smiles"][0]) # 取列表中的第一个元素
|
||||
else:
|
||||
products.append(reaction_result["smiles"])
|
||||
confidences.append(reaction_result.get("confidence", 0.0))
|
||||
# 另一种可能的格式
|
||||
elif "products" in reaction_result:
|
||||
for prod in reaction_result.get("products", []):
|
||||
if isinstance(prod, dict) and "smiles" in prod:
|
||||
# 检查smiles是否为列表
|
||||
if isinstance(prod["smiles"], list) and prod["smiles"]:
|
||||
products.append(prod["smiles"][0]) # 取列表中的第一个元素
|
||||
else:
|
||||
products.append(prod["smiles"])
|
||||
confidences.append(prod.get("confidence", 0.0))
|
||||
|
||||
# Add results for this reaction
|
||||
markdown += f"### 反应 {i}\n\n"
|
||||
markdown += f"**输入反应物:** `{reactants_str}`\n\n"
|
||||
|
||||
if products:
|
||||
markdown += "**预测产物:**\n\n"
|
||||
for j, (product, confidence) in enumerate(zip(products, confidences), 1):
|
||||
markdown += f"{j}. `{product}` (置信度: {confidence:.2f})\n"
|
||||
else:
|
||||
markdown += "**预测产物:** 无法解析产物结构\n\n"
|
||||
# 添加原始结果以便调试
|
||||
markdown += f"**原始结果:** `{reaction_result}`\n\n"
|
||||
|
||||
markdown += "\n"
|
||||
|
||||
return markdown
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing topn prediction results: {e}", exc_info=True)
|
||||
return f"Error: 解析多产物预测结果时出错: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in predict_reaction_topn: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
# @llm_tool(name="predict_retrosynthesis",
|
||||
# description="预测目标分子的逆合成路径")
|
||||
# async def predict_retrosynthesis(target_molecule: str, max_steps: int = 3) -> str:
|
||||
# """
|
||||
# 预测目标分子的逆合成路径。
|
||||
|
||||
# 此函数使用IBM RXN for Chemistry API建议可能的合成路线,
|
||||
# 将目标分子分解为可能商业可得的更简单前体。
|
||||
|
||||
# Args:
|
||||
# target_molecule: 目标分子的SMILES表示法。
|
||||
# max_steps: 考虑的最大逆合成步骤数,默认为3。
|
||||
|
||||
# Returns:
|
||||
# 包含预测逆合成路径的格式化Markdown字符串。
|
||||
|
||||
# Examples:
|
||||
# >>> predict_retrosynthesis("Brc1c2ccccc2c(Br)c2ccccc12")
|
||||
# # 返回目标分子的可能合成路线
|
||||
# """
|
||||
# try:
|
||||
# # Get RXN wrapper
|
||||
# wrapper = _get_rxn_wrapper()
|
||||
|
||||
# # Clean input
|
||||
# target_molecule = target_molecule.strip()
|
||||
|
||||
# # Validate max_steps
|
||||
# if max_steps < 1:
|
||||
# max_steps = 1
|
||||
# elif max_steps > 5:
|
||||
# max_steps = 5
|
||||
# logger.warning("max_steps限制为最大5步")
|
||||
|
||||
# # Submit prediction
|
||||
# response = await asyncio.to_thread(
|
||||
# wrapper.predict_automatic_retrosynthesis,
|
||||
# product=target_molecule,
|
||||
# max_steps=max_steps
|
||||
# )
|
||||
|
||||
# if not response or "prediction_id" not in response:
|
||||
# return "Error: 无法提交逆合成预测请求"
|
||||
|
||||
# # 直接获取结果,而不是通过_wait_for_result
|
||||
# results = await asyncio.to_thread(
|
||||
# wrapper.get_predict_automatic_retrosynthesis_results,
|
||||
# response["prediction_id"]
|
||||
# )
|
||||
|
||||
# # Extract retrosynthetic paths
|
||||
# try:
|
||||
# paths = results.get("retrosynthetic_paths", [])
|
||||
|
||||
# if not paths:
|
||||
# return "## 逆合成分析结果\n\n未找到可行的逆合成路径。目标分子可能太复杂或结构有问题。"
|
||||
|
||||
# # Format results
|
||||
# markdown = f"## 逆合成分析结果\n\n"
|
||||
# markdown += f"### 目标分子\n\n`{target_molecule}`\n\n"
|
||||
# markdown += f"### 找到的合成路径: {len(paths)}\n\n"
|
||||
|
||||
# # Limit to top 3 paths for readability
|
||||
# display_paths = paths[:3]
|
||||
|
||||
# for i, path in enumerate(display_paths, 1):
|
||||
# markdown += f"#### 路径 {i}\n\n"
|
||||
|
||||
# # Extract sequence information
|
||||
# sequence_id = path.get("sequenceId", "未知")
|
||||
# confidence = path.get("confidence", 0.0)
|
||||
|
||||
# markdown += f"**置信度:** {confidence:.2f}\n\n"
|
||||
|
||||
# # Extract steps
|
||||
# steps = path.get("steps", [])
|
||||
|
||||
# if steps:
|
||||
# markdown += "**合成步骤:**\n\n"
|
||||
|
||||
# for j, step in enumerate(steps, 1):
|
||||
# # Extract reactants and products
|
||||
# reactants = step.get("reactants", [])
|
||||
# reactant_smiles = [r.get("smiles", "") for r in reactants if "smiles" in r]
|
||||
|
||||
# product = step.get("product", {})
|
||||
# product_smiles = product.get("smiles", "")
|
||||
|
||||
# markdown += f"步骤 {j}: "
|
||||
|
||||
# if reactant_smiles and product_smiles:
|
||||
# markdown += f"`{'.' if len(reactant_smiles) > 1 else ''.join(reactant_smiles)}` → `{product_smiles}`\n\n"
|
||||
# else:
|
||||
# markdown += "反应细节不可用\n\n"
|
||||
# else:
|
||||
# markdown += "**合成步骤:** 未提供详细步骤\n\n"
|
||||
|
||||
# markdown += "---\n\n"
|
||||
|
||||
# if len(paths) > 3:
|
||||
# markdown += f"*注: 仅显示前3条路径,共找到{len(paths)}条可能的合成路径。*\n"
|
||||
|
||||
# return markdown
|
||||
|
||||
# except (KeyError, IndexError) as e:
|
||||
# logger.error(f"Error parsing retrosynthesis results: {e}")
|
||||
# return f"Error: 解析逆合成结果时出错: {str(e)}"
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in predict_retrosynthesis: {e}")
|
||||
# return f"Error: {str(e)}"
|
||||
|
||||
|
||||
# @llm_tool(name="predict_biocatalytic_retrosynthesis",
|
||||
# description="使用生物催化模型预测目标分子的逆合成路径")
|
||||
# async def predict_biocatalytic_retrosynthesis(target_molecule: str) -> str:
|
||||
# """
|
||||
# 使用生物催化模型预测目标分子的逆合成路径。
|
||||
|
||||
# 此函数使用IBM RXN for Chemistry API的专门生物催化模型,
|
||||
# 建议可能的酶催化合成路线来创建目标分子。
|
||||
|
||||
# Args:
|
||||
# target_molecule: 目标分子的SMILES表示法。
|
||||
|
||||
# Returns:
|
||||
# 包含预测生物催化逆合成路径的格式化Markdown字符串。
|
||||
|
||||
# Examples:
|
||||
# >>> predict_biocatalytic_retrosynthesis("OC1C(O)C=C(Br)C=C1")
|
||||
# # 返回目标分子的可能酶催化合成路线
|
||||
# """
|
||||
# try:
|
||||
# # Get RXN wrapper
|
||||
# wrapper = _get_rxn_wrapper()
|
||||
|
||||
# # Clean input
|
||||
# target_molecule = target_molecule.strip()
|
||||
|
||||
# # Submit prediction with enzymatic model
|
||||
# # Note: The model name might change in future API versions
|
||||
# response = await asyncio.to_thread(
|
||||
# wrapper.predict_automatic_retrosynthesis,
|
||||
# product=target_molecule,
|
||||
# ai_model="enzymatic-2021-04-16" # Use the enzymatic model
|
||||
# )
|
||||
|
||||
# if not response or "prediction_id" not in response:
|
||||
# return "Error: 无法提交生物催化逆合成预测请求"
|
||||
|
||||
# # 直接获取结果,而不是通过_wait_for_result
|
||||
# results = await asyncio.to_thread(
|
||||
# wrapper.get_predict_automatic_retrosynthesis_results,
|
||||
# response["prediction_id"]
|
||||
# )
|
||||
|
||||
# # Extract retrosynthetic paths
|
||||
# try:
|
||||
# paths = results.get("retrosynthetic_paths", [])
|
||||
|
||||
# if not paths:
|
||||
# return "## 生物催化逆合成分析结果\n\n未找到可行的酶催化合成路径。目标分子可能不适合酶催化或结构有问题。"
|
||||
|
||||
# # Format results
|
||||
# markdown = f"## 生物催化逆合成分析结果\n\n"
|
||||
# markdown += f"### 目标分子\n\n`{target_molecule}`\n\n"
|
||||
# markdown += f"### 找到的酶催化合成路径: {len(paths)}\n\n"
|
||||
|
||||
# # Limit to top 3 paths for readability
|
||||
# display_paths = paths[:3]
|
||||
|
||||
# for i, path in enumerate(display_paths, 1):
|
||||
# markdown += f"#### 路径 {i}\n\n"
|
||||
|
||||
# # Extract sequence information
|
||||
# sequence_id = path.get("sequenceId", "未知")
|
||||
# confidence = path.get("confidence", 0.0)
|
||||
|
||||
# markdown += f"**置信度:** {confidence:.2f}\n\n"
|
||||
|
||||
# # Extract steps
|
||||
# steps = path.get("steps", [])
|
||||
|
||||
# if steps:
|
||||
# markdown += "**酶催化步骤:**\n\n"
|
||||
|
||||
# for j, step in enumerate(steps, 1):
|
||||
# # Extract reactants and products
|
||||
# reactants = step.get("reactants", [])
|
||||
# reactant_smiles = [r.get("smiles", "") for r in reactants if "smiles" in r]
|
||||
|
||||
# product = step.get("product", {})
|
||||
# product_smiles = product.get("smiles", "")
|
||||
|
||||
# markdown += f"步骤 {j}: "
|
||||
|
||||
# if reactant_smiles and product_smiles:
|
||||
# markdown += f"`{'.' if len(reactant_smiles) > 1 else ''.join(reactant_smiles)}` → `{product_smiles}`\n\n"
|
||||
# else:
|
||||
# markdown += "反应细节不可用\n\n"
|
||||
|
||||
# # Add enzyme information if available
|
||||
# if "enzymeClass" in step:
|
||||
# markdown += f"*可能的酶类别: {step['enzymeClass']}*\n\n"
|
||||
# else:
|
||||
# markdown += "**合成步骤:** 未提供详细步骤\n\n"
|
||||
|
||||
# markdown += "---\n\n"
|
||||
|
||||
# if len(paths) > 3:
|
||||
# markdown += f"*注: 仅显示前3条路径,共找到{len(paths)}条可能的酶催化合成路径。*\n"
|
||||
|
||||
# return markdown
|
||||
|
||||
# except (KeyError, IndexError) as e:
|
||||
# logger.error(f"Error parsing biocatalytic retrosynthesis results: {e}")
|
||||
# return f"Error: 解析生物催化逆合成结果时出错: {str(e)}"
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in predict_biocatalytic_retrosynthesis: {e}")
|
||||
# return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="predict_reaction_properties_rxn",
|
||||
description="Predict chemical reaction properties such as atom mapping and yield using IBM RXN for Chemistry API")
|
||||
async def predict_reaction_properties_rxn(
|
||||
reaction: str,
|
||||
property_type: str = "atom-mapping"
|
||||
) -> str:
|
||||
"""
|
||||
Predict chemical reaction properties such as atom mapping and yield.
|
||||
|
||||
This function uses the IBM RXN for Chemistry API to predict various properties
|
||||
of chemical reactions, including atom-to-atom mapping (showing how atoms in
|
||||
reactants correspond to atoms in products) and potential reaction yields.
|
||||
|
||||
Args:
|
||||
reaction: SMILES notation of the reaction (reactants>>products).
|
||||
property_type: Type of property to predict. Options: "atom-mapping", "yield".
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string containing predicted reaction properties.
|
||||
|
||||
Examples:
|
||||
>>> predict_reaction_properties_rxn("CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F", "atom-mapping")
|
||||
# Returns atom mapping for the reaction
|
||||
"""
|
||||
try:
|
||||
# Get RXN wrapper
|
||||
wrapper = _get_rxn_wrapper()
|
||||
|
||||
# Clean input
|
||||
reaction = reaction.strip()
|
||||
|
||||
# Validate property_type
|
||||
valid_property_types = ["atom-mapping", "yield"]
|
||||
if property_type not in valid_property_types:
|
||||
return f"Error: 无效的属性类型 '{property_type}'。支持的类型: {', '.join(valid_property_types)}"
|
||||
|
||||
# Determine model based on property type
|
||||
ai_model = "atom-mapping-2020" if property_type == "atom-mapping" else "yield-2020-08-10"
|
||||
|
||||
# Submit prediction
|
||||
response = await asyncio.to_thread(
|
||||
wrapper.predict_reaction_properties,
|
||||
reactions=[reaction],
|
||||
ai_model=ai_model
|
||||
)
|
||||
|
||||
if not response or "response" not in response:
|
||||
return f"Error: 无法提交{property_type}预测请求"
|
||||
|
||||
# Extract results
|
||||
try:
|
||||
content = response.get("response", {}).get("payload", {}).get("content", [])
|
||||
|
||||
if not content:
|
||||
return f"Error: 未找到{property_type}预测结果"
|
||||
|
||||
# Format results based on property type
|
||||
markdown = f"## 反应{property_type}预测结果\n\n"
|
||||
markdown += f"### 输入反应\n\n`{reaction}`\n\n"
|
||||
|
||||
if property_type == "atom-mapping":
|
||||
# Extract mapped reaction
|
||||
mapped_reaction = content[0].get("value", "")
|
||||
|
||||
if not mapped_reaction:
|
||||
return "Error: 无法生成原子映射"
|
||||
|
||||
markdown += "### 原子映射结果\n\n"
|
||||
markdown += f"`{mapped_reaction}`\n\n"
|
||||
|
||||
# Split into reactants and products for explanation
|
||||
if ">>" in mapped_reaction:
|
||||
reactants, products = mapped_reaction.split(">>")
|
||||
markdown += "### 映射解释\n\n"
|
||||
markdown += "原子映射显示了反应物中的原子如何对应到产物中的原子。\n"
|
||||
markdown += "每个原子上的数字表示映射ID,相同ID的原子在反应前后是同一个原子。\n\n"
|
||||
markdown += f"**映射的反应物:** `{reactants}`\n\n"
|
||||
markdown += f"**映射的产物:** `{products}`\n"
|
||||
|
||||
elif property_type == "yield":
|
||||
# Extract predicted yield
|
||||
predicted_yield = content[0].get("value", "")
|
||||
|
||||
if not predicted_yield:
|
||||
return "Error: 无法预测反应产率"
|
||||
|
||||
try:
|
||||
yield_value = float(predicted_yield)
|
||||
markdown += "### 产率预测结果\n\n"
|
||||
markdown += f"**预测产率:** {yield_value:.1f}%\n\n"
|
||||
|
||||
# Add interpretation
|
||||
if yield_value < 30:
|
||||
markdown += "**解释:** 预测产率较低,反应可能效率不高。考虑优化反应条件或探索替代路线。\n"
|
||||
elif yield_value < 70:
|
||||
markdown += "**解释:** 预测产率中等,反应可能是可行的,但有优化空间。\n"
|
||||
else:
|
||||
markdown += "**解释:** 预测产率较高,反应可能非常有效。\n"
|
||||
except ValueError:
|
||||
markdown += f"**预测产率:** {predicted_yield}\n"
|
||||
|
||||
return markdown
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.error(f"Error parsing reaction properties results: {e}")
|
||||
return f"Error: 解析反应属性预测结果时出错: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in predict_reaction_properties: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="extract_reaction_actions_rxn",
|
||||
description="Extract structured reaction steps from text descriptions using IBM RXN for Chemistry API")
|
||||
async def extract_reaction_actions_rxn(reaction_text: str) -> str:
|
||||
"""
|
||||
Extract structured reaction steps from text descriptions.
|
||||
|
||||
This function uses the IBM RXN for Chemistry API to parse text descriptions
|
||||
of chemical procedures and extract structured actions representing the steps
|
||||
of the procedure.
|
||||
|
||||
Args:
|
||||
reaction_text: Text description of a chemical reaction procedure.
|
||||
|
||||
Returns:
|
||||
Formatted Markdown string containing the extracted reaction steps.
|
||||
|
||||
Examples:
|
||||
>>> extract_reaction_actions_rxn("To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.")
|
||||
# Returns structured steps extracted from the text
|
||||
"""
|
||||
try:
|
||||
# Get RXN wrapper
|
||||
wrapper = _get_rxn_wrapper()
|
||||
|
||||
# Clean input
|
||||
reaction_text = reaction_text.strip()
|
||||
|
||||
if not reaction_text:
|
||||
return "Error: 反应文本为空"
|
||||
|
||||
# Submit extraction request
|
||||
response = await asyncio.to_thread(
|
||||
wrapper.paragraph_to_actions,
|
||||
paragraph=reaction_text
|
||||
)
|
||||
|
||||
# 检查response是否存在
|
||||
if not response:
|
||||
return "Error: 无法从文本中提取反应步骤"
|
||||
|
||||
# 直接返回response,不做任何处理
|
||||
# 这是基于参考代码中直接打印results['actions']的方式
|
||||
# 我们假设response本身就是我们需要的结果
|
||||
return f"""## 反应步骤提取结果
|
||||
|
||||
### 输入文本
|
||||
|
||||
{reaction_text}
|
||||
|
||||
### 提取的反应步骤
|
||||
|
||||
```
|
||||
{response}
|
||||
```
|
||||
"""
|
||||
|
||||
return markdown
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_reaction_actions: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
84
sci_mcp/core/config.py
Executable file
84
sci_mcp/core/config.py
Executable file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Configuration Module
|
||||
|
||||
This module provides configuration settings for the Mars Toolkit.
|
||||
It includes API keys, endpoints, paths, and other configuration parameters.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
class Config:
|
||||
|
||||
@classmethod
|
||||
def as_dict(cls) -> Dict[str, Any]:
|
||||
"""Return all configuration settings as a dictionary"""
|
||||
return {
|
||||
key: value for key, value in cls.__dict__.items()
|
||||
if not key.startswith('__') and not callable(value)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def update(cls, **kwargs):
|
||||
"""Update configuration settings"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(cls, key):
|
||||
setattr(cls, key, value)
|
||||
|
||||
|
||||
class General_Config(Config):
|
||||
"""Configuration class for General MCP"""
|
||||
|
||||
# Searxng
|
||||
SEARXNG_HOST="http://192.168.168.1:40032/"
|
||||
SEARXNG_MAX_RESULTS=10
|
||||
|
||||
|
||||
class Material_Config(Config):
|
||||
|
||||
|
||||
# Materials Project
|
||||
MP_API_KEY = 'PMASAg256b814q3OaSRWeVc7MKx4mlKI'
|
||||
MP_ENDPOINT = 'https://api.materialsproject.org/'
|
||||
MP_TOPK = 3
|
||||
|
||||
LOCAL_MP_PROPS_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Props/'
|
||||
LOCAL_MP_CIF_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/MPDatasets/'
|
||||
# Proxy
|
||||
HTTP_PROXY = ''#'http://192.168.168.1:20171'
|
||||
HTTPS_PROXY = ''#'http://192.168.168.1:20171'
|
||||
|
||||
# FairChem
|
||||
FAIRCHEM_MODEL_PATH = '/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/support/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
|
||||
FMAX = 0.05
|
||||
|
||||
# MatterGen
|
||||
MATTERGENMODEL_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/support/pretrained_models/mattergen_ckpt'
|
||||
MATTERGEN_ROOT='/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/mattergen_gen/mattergen'
|
||||
MATTERGENMODEL_RESULT_PATH = 'results/'
|
||||
|
||||
# Dify
|
||||
DIFY_ROOT_URL = 'http://192.168.191.101:6080'
|
||||
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
|
||||
|
||||
#temp root
|
||||
TEMP_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/material'
|
||||
|
||||
|
||||
class Chemistry_Config(Config):
|
||||
|
||||
|
||||
TEMP_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/chemistry'
|
||||
|
||||
RXN4_CHEMISTRY_KEY='apk-8928522a146c2503f30b16d9909222d7583f412ee8f1049f08d32a089ba88d34'
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
material_config = Material_Config()
|
||||
general_config = General_Config()
|
||||
chemistry_config = Chemistry_Config()
|
||||
|
||||
|
||||
|
||||
315
sci_mcp/core/llm_tools.py
Executable file
315
sci_mcp/core/llm_tools.py
Executable file
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
LLM Tools Module
|
||||
|
||||
This module provides decorators and utilities for defining, registering, and managing LLM tools.
|
||||
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
|
||||
registered tools for use with LLM APIs.
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
|
||||
import docstring_parser
|
||||
from pydantic import BaseModel, create_model, Field
|
||||
|
||||
# Registry to store all registered tools
|
||||
_TOOL_REGISTRY = {}
|
||||
# Mapping of domain names to their module paths
|
||||
_DOMAIN_MODULE_MAPPING = {
|
||||
'material': 'sci_mcp.material_mcp',
|
||||
'general': 'sci_mcp.general_mcp',
|
||||
'biology': 'sci_mcp.biology_mcp',
|
||||
'chemistry': 'sci_mcp.chemistry_mcp'
|
||||
}
|
||||
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
|
||||
"""
|
||||
Decorator to mark a function as an LLM tool.
|
||||
|
||||
This decorator registers the function as an LLM tool, generates a JSON schema for it,
|
||||
and makes it available for retrieval through the get_tools function.
|
||||
|
||||
Args:
|
||||
name: Optional custom name for the tool. If not provided, the function name will be used.
|
||||
description: Optional custom description for the tool. If not provided, the function's
|
||||
docstring will be used.
|
||||
|
||||
Returns:
|
||||
The decorated function with additional attributes for LLM tool functionality.
|
||||
|
||||
Example:
|
||||
@llm_tool(name="weather_lookup", description="Get current weather for a location")
|
||||
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
|
||||
'''Get weather information for a specific location.'''
|
||||
# Implementation...
|
||||
return {"temperature": 22.5, "conditions": "sunny"}
|
||||
"""
|
||||
# Handle case when decorator is used without parentheses: @llm_tool
|
||||
if callable(name):
|
||||
func = name
|
||||
name = None
|
||||
description = None
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
|
||||
def decorator(func: Callable) -> Callable:
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
return decorator
|
||||
|
||||
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
|
||||
"""Implementation of the llm_tool decorator."""
|
||||
# Get function signature and docstring
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func) or ""
|
||||
parsed_doc = docstring_parser.parse(doc)
|
||||
|
||||
# Determine tool name
|
||||
tool_name = name or func.__name__
|
||||
|
||||
# Determine tool description
|
||||
tool_description = description or doc
|
||||
|
||||
# Create parameter properties for JSON schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self parameter for methods
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = None if param.default is inspect.Parameter.empty else param.default
|
||||
param_required = param.default is inspect.Parameter.empty
|
||||
|
||||
# Get parameter description from docstring if available
|
||||
param_desc = ""
|
||||
for param_doc in parsed_doc.params:
|
||||
if param_doc.arg_name == param_name:
|
||||
param_desc = param_doc.description
|
||||
break
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0] # The actual type
|
||||
if len(args) > 1 and isinstance(args[1], str):
|
||||
param_desc = args[1] # The description
|
||||
|
||||
# Create property for parameter
|
||||
param_schema = {
|
||||
"type": _get_json_type(param_type),
|
||||
"description": param_desc,
|
||||
"title": param_name.replace("_", " ").title()
|
||||
}
|
||||
|
||||
# Add default value if available
|
||||
if param_default is not None:
|
||||
param_schema["default"] = param_default
|
||||
|
||||
properties[param_name] = param_schema
|
||||
|
||||
# Add to required list if no default value
|
||||
if param_required:
|
||||
required.append(param_name)
|
||||
|
||||
# Create OpenAI format JSON schema
|
||||
openai_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Create MCP format JSON schema
|
||||
mcp_schema = {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
|
||||
# Create Pydantic model for args schema
|
||||
field_definitions = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = ... if param.default is inspect.Parameter.empty else param.default
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0]
|
||||
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
|
||||
else:
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default))
|
||||
|
||||
# Create args schema model
|
||||
model_name = f"{tool_name.title().replace('_', '')}Schema"
|
||||
args_schema = create_model(model_name, **field_definitions)
|
||||
|
||||
# 根据原始函数是否是异步函数来创建相应类型的包装函数
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Attach metadata to function
|
||||
wrapper.is_llm_tool = True
|
||||
wrapper.tool_name = tool_name
|
||||
wrapper.tool_description = tool_description
|
||||
wrapper.openai_schema = openai_schema
|
||||
wrapper.mcp_schema = mcp_schema
|
||||
wrapper.args_schema = args_schema
|
||||
|
||||
# Register the tool
|
||||
_TOOL_REGISTRY[tool_name] = wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_all_tools() -> Dict[str, Callable]:
|
||||
"""
|
||||
Get all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping tool names to their corresponding functions.
|
||||
"""
|
||||
return _TOOL_REGISTRY
|
||||
|
||||
def get_all_tool_schemas(schema_type='openai') -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get JSON schemas for all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
|
||||
"""
|
||||
return [tool.mcp_schema for tool in _TOOL_REGISTRY.values()] if schema_type == 'mcp' else [tool.openai_schema for tool in _TOOL_REGISTRY.values()]
|
||||
|
||||
def import_domain_tools(domains: List[str]) -> None:
|
||||
"""
|
||||
Import tools from specified domains to ensure they are registered.
|
||||
|
||||
This function dynamically imports modules from the specified domains to ensure
|
||||
that all tools decorated with @llm_tool are registered in the _TOOL_REGISTRY.
|
||||
|
||||
Args:
|
||||
domains: List of domain names (e.g., ['material', 'general'])
|
||||
"""
|
||||
for domain in domains:
|
||||
if domain not in _DOMAIN_MODULE_MAPPING:
|
||||
continue
|
||||
|
||||
module_path = _DOMAIN_MODULE_MAPPING[domain]
|
||||
try:
|
||||
# Import the base module
|
||||
base_module = importlib.import_module(module_path)
|
||||
base_path = os.path.dirname(base_module.__file__)
|
||||
|
||||
# Recursively import all submodules
|
||||
for _, name, is_pkg in pkgutil.walk_packages([base_path], f"{module_path}."):
|
||||
try:
|
||||
importlib.import_module(name)
|
||||
except ImportError as e:
|
||||
print(f"Error importing {name}: {e}")
|
||||
except ImportError as e:
|
||||
print(f"Error importing domain {domain}: {e}")
|
||||
|
||||
def get_domain_tools(domains: List[str]) -> Dict[str, Dict[str, Callable]]:
|
||||
"""
|
||||
Get tools from specified domains.
|
||||
|
||||
Args:
|
||||
domains: List of domain names (e.g., ['material', 'general'])
|
||||
|
||||
Returns:
|
||||
A dictionary that maps tool names and their functions
|
||||
"""
|
||||
# First, ensure all tools from the specified domains are imported and registered
|
||||
import_domain_tools(domains)
|
||||
|
||||
domain_tools = {}
|
||||
for domain in domains:
|
||||
if domain not in _DOMAIN_MODULE_MAPPING:
|
||||
continue
|
||||
|
||||
domain_module_prefix = _DOMAIN_MODULE_MAPPING[domain]
|
||||
|
||||
|
||||
for tool_name, tool_func in _TOOL_REGISTRY.items():
|
||||
# Check if the tool's module belongs to this domain
|
||||
if hasattr(tool_func, "__module__") and tool_func.__module__.startswith(domain_module_prefix):
|
||||
domain_tools[tool_name] = tool_func
|
||||
|
||||
|
||||
return domain_tools
|
||||
|
||||
def get_domain_tool_schemas(domains: List[str],schema_type='openai') -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get JSON schemas for tools from specified domains.
|
||||
|
||||
Args:
|
||||
domains: List of domain names (e.g., ['material', 'general'])
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping domain names to lists of tool schemas
|
||||
"""
|
||||
# First, get all domain tools
|
||||
domain_tools = get_domain_tools(domains)
|
||||
|
||||
if schema_type == 'mcp':
|
||||
tools_schema_list = [tool.mcp_schema for tool in domain_tools.values()]
|
||||
else:
|
||||
tools_schema_list = [tool.openai_schema for tool in domain_tools.values()]
|
||||
|
||||
|
||||
return tools_schema_list
|
||||
|
||||
|
||||
def _get_json_type(python_type: Any) -> str:
|
||||
"""
|
||||
Convert Python type to JSON schema type.
|
||||
|
||||
Args:
|
||||
python_type: Python type annotation
|
||||
|
||||
Returns:
|
||||
Corresponding JSON schema type as string
|
||||
"""
|
||||
if python_type is str:
|
||||
return "string"
|
||||
elif python_type is int:
|
||||
return "integer"
|
||||
elif python_type is float:
|
||||
return "number"
|
||||
elif python_type is bool:
|
||||
return "boolean"
|
||||
elif python_type is list or python_type is List:
|
||||
return "array"
|
||||
elif python_type is dict or python_type is Dict:
|
||||
return "object"
|
||||
else:
|
||||
# Default to string for complex types
|
||||
return "string"
|
||||
0
sci_mcp/general_mcp/__init__.py
Normal file
0
sci_mcp/general_mcp/__init__.py
Normal file
0
sci_mcp/general_mcp/searxng_query/__init__.py
Normal file
0
sci_mcp/general_mcp/searxng_query/__init__.py
Normal file
78
sci_mcp/general_mcp/searxng_query/searxng_query_tools.py
Normal file
78
sci_mcp/general_mcp/searxng_query/searxng_query_tools.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Search Online Module
|
||||
|
||||
This module provides functions for searching information on the web.
|
||||
"""
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import general_config
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Annotated, Any, Dict, List, Union
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
import mcp.types as types
|
||||
|
||||
|
||||
@llm_tool(name="search_online_searxng", description="Search scientific information online using searxng")
|
||||
async def search_online_searxng(
|
||||
query: Annotated[str, "Search term"],
|
||||
num_results: Annotated[int, "Number of results"] = 5
|
||||
) -> str:
|
||||
"""
|
||||
Searches for scientific information online and returns results as a formatted string.
|
||||
|
||||
Args:
|
||||
query: Search term for scientific content
|
||||
num_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
Formatted string with search results (titles, snippets, links)
|
||||
"""
|
||||
# lzy: 此部分到正式发布时可能要删除,因为searxng 已在本地部署,因此本地调试时无需设置代理
|
||||
os.environ['HTTP_PROXY'] = ''
|
||||
os.environ['HTTPS_PROXY'] = ''
|
||||
try:
|
||||
max_results = min(int(num_results), general_config.SEARXNG_MAX_RESULTS)
|
||||
search = SearxSearchWrapper(
|
||||
searx_host=general_config.SEARXNG_HOST,
|
||||
categories=["science",],
|
||||
k=num_results
|
||||
)
|
||||
|
||||
# Execute search in a separate thread to avoid blocking the event loop
|
||||
# since SearxSearchWrapper doesn't have native async support
|
||||
loop = asyncio.get_event_loop()
|
||||
raw_results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: search.results(query, language=['en','zh'], num_results=max_results)
|
||||
)
|
||||
|
||||
# Transform results into structured format
|
||||
formatted_results = []
|
||||
for result in raw_results:
|
||||
formatted_results.append({
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("snippet", ""),
|
||||
"link": result.get("link", ""),
|
||||
"source": result.get("source", "")
|
||||
})
|
||||
|
||||
# Format results into a readable Markdown string
|
||||
result_str = f"Search Results for '{query}' ({len(formatted_results)} items):\n\n"
|
||||
if len(formatted_results) > 0:
|
||||
for i, res in enumerate(formatted_results):
|
||||
title = res.get("title", "No Title")
|
||||
snippet = res.get("snippet", "No Snippet")
|
||||
link = res.get("link", "No Link")
|
||||
source = res.get("source", "No Source")
|
||||
result_str += f"{i + 1}. **{title}**\n"
|
||||
result_str += f" - Snippet: {snippet}\n"
|
||||
result_str += f" - Link: [{link}]({link})\n"
|
||||
result_str += f" - Source: {source}\n\n"
|
||||
else:
|
||||
result_str += "No results found.\n"
|
||||
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
34
sci_mcp/material_mcp/__init__.py
Executable file
34
sci_mcp/material_mcp/__init__.py
Executable file
@@ -0,0 +1,34 @@
|
||||
|
||||
# # Core modules
|
||||
# from mars_toolkit.core.config import config
|
||||
|
||||
|
||||
# # Basic tools
|
||||
# from mars_toolkit.misc.misc_tools import get_current_time
|
||||
|
||||
# # Compute modules
|
||||
# from mars_toolkit.compute.material_gen import generate_material
|
||||
# from mars_toolkit.compute.property_pred import predict_properties
|
||||
# from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
|
||||
|
||||
# # Query modules
|
||||
# from mars_toolkit.query.mp_query import (
|
||||
# search_material_property_from_material_project,
|
||||
# get_crystal_structures_from_materials_project,
|
||||
# get_mpid_from_formula
|
||||
# )
|
||||
# from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
# from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
# from mars_toolkit.query.web_search import search_online
|
||||
|
||||
# # Visualization modules
|
||||
|
||||
|
||||
# from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# __version__ = "0.1.0"
|
||||
# __all__ = ["llm_tool", "get_tools", "get_tool_schemas"]
|
||||
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Structure Optimization Module
|
||||
|
||||
This module provides functions for optimizing crystal structures using the FairChem model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from ase.io import read, write
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
from ..support.utils import convert_structure, remove_symmetry_equiv_xyz, read_structure_from_file_name_or_content_string
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FairChem模型
|
||||
calc = None
|
||||
|
||||
def init_model():
|
||||
"""初始化FairChem模型"""
|
||||
global calc
|
||||
if calc is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from fairchem.core import OCPCalculator
|
||||
calc = OCPCalculator(checkpoint_path=material_config.FAIRCHEM_MODEL_PATH)
|
||||
logger.info("FairChem model initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def generate_symmetry_cif(structure: Structure) -> str:
|
||||
"""
|
||||
生成对称性CIF
|
||||
|
||||
Args:
|
||||
structure: Pymatgen Structure对象
|
||||
|
||||
Returns:
|
||||
CIF格式的字符串
|
||||
"""
|
||||
analyzer = SpacegroupAnalyzer(structure)
|
||||
structure_refined = analyzer.get_refined_structure()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
|
||||
cif_writer.write_file(tmp_file.name)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
os.unlink(tmp_file.name)
|
||||
return content
|
||||
|
||||
def optimize_structure(atoms: Atoms, output_format: str, fmax: float = 0.05) -> str:
|
||||
"""
|
||||
优化晶体结构
|
||||
|
||||
Args:
|
||||
atoms: ASE Atoms对象
|
||||
output_format: 输出格式 (cif, xyz, vasp等)
|
||||
fmax: 力收敛标准
|
||||
|
||||
Returns:
|
||||
包含优化结果的格式化字符串
|
||||
"""
|
||||
atoms.calc = calc
|
||||
|
||||
try:
|
||||
# 捕获优化过程的输出
|
||||
temp_output = StringIO()
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = temp_output
|
||||
|
||||
# 执行优化
|
||||
dyn = FIRE(FrechetCellFilter(atoms))
|
||||
dyn.run(fmax=fmax)
|
||||
|
||||
# 恢复标准输出并获取日志
|
||||
sys.stdout = original_stdout
|
||||
optimization_log = temp_output.getvalue()
|
||||
temp_output.close()
|
||||
|
||||
# 获取总能量
|
||||
total_energy = atoms.get_potential_energy()
|
||||
|
||||
# 处理优化后的结构
|
||||
if output_format == "cif":
|
||||
optimized_structure = Structure.from_ase_atoms(atoms)
|
||||
content = generate_symmetry_cif(optimized_structure)
|
||||
content = remove_symmetry_equiv_xyz(content)
|
||||
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||||
write(tmp_file.name, atoms)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
# 格式化返回结果
|
||||
format_result = f"""
|
||||
The following is the optimized crystal structure information:
|
||||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||
**Total Energy: {total_energy} eV**
|
||||
|
||||
#### Optimizing Log:
|
||||
```text
|
||||
{optimization_log}
|
||||
```
|
||||
### Optimized {output_format.upper()} Content:
|
||||
```
|
||||
{content}
|
||||
```
|
||||
"""
|
||||
return format_result
|
||||
except Exception as e:
|
||||
return f"Error: Failed to optimize structure: {str(e)}"
|
||||
|
||||
@llm_tool(name="optimize_crystal_structure_FairChem",
|
||||
description="Optimizes crystal structures using the FairChem model")
|
||||
async def optimize_crystal_structure_FairChem(
|
||||
structure_source: str,
|
||||
format_type: str = "auto",
|
||||
optimization_level: str = "normal"
|
||||
) -> str:
|
||||
"""
|
||||
Optimizes a crystal structure to find its lowest energy configuration.
|
||||
|
||||
Args:
|
||||
structure_source: Either a file name or direct structure content (CIF, XYZ, POSCAR)
|
||||
format_type: Structure format type (auto, cif, xyz, poscar). Default "auto" will attempt to detect format.
|
||||
optimization_level: Optimization precision level (quick, normal, precise)
|
||||
|
||||
Returns:
|
||||
Optimized structure with total energy and optimization details
|
||||
"""
|
||||
# 确保模型已初始化
|
||||
if calc is None:
|
||||
init_model()
|
||||
|
||||
# 设置优化参数
|
||||
fmax_values = {
|
||||
"quick": 0.1,
|
||||
"normal": 0.05,
|
||||
"precise": 0.01
|
||||
}
|
||||
fmax = fmax_values.get(optimization_level, 0.05)
|
||||
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_optimization():
|
||||
try:
|
||||
# 处理输入结构
|
||||
content, actual_format = read_structure_from_file_name_or_content_string(structure_source, format_type)
|
||||
|
||||
# 转换格式映射
|
||||
format_mapping = {
|
||||
"cif": "cif",
|
||||
"xyz": "xyz",
|
||||
"poscar": "vasp",
|
||||
"vasp": "vasp"
|
||||
}
|
||||
final_format = format_mapping.get(actual_format.lower(), "cif")
|
||||
|
||||
# 转换结构
|
||||
atoms = convert_structure(final_format, content)
|
||||
if atoms is None:
|
||||
return f"Error: Unable to convert input structure. Please check if the format is correct."
|
||||
|
||||
# 优化结构
|
||||
return optimize_structure(atoms, final_format, fmax=fmax)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: Failed to optimize structure: {str(e)}"
|
||||
|
||||
return await asyncio.to_thread(run_optimization)
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import codecs
|
||||
import json
|
||||
import requests
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import material_config
|
||||
|
||||
@llm_tool(
|
||||
name="retrieval_from_knowledge_base",
|
||||
description="Retrieve information from local materials science literature knowledge base"
|
||||
)
|
||||
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
|
||||
"""
|
||||
检索本地材料科学文献知识库中的相关信息
|
||||
|
||||
Args:
|
||||
query: 查询字符串,如材料名称"CsPbBr3"
|
||||
topk: 返回结果数量,默认3条
|
||||
|
||||
Returns:
|
||||
包含文档ID、标题和相关性分数的字典
|
||||
"""
|
||||
# 设置Dify API的URL端点
|
||||
url = f'{material_config.DIFY_ROOT_URL}/v1/chat-messages'
|
||||
|
||||
# 配置请求头,包含API密钥和内容类型
|
||||
headers = {
|
||||
'Authorization': f'Bearer {material_config.DIFY_API_KEY}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# 准备请求数据
|
||||
data = {
|
||||
"inputs": {"topK": topk}, # 设置返回的最大结果数量
|
||||
"query": query, # 设置查询字符串
|
||||
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
|
||||
"conversation_id": "", # 不使用会话ID,每次都是独立查询
|
||||
"user": "abc-123" # 用户标识符
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送POST请求到Dify API并获取响应
|
||||
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
|
||||
response = requests.post(url, headers=headers, json=data, timeout=1111)
|
||||
|
||||
# 获取响应文本
|
||||
response_text = response.text
|
||||
|
||||
# 解码响应文本中的Unicode转义序列
|
||||
response_text = codecs.decode(response_text, 'unicode_escape')
|
||||
|
||||
# 将响应文本解析为JSON对象
|
||||
result_json = json.loads(response_text)
|
||||
|
||||
# 从响应中提取元数据
|
||||
metadata = result_json.get("metadata", {})
|
||||
|
||||
# 构建包含关键信息的结果字典
|
||||
useful_info = {
|
||||
"id": metadata.get("document_id"), # 文档ID
|
||||
"title": result_json.get("title"), # 文档标题
|
||||
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
|
||||
"score": metadata.get("score") # 相关性分数
|
||||
}
|
||||
|
||||
# 返回提取的有用信息
|
||||
return json.dumps(useful_info, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并处理所有可能的异常,返回错误信息
|
||||
return f"Error: {str(e)}"
|
||||
8
sci_mcp/material_mcp/matgl_tools/__init__.py
Normal file
8
sci_mcp/material_mcp/matgl_tools/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
MatGL Tools Module
|
||||
|
||||
This module provides tools for material structure relaxation and property prediction
|
||||
using MatGL (Materials Graph Library) models.
|
||||
"""
|
||||
|
||||
from .matgl_tools import *
|
||||
487
sci_mcp/material_mcp/matgl_tools/matgl_tools.py
Normal file
487
sci_mcp/material_mcp/matgl_tools/matgl_tools.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
MatGL Tools Module
|
||||
|
||||
This module provides tools for material structure relaxation and property prediction
|
||||
using MatGL (Materials Graph Library) models.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from ...core.config import material_config
|
||||
|
||||
|
||||
import warnings
|
||||
import json
|
||||
from typing import Dict, List, Union, Optional, Any
|
||||
|
||||
import torch
|
||||
from pymatgen.core import Lattice, Structure
|
||||
from pymatgen.ext.matproj import MPRester
|
||||
from pymatgen.io.ase import AseAtomsAdaptor
|
||||
|
||||
import matgl
|
||||
from matgl.ext.ase import Relaxer, MolecularDynamics, PESCalculator
|
||||
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
|
||||
|
||||
from ...core.llm_tools import llm_tool
|
||||
import os
|
||||
from ..support.utils import read_structure_from_file_name_or_content_string
|
||||
# To suppress warnings for clearer output
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
|
||||
@llm_tool(name="relax_crystal_structure_M3GNet",
|
||||
description="Optimize crystal structure geometry using M3GNet universal potential from a structure file or content string")
|
||||
async def relax_crystal_structure_M3GNet(
|
||||
structure_source: str,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Optimize crystal structure geometry to find its equilibrium configuration.
|
||||
|
||||
Uses M3GNet universal potential for fast and accurate structure relaxation without DFT.
|
||||
Accepts a structure file or content string.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
fmax: Maximum force tolerance for convergence in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string with the relaxation results or an error message.
|
||||
"""
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Create a relaxer and relax the structure
|
||||
relaxer = Relaxer(potential=pot)
|
||||
relax_results = relaxer.relax(structure, fmax=fmax)
|
||||
|
||||
# Get the relaxed structure
|
||||
relaxed_structure = relax_results["final_structure"]
|
||||
reduced_formula = relaxed_structure.composition.reduced_formula
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = relaxed_structure.lattice
|
||||
volume = relaxed_structure.volume
|
||||
density = relaxed_structure.density
|
||||
symmetry = relaxed_structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(relaxed_structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Structure Relaxation\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Force Tolerance**: `{fmax} eV/Å`\n"
|
||||
f"- **Status**: `Successfully relaxed`\n\n"
|
||||
f"### Relaxed Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {relaxed_structure.lattice.pbc[0]!s:5s} {relaxed_structure.lattice.pbc[1]!s:5s} {relaxed_structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(relaxed_structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error during structure relaxation: {str(e)}"
|
||||
|
||||
|
||||
# 内部函数,用于结构优化,返回结构对象而不是格式化字符串
|
||||
async def _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source: str,
|
||||
fmax: float = 0.01
|
||||
) -> Union[Structure, str]:
|
||||
"""
|
||||
内部使用的结构优化函数,返回结构对象而不是格式化字符串。
|
||||
|
||||
Args:
|
||||
structure_source: 结构文件名或内容字符串
|
||||
fmax: 力收敛阈值 (eV/Å)
|
||||
|
||||
Returns:
|
||||
优化后的结构对象或错误信息
|
||||
"""
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Create a relaxer and relax the structure
|
||||
relaxer = Relaxer(potential=pot)
|
||||
relax_results = relaxer.relax(structure, fmax=fmax)
|
||||
|
||||
# Get the relaxed structure
|
||||
relaxed_structure = relax_results["final_structure"]
|
||||
|
||||
return relaxed_structure
|
||||
except Exception as e:
|
||||
return f"Error during structure relaxation: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="predict_formation_energy_M3GNet",
|
||||
description="Predict the formation energy of a crystal structure using the M3GNet formation energy model from a structure file or content string, with optional structure optimization")
|
||||
async def predict_formation_energy_M3GNet(
|
||||
structure_source: str,
|
||||
optimize_structure: bool = True,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Predict the formation energy of a crystal structure using the M3GNet formation energy model.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
optimize_structure: Whether to optimize the structure before prediction (default: True).
|
||||
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string containing the predicted formation energy in eV/atom or an error message.
|
||||
"""
|
||||
try:
|
||||
# 获取结构(优化或不优化)
|
||||
if optimize_structure:
|
||||
# 使用内部函数优化结构
|
||||
structure = await _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source=structure_source,
|
||||
fmax=fmax
|
||||
)
|
||||
|
||||
# 检查优化是否成功
|
||||
if isinstance(structure, str) and structure.startswith("Error"):
|
||||
return structure
|
||||
else:
|
||||
# 直接读取结构,不进行优化
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# 加载预训练模型
|
||||
model = matgl.load_model("M3GNet-MP-2018.6.1-Eform")
|
||||
|
||||
# 预测形成能
|
||||
eform = model.predict_structure(structure)
|
||||
reduced_formula = structure.composition.reduced_formula
|
||||
|
||||
# 构建结果字符串
|
||||
optimization_status = "optimized" if optimize_structure else "non-optimized"
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = structure.lattice
|
||||
volume = structure.volume
|
||||
density = structure.density
|
||||
symmetry = structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Formation Energy Prediction\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Structure Status**: `{optimization_status}`\n"
|
||||
f"- **Formation Energy**: `{float(eform):.3f} eV/atom`\n\n"
|
||||
f"### Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="run_molecular_dynamics_M3GNet",
|
||||
description="Run molecular dynamics simulation on a crystal structure using M3GNet universal potential, with optional structure optimization")
|
||||
async def run_molecular_dynamics_M3GNet(
|
||||
structure_source: str,
|
||||
temperature_K: float = 300,
|
||||
steps: int = 100,
|
||||
optimize_structure: bool = True,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Run molecular dynamics simulation on a crystal structure using M3GNet universal potential.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
temperature_K: Temperature for MD simulation in Kelvin (default: 300).
|
||||
steps: Number of MD steps to run (default: 100).
|
||||
optimize_structure: Whether to optimize the structure before simulation (default: True).
|
||||
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string containing the simulation results, including final potential energy.
|
||||
"""
|
||||
try:
|
||||
# 获取结构(优化或不优化)
|
||||
if optimize_structure:
|
||||
# 使用内部函数优化结构
|
||||
structure = await _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source=structure_source,
|
||||
fmax=fmax
|
||||
)
|
||||
|
||||
# 检查优化是否成功
|
||||
if isinstance(structure, str) and structure.startswith("Error"):
|
||||
return structure
|
||||
else:
|
||||
# 直接读取结构,不进行优化
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Convert pymatgen structure to ASE atoms
|
||||
ase_adaptor = AseAtomsAdaptor()
|
||||
atoms = ase_adaptor.get_atoms(structure)
|
||||
|
||||
# Initialize the velocity according to Maxwell Boltzmann distribution
|
||||
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
|
||||
|
||||
# Create the MD class and run simulation
|
||||
driver = MolecularDynamics(atoms, potential=pot, temperature=temperature_K)
|
||||
driver.run(steps)
|
||||
|
||||
# Get final potential energy
|
||||
final_energy = atoms.get_potential_energy()
|
||||
|
||||
# Get final structure
|
||||
final_structure = ase_adaptor.get_structure(atoms)
|
||||
reduced_formula = final_structure.composition.reduced_formula
|
||||
|
||||
# 构建结果字符串
|
||||
optimization_status = "optimized" if optimize_structure else "non-optimized"
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = final_structure.lattice
|
||||
volume = final_structure.volume
|
||||
density = final_structure.density
|
||||
symmetry = final_structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(final_structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Molecular Dynamics Simulation\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Structure Status**: `{optimization_status}`\n"
|
||||
f"- **Temperature**: `{temperature_K} K`\n"
|
||||
f"- **Steps**: `{steps}`\n"
|
||||
f"- **Final Potential Energy**: `{float(final_energy):.3f} eV`\n\n"
|
||||
f"### Final Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {final_structure.lattice.pbc[0]!s:5s} {final_structure.lattice.pbc[1]!s:5s} {final_structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(final_structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="calculate_single_point_energy_M3GNet",
|
||||
description="Calculate single point energy of a crystal structure using M3GNet universal potential, with optional structure optimization")
|
||||
async def calculate_single_point_energy_M3GNet(
|
||||
structure_source: str,
|
||||
optimize_structure: bool = True,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Calculate single point energy of a crystal structure using M3GNet universal potential.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
optimize_structure: Whether to optimize the structure before calculation (default: True).
|
||||
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string containing the calculated potential energy in eV.
|
||||
"""
|
||||
try:
|
||||
# 获取结构(优化或不优化)
|
||||
if optimize_structure:
|
||||
# 使用内部函数优化结构
|
||||
structure = await _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source=structure_source,
|
||||
fmax=fmax
|
||||
)
|
||||
|
||||
# 检查优化是否成功
|
||||
if isinstance(structure, str) and structure.startswith("Error"):
|
||||
return structure
|
||||
else:
|
||||
# 直接读取结构,不进行优化
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Convert pymatgen structure to ASE atoms
|
||||
ase_adaptor = AseAtomsAdaptor()
|
||||
atoms = ase_adaptor.get_atoms(structure)
|
||||
|
||||
# Set up the calculator for atoms object
|
||||
calc = PESCalculator(pot)
|
||||
atoms.set_calculator(calc)
|
||||
|
||||
# Calculate potential energy
|
||||
energy = atoms.get_potential_energy()
|
||||
reduced_formula = structure.composition.reduced_formula
|
||||
|
||||
# 构建结果字符串
|
||||
optimization_status = "optimized" if optimize_structure else "non-optimized"
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = structure.lattice
|
||||
volume = structure.volume
|
||||
density = structure.density
|
||||
symmetry = structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Single Point Energy Calculation\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Structure Status**: `{optimization_status}`\n"
|
||||
f"- **Potential Energy**: `{float(energy):.3f} eV`\n\n"
|
||||
f"### Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
#Error: Bad serialized model or bad model name. It is possible that you have an older model cached. Please clear your cache by running `python -c "import matgl; matgl.clear_cache()"`
|
||||
# @llm_tool(name="predict_band_gap",
|
||||
# description="Predict the band gap of a crystal structure using MEGNet multi-fidelity model from either a chemical formula or CIF file, with structure optimization")
|
||||
# async def predict_band_gap(
|
||||
# formula: str = None,
|
||||
# cif_file_name: str = None,
|
||||
# method: str = "PBE",
|
||||
# fmax: float = 0.01
|
||||
# ) -> str:
|
||||
# """
|
||||
# Predict the band gap of a crystal structure using the MEGNet multi-fidelity band gap model.
|
||||
|
||||
# First optimizes the crystal structure using M3GNet universal potential, then predicts
|
||||
# the band gap based on the relaxed structure for more accurate results.
|
||||
|
||||
# Accepts either a chemical formula (searches Materials Project database) or a CIF file.
|
||||
|
||||
# Args:
|
||||
# formula: Chemical formula to retrieve from Materials Project (e.g., "Fe2O3").
|
||||
# cif_file_name: Name of CIF file in temp directory to use as structure source.
|
||||
# method: The DFT method to use for the prediction. Options are "PBE", "GLLB-SC", "HSE", or "SCAN".
|
||||
# Default is "PBE".
|
||||
# fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
# Returns:
|
||||
# A string containing the predicted band gap in eV or an error message.
|
||||
# """
|
||||
# try:
|
||||
# # First, relax the crystal structure
|
||||
# relaxed_result = await relax_crystal_structure(
|
||||
# formula=formula,
|
||||
# cif_file_name=cif_file_name,
|
||||
# fmax=fmax
|
||||
# )
|
||||
|
||||
# # Check if relaxation was successful
|
||||
# if isinstance(relaxed_result, str) and relaxed_result.startswith("Error"):
|
||||
# return relaxed_result
|
||||
|
||||
# # Use the relaxed structure for band gap prediction
|
||||
# structure = relaxed_result
|
||||
|
||||
# if structure is None:
|
||||
# return "Error: Failed to obtain a valid relaxed structure"
|
||||
|
||||
# # Load the pre-trained MEGNet band gap model
|
||||
# model = matgl.load_model("MEGNet-MP-2019.4.1-BandGap-mfi")
|
||||
|
||||
# # Map method name to index
|
||||
# method_map = {"PBE": 0, "GLLB-SC": 1, "HSE": 2, "SCAN": 3}
|
||||
# if method not in method_map:
|
||||
# return f"Error: Unsupported method: {method}. Choose from PBE, GLLB-SC, HSE, or SCAN."
|
||||
|
||||
# # Set the graph label based on the method
|
||||
# graph_attrs = torch.tensor([method_map[method]])
|
||||
|
||||
# # Predict the band gap using the relaxed structure
|
||||
# bandgap = model.predict_structure(structure=structure, state_attr=graph_attrs)
|
||||
# reduced_formula = structure.reduced_formula
|
||||
|
||||
# # Return the band gap as a string
|
||||
# return f"The predicted band gap for relaxed {reduced_formula} using {method} method is {float(bandgap):.3f} eV."
|
||||
# except Exception as e:
|
||||
# return f"Error: {str(e)}"
|
||||
|
||||
|
||||
240
sci_mcp/material_mcp/mattergen_gen/material_gen_tools.py
Executable file
240
sci_mcp/material_mcp/mattergen_gen/material_gen_tools.py
Executable file
@@ -0,0 +1,240 @@
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
import tempfile
|
||||
import os
|
||||
import datetime
|
||||
import asyncio
|
||||
import zipfile
|
||||
import shutil
|
||||
import re
|
||||
import multiprocessing
|
||||
from multiprocessing import Process, Queue
|
||||
from pathlib import Path
|
||||
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
||||
import logging
|
||||
# 设置多进程启动方法为spawn,解决CUDA初始化错误
|
||||
try:
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
except RuntimeError:
|
||||
# 如果已经设置过启动方法,会抛出RuntimeError
|
||||
pass
|
||||
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from ase.io import read, write
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
# 导入路径已更新
|
||||
from ...core.llm_tools import llm_tool
|
||||
from .mattergen_wrapper import *
|
||||
|
||||
# 使用mattergen_wrapper
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
|
||||
def convert_values(data_str):
|
||||
"""
|
||||
将字符串转换为字典
|
||||
|
||||
Args:
|
||||
data_str: JSON字符串
|
||||
|
||||
Returns:
|
||||
解析后的数据,如果解析失败则返回原字符串
|
||||
"""
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return data_str # 如果无法解析为JSON,返回原字符串
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
|
||||
"""
|
||||
Preprocess a property value based on its name, converting it to the appropriate type.
|
||||
|
||||
Args:
|
||||
property_name: Name of the property
|
||||
property_value: Value of the property (can be string, float, or int)
|
||||
|
||||
Returns:
|
||||
Tuple of (property_name, processed_value)
|
||||
|
||||
Raises:
|
||||
ValueError: If the property value is invalid for the given property name
|
||||
"""
|
||||
valid_properties = [
|
||||
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
|
||||
"energy_above_hull", "formation_energy_per_atom", "space_group",
|
||||
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
|
||||
]
|
||||
|
||||
if property_name not in valid_properties:
|
||||
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
|
||||
|
||||
# Process property_value if it's a string
|
||||
if isinstance(property_value, str):
|
||||
try:
|
||||
# Try to convert string to float for numeric properties
|
||||
if property_name != "chemical_system":
|
||||
property_value = float(property_value)
|
||||
except ValueError:
|
||||
# If conversion fails, keep as string (for chemical_system)
|
||||
pass
|
||||
|
||||
# Handle special cases for properties that need specific types
|
||||
if property_name == "chemical_system":
|
||||
if isinstance(property_value, (int, float)):
|
||||
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
|
||||
property_value = str(property_value)
|
||||
elif property_name == "space_group" :
|
||||
space_group = property_value
|
||||
if space_group < 1 or space_group > 230:
|
||||
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
|
||||
|
||||
return property_name, property_value
|
||||
|
||||
|
||||
def main(
|
||||
output_path: str,
|
||||
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
|
||||
model_path: str | None = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
config_overrides: list[str] | None = None,
|
||||
checkpoint_epoch: Literal["best", "last"] | int = "last",
|
||||
properties_to_condition_on: TargetProperty | None = None,
|
||||
sampling_config_path: str | None = None,
|
||||
sampling_config_name: str = "default",
|
||||
sampling_config_overrides: list[str] | None = None,
|
||||
record_trajectories: bool = True,
|
||||
diffusion_guidance_factor: float | None = None,
|
||||
strict_checkpoint_loading: bool = True,
|
||||
target_compositions: list[dict[str, int]] | None = None,
|
||||
):
|
||||
"""
|
||||
Evaluate diffusion model against molecular metrics.
|
||||
|
||||
Args:
|
||||
model_path: Path to DiffusionLightningModule checkpoint directory.
|
||||
output_path: Path to output directory.
|
||||
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
|
||||
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
|
||||
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
|
||||
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
|
||||
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
|
||||
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
|
||||
record: Whether to record the trajectories of the generated structures. (default: True)
|
||||
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
|
||||
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
|
||||
Only supported for models trained for crystal structure prediction (CSP) (default: None)
|
||||
|
||||
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
|
||||
"""
|
||||
assert (
|
||||
pretrained_name is not None or model_path is not None
|
||||
), "Either pretrained_name or model_path must be provided."
|
||||
assert (
|
||||
pretrained_name is None or model_path is None
|
||||
), "Only one of pretrained_name or model_path can be provided."
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
sampling_config_overrides = sampling_config_overrides or []
|
||||
config_overrides = config_overrides or []
|
||||
properties_to_condition_on = properties_to_condition_on or {}
|
||||
target_compositions = target_compositions or []
|
||||
|
||||
if pretrained_name is not None:
|
||||
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
|
||||
pretrained_name, config_overrides=config_overrides
|
||||
)
|
||||
else:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch=checkpoint_epoch,
|
||||
config_overrides=config_overrides,
|
||||
strict_checkpoint_loading=strict_checkpoint_loading,
|
||||
)
|
||||
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name=sampling_config_name,
|
||||
sampling_config_path=_sampling_config_path,
|
||||
sampling_config_overrides=sampling_config_overrides,
|
||||
record_trajectories=record_trajectories,
|
||||
diffusion_guidance_factor=(
|
||||
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
|
||||
),
|
||||
target_compositions_dict=target_compositions,
|
||||
)
|
||||
generator.generate(output_dir=Path(output_path))
|
||||
|
||||
|
||||
@llm_tool(name="generate_material_MatterGen", description="Generate crystal structures with optional property constraints using MatterGen model")
|
||||
def generate_material_MatterGen(
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
This unified function can generate materials in three modes:
|
||||
1. Unconditional generation (no properties specified)
|
||||
2. Single property conditional generation (one property specified)
|
||||
3. Multi-property conditional generation (multiple properties specified)
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints. Can be:
|
||||
- None or empty dict for unconditional generation
|
||||
- Dict with single key-value pair for single property conditioning
|
||||
- Dict with multiple key-value pairs for multi-property conditioning
|
||||
Valid property names include: "dft_band_gap", "chemical_system", etc.
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
|
||||
|
||||
# 导入MatterGenService
|
||||
from .mattergen_service import MatterGenService
|
||||
logger.info("子进程成功导入MatterGenService")
|
||||
|
||||
# 获取MatterGenService实例
|
||||
service = MatterGenService.get_instance()
|
||||
logger.info("子进程成功获取MatterGenService实例")
|
||||
|
||||
# 使用服务生成材料
|
||||
logger.info("子进程开始调用generate方法...")
|
||||
result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
|
||||
logger.info("子进程generate方法调用完成")
|
||||
if "Error generating structures" in result:
|
||||
return f"Error: Invalid properties {properties}."
|
||||
else:
|
||||
return result
|
||||
466
sci_mcp/material_mcp/mattergen_gen/mattergen_service.py
Executable file
466
sci_mcp/material_mcp/mattergen_gen/mattergen_service.py
Executable file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
MatterGen service for mars_toolkit.
|
||||
|
||||
This module provides a service for generating crystal structures using MatterGen.
|
||||
The service initializes the CrystalGenerator once and reuses it for multiple
|
||||
generation requests, improving performance.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
import threading
|
||||
import torch
|
||||
|
||||
from .mattergen_wrapper import *
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
def format_cif_content(content):
|
||||
"""
|
||||
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
||||
|
||||
Args:
|
||||
content: String containing CIF content, possibly with PK headers
|
||||
|
||||
Returns:
|
||||
Formatted string with each CIF file properly labeled and formatted
|
||||
"""
|
||||
# 如果内容为空,直接返回空字符串
|
||||
if not content or content.strip() == '':
|
||||
return ''
|
||||
|
||||
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
|
||||
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
|
||||
|
||||
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
|
||||
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
|
||||
|
||||
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
|
||||
# 但我们需要保留这个字段在每个CIF文件中
|
||||
cif_blocks = []
|
||||
|
||||
# 查找所有_chemical_formula_structural的位置
|
||||
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
|
||||
|
||||
# 如果没有找到任何_chemical_formula_structural,返回空字符串
|
||||
if not formula_positions:
|
||||
return ''
|
||||
|
||||
# 分割CIF块
|
||||
for i in range(len(formula_positions)):
|
||||
start_pos = formula_positions[i]
|
||||
# 如果是最后一个块,结束位置是字符串末尾
|
||||
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
|
||||
|
||||
cif_block = content[start_pos:end_pos].strip()
|
||||
|
||||
# 提取formula值
|
||||
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
|
||||
if formula_match:
|
||||
formula = formula_match.group(1)
|
||||
cif_blocks.append((formula, cif_block))
|
||||
|
||||
# 格式化输出
|
||||
result = []
|
||||
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
|
||||
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]\n"
|
||||
result.append(formatted)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def extract_cif_file_from_zip(cifs_zip_path: str):
|
||||
"""
|
||||
Extract CIF files from a zip archive, extract formula from each CIF file,
|
||||
and save each CIF file with its formula as the filename.
|
||||
|
||||
Args:
|
||||
cifs_zip_path: Path to the zip file
|
||||
|
||||
Returns:
|
||||
list: List of tuples containing (index, formula, cif_path)
|
||||
"""
|
||||
result_dict = {}
|
||||
if os.path.exists(cifs_zip_path):
|
||||
with open(cifs_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
cifs_content = format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))
|
||||
pattern = r'\[cif (\d+) begin\]\n(.*?)\n\[cif \1 end\]'
|
||||
matches = re.findall(pattern, cifs_content, re.DOTALL)
|
||||
|
||||
# 处理每个匹配项,提取formula并保存CIF文件
|
||||
saved_files = []
|
||||
for idx, cif_content in matches:
|
||||
# 提取data_{formula}中的formula
|
||||
formula_match = re.search(r'data_([^\s]+)', cif_content)
|
||||
if formula_match:
|
||||
formula = formula_match.group(1)
|
||||
# 构建保存路径
|
||||
cif_path = os.path.join(material_config.TEMP_ROOT, f"{formula}.cif")
|
||||
# 保存CIF文件
|
||||
with open(cif_path, 'w') as f:
|
||||
f.write(cif_content)
|
||||
saved_files.append((idx, formula, cif_path))
|
||||
|
||||
return saved_files
|
||||
|
||||
|
||||
class MatterGenService:
|
||||
"""
|
||||
Service for generating crystal structures using MatterGen.
|
||||
|
||||
This service initializes the CrystalGenerator once and reuses it for multiple
|
||||
generation requests, improving performance.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
# 模型到GPU ID的映射
|
||||
MODEL_TO_GPU = {
|
||||
"mattergen_base": "0", # 基础模型使用GPU 0
|
||||
"dft_mag_density": "1", # 磁密度模型使用GPU 1
|
||||
"dft_bulk_modulus": "2", # 体积模量模型使用GPU 2
|
||||
"dft_shear_modulus": "3", # 剪切模量模型使用GPU 3
|
||||
"energy_above_hull": "4", # 能量模型使用GPU 4
|
||||
"formation_energy_per_atom": "5", # 形成能模型使用GPU 5
|
||||
"space_group": "6", # 空间群模型使用GPU 6
|
||||
"hhi_score": "7", # HHI评分模型使用GPU 7
|
||||
"ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0
|
||||
"chemical_system": "1", # 化学系统模型使用GPU 1
|
||||
"dft_band_gap": "2", # 带隙模型使用GPU 2
|
||||
"dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3
|
||||
"chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""
|
||||
Get the singleton instance of MatterGenService.
|
||||
|
||||
Returns:
|
||||
MatterGenService: The singleton instance.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the MatterGenService.
|
||||
|
||||
This initializes the base generator without any property conditioning.
|
||||
Specific generators for different property conditions will be initialized
|
||||
on demand.
|
||||
"""
|
||||
self._generators = {}
|
||||
self._output_dir = material_config.TEMP_ROOT
|
||||
|
||||
# 确保输出目录存在
|
||||
if not os.path.exists(self._output_dir):
|
||||
os.makedirs(self._output_dir)
|
||||
|
||||
# 初始化基础生成器(无条件生成)
|
||||
self._init_base_generator()
|
||||
|
||||
def _init_base_generator(self):
|
||||
"""
|
||||
Initialize the base generator for unconditional generation.
|
||||
"""
|
||||
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.")
|
||||
return
|
||||
|
||||
logger.info(f"Initializing base MatterGen generator from {model_path}")
|
||||
|
||||
try:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch="last",
|
||||
config_overrides=[],
|
||||
strict_checkpoint_loading=True,
|
||||
)
|
||||
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=None,
|
||||
batch_size=2, # 默认值,可在生成时覆盖
|
||||
num_batches=1, # 默认值,可在生成时覆盖
|
||||
sampling_config_name="default",
|
||||
sampling_config_path=None,
|
||||
sampling_config_overrides=[],
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=0.0,
|
||||
target_compositions_dict=[],
|
||||
)
|
||||
|
||||
self._generators["base"] = generator
|
||||
logger.info("Base MatterGen generator initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize base MatterGen generator: {e}")
|
||||
|
||||
def _get_or_create_generator(
|
||||
self,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
):
|
||||
"""
|
||||
Get or create a generator for the specified properties.
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
tuple: (generator, generator_key, properties_to_condition_on, gpu_id)
|
||||
"""
|
||||
# 如果没有属性约束,使用基础生成器
|
||||
if not properties:
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0
|
||||
return self._generators.get("base"), "base", None, gpu_id
|
||||
|
||||
# 处理属性约束
|
||||
properties_to_condition_on = {}
|
||||
for property_name, property_value in properties.items():
|
||||
properties_to_condition_on[property_name] = property_value
|
||||
|
||||
# 确定模型目录
|
||||
if len(properties) == 1:
|
||||
# 单属性条件
|
||||
property_name = list(properties.keys())[0]
|
||||
property_to_model = {
|
||||
"dft_mag_density": "dft_mag_density",
|
||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||
"dft_shear_modulus": "dft_shear_modulus",
|
||||
"energy_above_hull": "energy_above_hull",
|
||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||
"space_group": "space_group",
|
||||
"hhi_score": "hhi_score",
|
||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||
"chemical_system": "chemical_system",
|
||||
"dft_band_gap": "dft_band_gap"
|
||||
}
|
||||
model_dir = property_to_model.get(property_name, property_name)
|
||||
generator_key = f"single_{property_name}"
|
||||
else:
|
||||
# 多属性条件
|
||||
property_keys = set(properties.keys())
|
||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||
model_dir = "dft_mag_density_hhi_score"
|
||||
generator_key = "multi_dft_mag_density_hhi_score"
|
||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||
model_dir = "chemical_system_energy_above_hull"
|
||||
generator_key = "multi_chemical_system_energy_above_hull"
|
||||
else:
|
||||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||||
first_property = list(properties.keys())[0]
|
||||
model_dir = first_property
|
||||
generator_key = f"multi_{first_property}_etc"
|
||||
|
||||
# 获取对应的GPU ID
|
||||
gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0
|
||||
|
||||
# 构建完整的模型路径
|
||||
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
# 检查模型目录是否存在
|
||||
if not os.path.exists(model_path):
|
||||
# 如果特定模型不存在,回退到基础模型
|
||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
generator_key = "base"
|
||||
|
||||
# 检查是否已经有这个生成器
|
||||
if generator_key in self._generators:
|
||||
# 更新生成器的参数
|
||||
generator = self._generators[generator_key]
|
||||
generator.batch_size = batch_size
|
||||
generator.num_batches = num_batches
|
||||
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
|
||||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||||
|
||||
# 创建新的生成器
|
||||
try:
|
||||
logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}")
|
||||
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch="last",
|
||||
config_overrides=[],
|
||||
strict_checkpoint_loading=True,
|
||||
)
|
||||
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name="default",
|
||||
sampling_config_path=None,
|
||||
sampling_config_overrides=[],
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0,
|
||||
target_compositions_dict=[],
|
||||
)
|
||||
|
||||
self._generators[generator_key] = generator
|
||||
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
|
||||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
|
||||
# 回退到基础生成器
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0")
|
||||
return self._generators.get("base"), "base", None, base_gpu_id
|
||||
|
||||
def generate(
|
||||
self,
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
str: Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
|
||||
|
||||
# 处理字符串输入(如果提供)
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||
|
||||
# 如果为None,默认为空字典
|
||||
properties = properties or {}
|
||||
|
||||
# 获取或创建生成器和GPU ID
|
||||
generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator(
|
||||
properties, batch_size, num_batches, diffusion_guidance_factor
|
||||
)
|
||||
print("gpu_id",gpu_id)
|
||||
if generator is None:
|
||||
return "Error: Failed to initialize MatterGen generator"
|
||||
|
||||
# 使用torch.cuda.set_device()直接设置当前GPU
|
||||
try:
|
||||
# 将字符串类型的gpu_id转换为整数
|
||||
cuda_device_id = int(gpu_id)
|
||||
torch.cuda.set_device(cuda_device_id)
|
||||
logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}")
|
||||
print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.")
|
||||
|
||||
# 生成结构
|
||||
try:
|
||||
|
||||
output_dir= Path(self._output_dir+f'/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}')
|
||||
Path.mkdir(output_dir, parents=True, exist_ok=True)
|
||||
generator.generate(output_dir=output_dir)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating structures: {e}")
|
||||
return f"Error generating structures: {e}"
|
||||
|
||||
# 创建字典存储文件内容
|
||||
result_dict = {}
|
||||
|
||||
# 定义文件路径
|
||||
cif_zip_path = os.path.join(str(output_dir), f"generated_crystals_cif.zip")
|
||||
xyz_file_path = os.path.join(str(output_dir), f"generated_crystals.extxyz")
|
||||
trajectories_zip_path = os.path.join(str(output_dir), f"generated_trajectories.zip")
|
||||
|
||||
# 读取CIF压缩文件
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
|
||||
# 根据生成类型创建描述性提示
|
||||
if not properties:
|
||||
generation_type = "unconditional"
|
||||
title = "Generated Material Structures"
|
||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||
property_description = "unconditionally"
|
||||
elif len(properties) == 1:
|
||||
generation_type = "single_property"
|
||||
property_name = list(properties.keys())[0]
|
||||
property_value = properties[property_name]
|
||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||
property_description = f"conditioned on {property_name} = {property_value}"
|
||||
else:
|
||||
generation_type = "multi_property"
|
||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||
|
||||
# 创建完整的提示
|
||||
prompt = f"""
|
||||
# {title}
|
||||
|
||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||
|
||||
{'' if generation_type == 'unconditional' else f'''
|
||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||
the generation adheres to the specified property values. Higher values produce samples that more
|
||||
closely match the target properties but may reduce diversity.
|
||||
'''}
|
||||
|
||||
## CIF Files (Crystallographic Information Files)
|
||||
|
||||
- Standard format for crystallographic structures
|
||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||
- Used by crystallographic software and visualization tools
|
||||
|
||||
```
|
||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||
```
|
||||
|
||||
{description}
|
||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||
"""
|
||||
# print("prompt",prompt)
|
||||
# 清理文件(读取后删除)
|
||||
# try:
|
||||
# if os.path.exists(cif_zip_path):
|
||||
# os.remove(cif_zip_path)
|
||||
# if os.path.exists(xyz_file_path):
|
||||
# os.remove(xyz_file_path)
|
||||
# if os.path.exists(trajectories_zip_path):
|
||||
# os.remove(trajectories_zip_path)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error cleaning up files: {e}")
|
||||
|
||||
# GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理
|
||||
logger.info(f"Generation completed on GPU for model {generator_key}")
|
||||
|
||||
return prompt
|
||||
26
sci_mcp/material_mcp/mattergen_gen/mattergen_wrapper.py
Executable file
26
sci_mcp/material_mcp/mattergen_gen/mattergen_wrapper.py
Executable file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
This is a wrapper module that provides access to the mattergen modules
|
||||
by modifying the Python path at runtime.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from ...core.config import material_config
|
||||
# Add the mattergen directory to the Python path
|
||||
mattergen_dir = material_config.MATTERGEN_ROOT
|
||||
sys.path.insert(0, mattergen_dir)
|
||||
|
||||
# Import the necessary modules from the mattergen package
|
||||
try:
|
||||
from mattergen import generator
|
||||
from mattergen.common.data import chemgraph
|
||||
from mattergen.common.data.types import TargetProperty
|
||||
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||||
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
|
||||
except ImportError as e:
|
||||
print(f"Error importing mattergen modules: {e}")
|
||||
print(f"Python path: {sys.path}")
|
||||
raise
|
||||
CrystalGenerator = generator.CrystalGenerator
|
||||
# Re-export the modules
|
||||
__all__ = ['generator', 'chemgraph', 'TargetProperty', 'MatterGenCheckpointInfo', 'PRETRAINED_MODEL_NAME','CrystalGenerator']
|
||||
73
sci_mcp/material_mcp/mattersim_pred/property_pred_tools.py
Normal file
73
sci_mcp/material_mcp/mattersim_pred/property_pred_tools.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Property Prediction Module
|
||||
|
||||
This module provides functions for predicting properties of crystal structures.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
import numpy as np
|
||||
from ase.units import GPa
|
||||
from mattersim.forcefield import MatterSimCalculator
|
||||
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ..support.utils import convert_structure,read_structure_from_file_name_or_content_string
|
||||
|
||||
@llm_tool(
|
||||
name="predict_properties_MatterSim",
|
||||
description="Predict energy, forces, and stress of crystal structures using MatterSim model based on CIF string",
|
||||
)
|
||||
async def predict_properties_MatterSim(structure_source: str) -> str:
|
||||
"""
|
||||
Use MatterSim model to predict energy, forces, and stress of crystal structures.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string
|
||||
|
||||
Returns:
|
||||
String containing prediction results
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_prediction():
|
||||
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
|
||||
structure_content,content_format=read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = convert_structure(content_format, structure_content)
|
||||
if structure is None:
|
||||
return "Unable to parse CIF string. Please check if the format is correct."
|
||||
|
||||
# 设置设备
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 使用 MatterSimCalculator 计算属性
|
||||
structure.calc = MatterSimCalculator(device=device)
|
||||
|
||||
# 直接获取能量、力和应力
|
||||
energy = structure.get_potential_energy()
|
||||
forces = structure.get_forces()
|
||||
stresses = structure.get_stress(voigt=False)
|
||||
|
||||
# 计算每原子能量
|
||||
num_atoms = len(structure)
|
||||
energy_per_atom = energy / num_atoms
|
||||
|
||||
# 计算应力(GPa和eV/A^3格式)
|
||||
stresses_ev_a3 = stresses
|
||||
stresses_gpa = stresses / GPa
|
||||
|
||||
# 构建返回的提示信息
|
||||
prompt = f"""
|
||||
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
|
||||
|
||||
Prediction results using the provided CIF structure:
|
||||
|
||||
- Total Energy (eV): {energy}
|
||||
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
|
||||
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
|
||||
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
|
||||
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
# 异步执行预测操作
|
||||
return await asyncio.to_thread(run_prediction)
|
||||
0
sci_mcp/material_mcp/mp_query/__init__.py
Normal file
0
sci_mcp/material_mcp/mp_query/__init__.py
Normal file
42
sci_mcp/material_mcp/mp_query/get_mp_id.py
Normal file
42
sci_mcp/material_mcp/mp_query/get_mp_id.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import List
|
||||
from mp_api.client import MPRester
|
||||
from ...core.config import material_config
|
||||
|
||||
async def get_mpid_from_formula(formula: str) -> List[str]:
|
||||
"""
|
||||
Get material IDs (mpid) from Materials Project database by chemical formula.
|
||||
Returns mpids for the lowest energy structures.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula (e.g., "Fe2O3")
|
||||
|
||||
Returns:
|
||||
List of material IDs
|
||||
"""
|
||||
os.environ['HTTP_PROXY'] = material_config.HTTP_PROXY or ''
|
||||
os.environ['HTTPS_PROXY'] =material_config.HTTPS_PROXY or ''
|
||||
|
||||
|
||||
try:
|
||||
id_list = []
|
||||
|
||||
cleaned_formula = formula.replace(" ", "").replace("\n", "").replace("\'", "").replace("\"", "")
|
||||
if "=" in cleaned_formula:
|
||||
name, id = cleaned_formula.split("=")
|
||||
else:
|
||||
id = cleaned_formula
|
||||
|
||||
formula_list = [id]
|
||||
|
||||
with MPRester(material_config.MP_API_KEY) as mpr:
|
||||
docs = mpr.materials.summary.search(formula=formula_list)
|
||||
if not docs:
|
||||
return "No materials found"
|
||||
else:
|
||||
for doc in docs:
|
||||
id_list.append(doc.material_id)
|
||||
return id_list
|
||||
except Exception as e:
|
||||
|
||||
return f"Error: get_mpid_from_formula: {str(e)}"
|
||||
168
sci_mcp/material_mcp/mp_query/mp_query_tools.py
Normal file
168
sci_mcp/material_mcp/mp_query/mp_query_tools.py
Normal file
@@ -0,0 +1,168 @@
|
||||
|
||||
import glob
|
||||
import json
|
||||
from typing import Dict, Any, Union
|
||||
from ...core.llm_tools import llm_tool
|
||||
from .get_mp_id import get_mpid_from_formula
|
||||
from ..support.utils import extract_cif_info, remove_symmetry_equiv_xyz
|
||||
from ...core.config import material_config
|
||||
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
@llm_tool(name="search_crystal_structures_from_materials_project",
|
||||
description="Retrieve and optimize crystal structures from Materials Project database using a chemical formula")
|
||||
async def search_crystal_structures_from_materials_project(
|
||||
formula: str,
|
||||
conventional_unit_cell: bool = True,
|
||||
symprec: float = 0.1
|
||||
) -> str:
|
||||
"""
|
||||
Retrieves crystal structures for a given chemical formula from Materials Project database and applies symmetry optimization.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula to search for (e.g., "Fe2O3")
|
||||
conventional_unit_cell: If True, returns conventional unit cell; if False, returns primitive cell
|
||||
symprec: Symmetry precision parameter for structure refinement (default: 0.1)
|
||||
|
||||
Returns:
|
||||
Formatted CIF data for the retrieved crystal structures with symmetry analysis
|
||||
"""
|
||||
try:
|
||||
structures = {}
|
||||
mp_id_list = await get_mpid_from_formula(formula=formula)
|
||||
if isinstance(mp_id_list, str):
|
||||
return mp_id_list # 直接返回错误信息
|
||||
|
||||
for i, mp_id in enumerate(mp_id_list):
|
||||
try:
|
||||
# 文件操作可能引发异常
|
||||
cif_files = glob.glob(material_config.LOCAL_MP_CIF_ROOT + f"/{mp_id}.cif")
|
||||
if not cif_files:
|
||||
continue # 如果没有找到文件,跳过这个mp_id
|
||||
|
||||
cif_file = cif_files[0]
|
||||
structure = Structure.from_file(cif_file)
|
||||
|
||||
# 结构处理可能引发异常
|
||||
if conventional_unit_cell:
|
||||
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
|
||||
|
||||
# 对结构进行对称化处理
|
||||
sga = SpacegroupAnalyzer(structure, symprec=symprec)
|
||||
symmetrized_structure = sga.get_refined_structure()
|
||||
|
||||
# 使用CifWriter生成CIF数据
|
||||
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
|
||||
cif_data = str(cif_writer)
|
||||
|
||||
# 删除CIF文件中的对称性操作部分
|
||||
cif_data = remove_symmetry_equiv_xyz(cif_data)
|
||||
cif_data = cif_data.replace('# generated using pymatgen', "")
|
||||
|
||||
# 生成一个唯一的键
|
||||
formula_key = structure.composition.reduced_formula
|
||||
key = f"{formula_key}_{i}"
|
||||
|
||||
structures[key] = cif_data
|
||||
|
||||
# 只保留前config.MP_TOPK个结果
|
||||
if len(structures) >= material_config.MP_TOPK:
|
||||
break
|
||||
|
||||
except (FileNotFoundError, IndexError) as file_error:
|
||||
# 处理文件相关错误
|
||||
continue # 跳过这个mp_id,继续处理下一个
|
||||
except ValueError as value_error:
|
||||
# 处理结构处理中的值错误
|
||||
continue # 跳过这个mp_id,继续处理下一个
|
||||
except Exception as process_error:
|
||||
# 记录处理特定结构时的错误,但继续处理其他结构
|
||||
print(f"Error: processing structure {mp_id}: {str(process_error)}")
|
||||
continue
|
||||
|
||||
# 如果没有成功处理任何结构
|
||||
if not structures:
|
||||
return f"No valid crystal structures found for formula: {formula}"
|
||||
|
||||
# 格式化结果为可读字符串
|
||||
prompt = f"""
|
||||
# Materials Project Symmetrized Crystal Structure Data
|
||||
|
||||
Below are symmetrized crystal structure data for {len(structures)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
|
||||
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
|
||||
"""
|
||||
|
||||
for i, (key, cif_data) in enumerate(structures.items(), 1):
|
||||
prompt += f"[cif {i} begin]\n"
|
||||
prompt += cif_data
|
||||
prompt += f"\n[cif {i} end]\n\n"
|
||||
|
||||
return prompt
|
||||
|
||||
except Exception as e:
|
||||
# 捕获整个函数执行过程中的任何未处理异常
|
||||
return f"Error: An unexpected error occurred while processing crystal structures: {str(e)}"
|
||||
|
||||
@llm_tool(name="search_material_property_from_material_project",
|
||||
description="Query material properties from Materials Project database using chemical formula")
|
||||
async def search_material_property_from_materials_project(
|
||||
formula: str,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve detailed property data for materials matching a chemical formula from Materials Project database.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula of the material(s) to search for (e.g. 'Fe2O3', 'LiFePO4')
|
||||
|
||||
Returns:
|
||||
Formatted string containing material properties including structure, electronic, thermodynamic and mechanical data
|
||||
"""
|
||||
# 获取MP ID列表
|
||||
mp_id_list = await get_mpid_from_formula(formula=formula)
|
||||
|
||||
# 检查get_mpid_from_formula的返回值类型
|
||||
# 如果返回的是字符串,说明发生了错误或没有找到材料
|
||||
if isinstance(mp_id_list, str):
|
||||
return mp_id_list # 直接返回错误信息
|
||||
|
||||
# 如果代码执行到这里,说明mp_id_list是一个有效的ID列表
|
||||
try:
|
||||
# 获取材料属性
|
||||
properties = []
|
||||
for mp_id in mp_id_list:
|
||||
try:
|
||||
file_path = material_config.LOCAL_MP_PROPS_ROOT + f"/{mp_id}.json"
|
||||
crystal_props = extract_cif_info(file_path, ['all_fields'])
|
||||
properties.append(crystal_props)
|
||||
except Exception as file_error:
|
||||
# 记录单个文件处理错误但继续处理其他ID
|
||||
continue
|
||||
|
||||
# 检查是否有结果
|
||||
if len(properties) == 0:
|
||||
return "No material properties found for the given formula, please try again."
|
||||
|
||||
# 只保留前MP_TOPK个结果
|
||||
properties = properties[:material_config.MP_TOPK]
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = []
|
||||
for i, item in enumerate(properties, 1):
|
||||
formatted_result = f"[property {i} begin]\n"
|
||||
formatted_result += json.dumps(item, indent=2)
|
||||
formatted_result += f"\n[property {i} end]\n\n"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
# 将所有结果合并为一个字符串
|
||||
res_chunk = "\n\n".join(formatted_results)
|
||||
res_template = f"""
|
||||
Here are the search material property from the Materials Project database:
|
||||
Due to length limitations, only the top {len(properties)} results are shown below:\n
|
||||
{res_chunk}
|
||||
"""
|
||||
return res_template
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: processing material properties: {str(e)}"
|
||||
0
sci_mcp/material_mcp/oqmd_query/__init__.py
Normal file
0
sci_mcp/material_mcp/oqmd_query/__init__.py
Normal file
92
sci_mcp/material_mcp/oqmd_query/oqmd_query_tools.py
Executable file
92
sci_mcp/material_mcp/oqmd_query/oqmd_query_tools.py
Executable file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from bs4 import BeautifulSoup
|
||||
from io import StringIO
|
||||
from typing import Annotated, Any, Dict, List
|
||||
import mcp.types as types
|
||||
from ...core.llm_tools import llm_tool
|
||||
|
||||
|
||||
|
||||
@llm_tool(name="query_material_from_OQMD", description="Query material properties by chemical formula from OQMD database")
|
||||
async def query_material_from_OQMD(
|
||||
formula: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
|
||||
) -> str:
|
||||
"""
|
||||
Query material information by chemical formula from OQMD database.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula of the material (e.g., Fe2O3, LiFePO4)
|
||||
|
||||
Returns:
|
||||
Formatted text with material information and property tables
|
||||
"""
|
||||
# Fetch data from OQMD
|
||||
url = f"https://www.oqmd.org/materials/composition/{formula}"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=100.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Validate response content
|
||||
if not response.text or len(response.text) < 100:
|
||||
raise ValueError("Invalid response content from OQMD API")
|
||||
|
||||
# Parse HTML data
|
||||
html = response.text
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# Parse basic data
|
||||
basic_data = []
|
||||
h1_element = soup.find('h1')
|
||||
if h1_element:
|
||||
basic_data.append(h1_element.text.strip())
|
||||
else:
|
||||
basic_data.append(f"Material: {formula}")
|
||||
|
||||
for script in soup.find_all('p'):
|
||||
if script:
|
||||
combined_text = ""
|
||||
for element in script.contents:
|
||||
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
|
||||
url = "https://www.oqmd.org" + element['href']
|
||||
combined_text += f"[{element.text.strip()}]({url}) "
|
||||
elif hasattr(element, 'text'):
|
||||
combined_text += element.text.strip() + " "
|
||||
else:
|
||||
combined_text += str(element).strip() + " "
|
||||
basic_data.append(combined_text.strip())
|
||||
|
||||
# Parse table data
|
||||
table_data = ""
|
||||
table = soup.find('table')
|
||||
if table:
|
||||
try:
|
||||
df = pd.read_html(StringIO(str(table)))[0]
|
||||
df = df.fillna('')
|
||||
df = df.replace([float('inf'), float('-inf')], '')
|
||||
table_data = df.to_markdown(index=False)
|
||||
except Exception as e:
|
||||
|
||||
table_data = "Error: parsing table data"
|
||||
|
||||
# Integrate data into a single text
|
||||
combined_text = "\n\n".join(basic_data)
|
||||
if table_data:
|
||||
combined_text += "\n\n## Material Properties Table\n\n" + table_data
|
||||
|
||||
return combined_text
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return f"Error: OQMD API request failed - {str(e)}"
|
||||
except httpx.TimeoutException:
|
||||
return "Error: OQMD API request timed out"
|
||||
except httpx.NetworkError as e:
|
||||
return f"Error: Network error occurred - {str(e)}"
|
||||
except ValueError as e:
|
||||
return f"Error: Invalid response content - {str(e)}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error occurred - {str(e)}"
|
||||
|
||||
|
||||
95
sci_mcp/material_mcp/pymatgen_cal/pymatgen_cal_tools.py
Normal file
95
sci_mcp/material_mcp/pymatgen_cal/pymatgen_cal_tools.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
import asyncio
|
||||
from pymatgen.core import Structure
|
||||
from ...core.config import material_config
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ..support.utils import read_structure_from_file_name_or_content_string
|
||||
|
||||
@llm_tool(name="calculate_density_Pymatgen", description="Calculate the density of a crystal structure from a file or content string using Pymatgen")
|
||||
async def calculate_density_Pymatgen(structure_source: str) -> str:
|
||||
"""
|
||||
Calculates the density of a structure from a file or content string.
|
||||
|
||||
Args:
|
||||
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
|
||||
Returns:
|
||||
str: A Markdown formatted string with the density or an error message if the calculation fails.
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
# # 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content,fmt=content_format)
|
||||
density = structure.density
|
||||
|
||||
# 删除临时文件
|
||||
|
||||
return (f"## Density Calculation\n\n"
|
||||
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n")
|
||||
except Exception as e:
|
||||
return f"Error: error occurred while calculating density: {str(e)}\n"
|
||||
|
||||
|
||||
@llm_tool(name="get_element_composition_Pymatgen", description="Analyze and retrieve the elemental composition of a crystal structure from a file or content string using Pymatgen")
|
||||
async def get_element_composition_Pymatgen(structure_source: str) -> str:
|
||||
"""
|
||||
Returns the elemental composition of a structure from a file or content string.
|
||||
|
||||
Args:
|
||||
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
|
||||
Returns:
|
||||
str: A Markdown formatted string with the elemental composition or an error message if the operation fails.
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
composition = structure.composition
|
||||
|
||||
return (f"## Element Composition\n\n"
|
||||
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
|
||||
f"- **Composition**: `{composition}`\n")
|
||||
except Exception as e:
|
||||
return f"Error: error occurred while getting element composition: {str(e)}\n"
|
||||
|
||||
|
||||
|
||||
@llm_tool(name="calculate_symmetry_Pymatgen", description="Determine the space group and symmetry operations of a crystal structure from a file or content string using Pymatgen")
|
||||
async def calculate_symmetry_Pymatgen(structure_source: str) -> str:
|
||||
"""
|
||||
Calculates the symmetry of a structure from a file or content string.
|
||||
|
||||
Args:
|
||||
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
|
||||
Returns:
|
||||
str: A Markdown formatted string with the symmetry information or an error message if the operation fails.
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
symmetry = structure.get_space_group_info()
|
||||
|
||||
return (f"## Symmetry Information\n\n"
|
||||
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
|
||||
f"- **Space Group**: `{symmetry[0]}`\n"
|
||||
f"- **Number**: `{symmetry[1]}`\n")
|
||||
except Exception as e:
|
||||
return f"Error: error occurred while calculating symmetry: {str(e)}\n"
|
||||
|
||||
0
sci_mcp/material_mcp/support/__init__.py
Normal file
0
sci_mcp/material_mcp/support/__init__.py
Normal file
212
sci_mcp/material_mcp/support/utils.py
Executable file
212
sci_mcp/material_mcp/support/utils.py
Executable file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
CIF Utilities Module
|
||||
|
||||
This module provides basic functions for handling CIF (Crystallographic Information File) files,
|
||||
which are commonly used in materials science for representing crystal structures.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from ase.io import read
|
||||
import tempfile
|
||||
from typing import Optional, Tuple
|
||||
from ase import Atoms
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def read_cif_txt_file(file_path):
|
||||
"""
|
||||
Read the CIF file and return its content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the CIF file
|
||||
|
||||
Returns:
|
||||
String content of the CIF file or None if an error occurs
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def extract_cif_info(path: str, fields_name: list):
|
||||
"""
|
||||
Extract specific fields from the CIF description JSON file.
|
||||
|
||||
Args:
|
||||
path: Path to the JSON file containing CIF information
|
||||
fields_name: List of field categories to extract. Use 'all_fields' to extract all fields.
|
||||
Other options include 'basic_fields', 'energy_electronic_fields', 'metal_magentic_fields'
|
||||
|
||||
Returns:
|
||||
Dictionary containing the extracted fields
|
||||
"""
|
||||
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
|
||||
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
|
||||
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
|
||||
|
||||
selected_fields = []
|
||||
if fields_name[0] == 'all_fields':
|
||||
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
|
||||
else:
|
||||
for field in fields_name:
|
||||
selected_fields.extend(locals().get(field, []))
|
||||
|
||||
with open(path, 'r') as f:
|
||||
docs = json.load(f)
|
||||
|
||||
new_docs = {}
|
||||
for field_name in selected_fields:
|
||||
new_docs[field_name] = docs.get(field_name, '')
|
||||
|
||||
return new_docs
|
||||
|
||||
def remove_symmetry_equiv_xyz(cif_content):
|
||||
"""
|
||||
Remove symmetry operations section from CIF file content.
|
||||
|
||||
This is often useful when working with CIF files in certain visualization tools
|
||||
or when focusing on the basic structure without symmetry operations.
|
||||
|
||||
Args:
|
||||
cif_content: CIF file content string
|
||||
|
||||
Returns:
|
||||
Cleaned CIF content string with symmetry operations removed
|
||||
"""
|
||||
lines = cif_content.split('\n')
|
||||
output_lines = []
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# 检测循环开始
|
||||
if line == 'loop_':
|
||||
# 查看下一行,检查是否是对称性循环
|
||||
next_lines = []
|
||||
j = i + 1
|
||||
while j < len(lines) and lines[j].strip().startswith('_'):
|
||||
next_lines.append(lines[j].strip())
|
||||
j += 1
|
||||
|
||||
# 检查是否包含对称性操作标签
|
||||
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
|
||||
# 跳过整个循环块
|
||||
while i < len(lines):
|
||||
if i + 1 >= len(lines):
|
||||
break
|
||||
|
||||
next_line = lines[i + 1].strip()
|
||||
# 检查是否到达下一个循环或数据块
|
||||
if next_line == 'loop_' or next_line.startswith('data_'):
|
||||
break
|
||||
|
||||
# 检查是否到达原子位置部分
|
||||
if next_line.startswith('_atom_site_'):
|
||||
break
|
||||
|
||||
i += 1
|
||||
else:
|
||||
# 不是对称性循环,保留loop_行
|
||||
output_lines.append(lines[i])
|
||||
else:
|
||||
# 非循环开始行,直接保留
|
||||
output_lines.append(lines[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
|
||||
def read_structure_from_file_name_or_content_string(file_name_or_content_string: str, format_type: str = "auto") -> Tuple[str, str]:
|
||||
"""
|
||||
处理结构输入,判断是文件名还是直接内容
|
||||
|
||||
当file_name_or_content_string被视为文件名时,会在material_config.TEMP_ROOT目录下查找该文件。
|
||||
这适用于大模型生成的临时文件,这些文件通常存储在临时目录中。
|
||||
|
||||
Args:
|
||||
file_name_or_content_string: 文件名或结构内容字符串
|
||||
format_type: 结构格式类型,"auto"表示自动检测
|
||||
|
||||
Returns:
|
||||
tuple: (内容字符串, 实际格式类型)
|
||||
"""
|
||||
# 首先检查是否是完整路径的文件
|
||||
if os.path.exists(file_name_or_content_string) and os.path.isfile(file_name_or_content_string):
|
||||
# 是完整路径文件,读取文件内容
|
||||
with open(file_name_or_content_string, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 如果格式为auto,从文件扩展名推断
|
||||
if format_type == "auto":
|
||||
ext = os.path.splitext(file_name_or_content_string)[1].lower().lstrip('.')
|
||||
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
|
||||
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
|
||||
else:
|
||||
# 默认假设为CIF
|
||||
format_type = 'cif'
|
||||
else:
|
||||
# 检查是否是临时目录中的文件名
|
||||
temp_path = os.path.join(material_config.TEMP_ROOT, file_name_or_content_string)
|
||||
if os.path.exists(temp_path) and os.path.isfile(temp_path):
|
||||
# 是临时目录中的文件,读取文件内容
|
||||
with open(temp_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 如果格式为auto,从文件扩展名推断
|
||||
if format_type == "auto":
|
||||
ext = os.path.splitext(temp_path)[1].lower().lstrip('.')
|
||||
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
|
||||
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
|
||||
else:
|
||||
# 默认假设为CIF
|
||||
format_type = 'cif'
|
||||
else:
|
||||
# 不是文件路径,假设是直接内容
|
||||
content = file_name_or_content_string
|
||||
|
||||
# 如果格式为auto,尝试从内容推断
|
||||
if format_type == "auto":
|
||||
# 简单启发式判断:
|
||||
# CIF文件通常包含"data_"和"_cell_"
|
||||
if "data_" in content and "_cell_" in content:
|
||||
format_type = "cif"
|
||||
# XYZ文件通常第一行是原子数量
|
||||
elif content.strip().split('\n')[0].strip().isdigit():
|
||||
format_type = "xyz"
|
||||
# POSCAR/VASP格式通常第一行是注释
|
||||
elif len(content.strip().split('\n')) > 5 and all(len(line.split()) == 3 for line in content.strip().split('\n')[2:5]):
|
||||
format_type = "vasp"
|
||||
# 默认假设为CIF
|
||||
else:
|
||||
format_type = "cif"
|
||||
|
||||
return content, format_type
|
||||
|
||||
def convert_structure(input_format: str='cif', content: str=None) -> Optional[Atoms]:
|
||||
"""
|
||||
将输入内容转换为Atoms对象
|
||||
|
||||
Args:
|
||||
input_format: 输入格式 (cif, xyz, vasp等)
|
||||
content: 结构内容字符串
|
||||
|
||||
Returns:
|
||||
ASE Atoms对象,如果转换失败则返回None
|
||||
"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
atoms = read(tmp_path)
|
||||
os.unlink(tmp_path)
|
||||
return atoms
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert structure: {str(e)}")
|
||||
return None
|
||||
Reference in New Issue
Block a user