初次提交

This commit is contained in:
lzy
2025-05-09 14:16:33 +08:00
commit 3a50afeec4
56 changed files with 9224 additions and 0 deletions

38
sci_mcp/__init__.py Normal file
View File

@@ -0,0 +1,38 @@
from .core.llm_tools import llm_tool,get_all_tools, get_all_tool_schemas, get_domain_tools, get_domain_tool_schemas
#general_mcp
#from .general_mcp.searxng_query.searxng_query_tools import search_online
# #material_mcp
from .material_mcp.mp_query.mp_query_tools import search_crystal_structures_from_materials_project,search_material_property_from_materials_project
from .material_mcp.oqmd_query.oqmd_query_tools import query_material_from_OQMD
from .material_mcp.knowledge_base_query.retrieval_from_knowledge_base_tools import retrieval_from_knowledge_base
from .material_mcp.mattersim_pred.property_pred_tools import predict_properties_MatterSim
from .material_mcp.mattergen_gen.material_gen_tools import generate_material_MatterGen
from .material_mcp.fairchem_structure_opt.structure_opt_tools import optimize_crystal_structure_FairChem
from .material_mcp.pymatgen_cal.pymatgen_cal_tools import calculate_density_Pymatgen,get_element_composition_Pymatgen,calculate_symmetry_Pymatgen
from .material_mcp.matgl_tools.matgl_tools import relax_crystal_structure_M3GNet,predict_formation_energy_M3GNet,run_molecular_dynamics_M3GNet,calculate_single_point_energy_M3GNet
#chemistry_mcp
from .chemistry_mcp.pubchem_tools.pubchem_tools import search_advanced_pubchem
from .chemistry_mcp.rdkit_tools.rdkit_tools import (
calculate_molecular_properties_rdkit,
calculate_drug_likeness_rdkit,
calculate_topological_descriptors_rdkit,
generate_molecular_fingerprints_rdkit,
calculate_molecular_similarity_rdkit,
analyze_molecular_structure_rdkit,
generate_molecular_conformer_rdkit,
identify_scaffolds_rdkit,
convert_between_chemical_formats_rdkit,
standardize_molecule_rdkit,
enumerate_stereoisomers_rdkit,
perform_substructure_search_rdkit
)
from .chemistry_mcp.rxn_tools.rxn_tools import (
predict_reaction_outcome_rxn,
predict_reaction_topn_rxn,
predict_reaction_properties_rxn,
extract_reaction_actions_rxn
)
__all__ = ["llm_tool", "get_all_tools", "get_all_tool_schemas", "get_domain_tools", "get_domain_tool_schemas"]

View File

@@ -0,0 +1,9 @@
"""
Chemistry MCP Module
This module provides tools for chemistry-related operations.
"""
from .pubchem_tools import search_advanced_pubchem
__all__ = ["search_advanced_pubchem"]

View File

@@ -0,0 +1,9 @@
"""
PubChem Tools Module
This module provides tools for accessing and processing chemical data from PubChem.
"""
from .pubchem_tools import search_advanced_pubchem
__all__ = ["search_advanced_pubchem"]

View File

@@ -0,0 +1,325 @@
"""
PubChem Tools Module
This module provides tools for searching and retrieving chemical compound information
from the PubChem database using the PubChemPy library.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Dict, List, Union, Optional, Any
import pubchempy as pcp
from ...core.llm_tools import llm_tool
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def compound_to_dict(compound: pcp.Compound) -> Dict[str, Any]:
"""
Convert a PubChem compound to a structured dictionary with relevant information.
Args:
compound: PubChem compound object
Returns:
Dictionary containing organized compound information
"""
if not compound:
return {}
# Basic information
result = {
"basic_info": {
"cid": compound.cid,
"iupac_name": compound.iupac_name,
"molecular_formula": compound.molecular_formula,
"molecular_weight": compound.molecular_weight,
"canonical_smiles": compound.canonical_smiles,
"isomeric_smiles": compound.isomeric_smiles,
},
"identifiers": {
"inchi": compound.inchi,
"inchikey": compound.inchikey,
},
"physical_properties": {
"xlogp": compound.xlogp,
"exact_mass": compound.exact_mass,
"monoisotopic_mass": compound.monoisotopic_mass,
"tpsa": compound.tpsa,
"complexity": compound.complexity,
"charge": compound.charge,
},
"molecular_features": {
"h_bond_donor_count": compound.h_bond_donor_count,
"h_bond_acceptor_count": compound.h_bond_acceptor_count,
"rotatable_bond_count": compound.rotatable_bond_count,
"heavy_atom_count": compound.heavy_atom_count,
"atom_stereo_count": compound.atom_stereo_count,
"defined_atom_stereo_count": compound.defined_atom_stereo_count,
"undefined_atom_stereo_count": compound.undefined_atom_stereo_count,
"bond_stereo_count": compound.bond_stereo_count,
"defined_bond_stereo_count": compound.defined_bond_stereo_count,
"undefined_bond_stereo_count": compound.undefined_bond_stereo_count,
"covalent_unit_count": compound.covalent_unit_count,
}
}
# Add synonyms if available
if hasattr(compound, 'synonyms') and compound.synonyms:
result["alternative_names"] = {
"synonyms": compound.synonyms[:10] if len(compound.synonyms) > 10 else compound.synonyms
}
return result
async def _search_by_name(name: str, max_results: int = 5) -> List[Dict[str, Any]]:
"""
Search compounds by name asynchronously.
Args:
name: Chemical compound name
max_results: Maximum number of results to return
Returns:
List of compound dictionaries
"""
try:
compounds = await asyncio.to_thread(
pcp.get_compounds, name, 'name', max_records=max_results
)
#print(compounds[0].to_dict())
return [compound.to_dict() for compound in compounds]
except Exception as e:
logging.error(f"Error searching by name '{name}': {str(e)}")
return [{"error": f"Error: {str(e)}"}]
async def _search_by_smiles(smiles: str, max_results: int = 5) -> List[Dict[str, Any]]:
"""
Search compounds by SMILES notation asynchronously.
Args:
smiles: SMILES notation of chemical compound
max_results: Maximum number of results to return
Returns:
List of compound dictionaries
"""
try:
compounds = await asyncio.to_thread(
pcp.get_compounds, smiles, 'smiles', max_records=max_results
)
return [compound.to_dict() for compound in compounds]
except Exception as e:
logging.error(f"Error searching by SMILES '{smiles}': {str(e)}")
return [{"error": f"Error: {str(e)}"}]
async def _search_by_formula(
formula: str,
max_results: int = 5,
listkey_count: int = 5,
listkey_start: int = 0
) -> List[Dict[str, Any]]:
"""
Search compounds by molecular formula asynchronously.
Uses pagination with listkey parameters to avoid timeout errors when searching
formulas that might return many results.
Args:
formula: Molecular formula
max_results: Maximum number of results to return
listkey_count: Number of results per page (default: 5)
listkey_start: Starting position for pagination (default: 0)
Returns:
List of compound dictionaries
"""
try:
# Use listkey parameters to avoid timeout errors
compounds = await asyncio.to_thread(
pcp.get_compounds,
formula,
'formula',
max_records=max_results,
listkey_count=listkey_count,
listkey_start=listkey_start
)
return [compound.to_dict() for compound in compounds]
except Exception as e:
logging.error(f"Error searching by formula '{formula}': {str(e)}")
return [{"error": f"Error: {str(e)}"}]
def _format_results_as_markdown(results: List[Dict[str, Any]], query_type: str, query_value: str) -> str:
"""
Format search results as a structured Markdown string.
Args:
results: List of compound dictionaries from compound.to_dict()
query_type: Type of search query (name, SMILES, formula)
query_value: Value of the search query
Returns:
Formatted Markdown string
"""
if not results:
return f"## PubChem Search Results\n\nNo compounds found for {query_type}: `{query_value}`"
if "error" in results[0]:
return f"## PubChem Search Error\n\n{results[0]['error']}"
markdown = f"## PubChem Search Results\n\nSearch by {query_type}: `{query_value}`\n\nFound {len(results)} compound(s)\n\n"
for i, compound in enumerate(results):
# Extract information from the compound.to_dict() structure
cid = compound.get("cid", "N/A")
iupac_name = compound.get("iupac_name", "Unknown")
molecular_formula = compound.get("molecular_formula", "N/A")
molecular_weight = compound.get("molecular_weight", "N/A")
canonical_smiles = compound.get("canonical_smiles", "N/A")
isomeric_smiles = compound.get("isomeric_smiles", "N/A")
inchi = compound.get("inchi", "N/A")
inchikey = compound.get("inchikey", "N/A")
xlogp = compound.get("xlogp", "N/A")
exact_mass = compound.get("exact_mass", "N/A")
tpsa = compound.get("tpsa", "N/A")
h_bond_donor_count = compound.get("h_bond_donor_count", "N/A")
h_bond_acceptor_count = compound.get("h_bond_acceptor_count", "N/A")
rotatable_bond_count = compound.get("rotatable_bond_count", "N/A")
heavy_atom_count = compound.get("heavy_atom_count", "N/A")
# Get atoms and bonds information if available
atoms = compound.get("atoms", [])
bonds = compound.get("bonds", [])
# Format the markdown output
markdown += f"### Compound {i+1}: {iupac_name}\n\n"
# Basic information section
markdown += "#### Basic Information\n\n"
markdown += f"- **CID**: {cid}\n"
markdown += f"- **Formula**: {molecular_formula}\n"
markdown += f"- **Molecular Weight**: {molecular_weight} g/mol\n"
markdown += f"- **Canonical SMILES**: `{canonical_smiles}`\n"
markdown += f"- **Isomeric SMILES**: `{isomeric_smiles}`\n"
# Identifiers section
markdown += "\n#### Identifiers\n\n"
markdown += f"- **InChI**: `{inchi}`\n"
markdown += f"- **InChIKey**: `{inchikey}`\n"
# Physical properties section
markdown += "\n#### Physical Properties\n\n"
markdown += f"- **XLogP**: {xlogp}\n"
markdown += f"- **Exact Mass**: {exact_mass}\n"
markdown += f"- **TPSA**: {tpsa} Ų\n"
# Molecular features section
markdown += "\n#### Molecular Features\n\n"
markdown += f"- **H-Bond Donors**: {h_bond_donor_count}\n"
markdown += f"- **H-Bond Acceptors**: {h_bond_acceptor_count}\n"
markdown += f"- **Rotatable Bonds**: {rotatable_bond_count}\n"
markdown += f"- **Heavy Atoms**: {heavy_atom_count}\n"
# Structure information
markdown += "\n#### Structure Information\n\n"
markdown += f"- **Atoms Count**: {len(atoms)}\n"
markdown += f"- **Bonds Count**: {len(bonds)}\n"
# Add a summary of atom elements if available
if atoms:
elements = {}
for atom in atoms:
element = atom.get("element", "")
if element:
elements[element] = elements.get(element, 0) + 1
if elements:
markdown += "- **Elements**: "
elements_str = ", ".join([f"{element}: {count}" for element, count in elements.items()])
markdown += f"{elements_str}\n"
markdown += "\n---\n\n" if i < len(results) - 1 else "\n"
return markdown
@llm_tool(name="search_advanced_pubchem",
description="Search for chemical compounds on PubChem database using name, SMILES notation, or molecular formula via PubChemPy library")
async def search_advanced_pubchem(
name: Optional[str] = None,
smiles: Optional[str] = None,
formula: Optional[str] = None,
max_results: int = 3
) -> str:
"""
Perform an advanced search for chemical compounds on PubChem using various identifiers.
This function allows searching by compound name, SMILES notation, or molecular formula.
At least one search parameter must be provided. If multiple parameters are provided,
the search will prioritize in the order: name > smiles > formula.
Args:
name: Name of the chemical compound (e.g., "Aspirin", "Caffeine")
smiles: SMILES notation of the chemical compound (e.g., "CC(=O)OC1=CC=CC=C1C(=O)O" for Aspirin)
formula: Molecular formula (e.g., "C9H8O4" for Aspirin)
max_results: Maximum number of results to return (default: 3)
Returns:
A formatted Markdown string with search results
Examples:
>>> search_advanced_pubchem(name="Aspirin")
# Returns information about Aspirin
>>> search_advanced_pubchem(smiles="CC(=O)OC1=CC=CC=C1C(=O)O")
# Returns information about compounds matching the SMILES notation
>>> search_advanced_pubchem(formula="C9H8O4", max_results=5)
# Returns up to 5 compounds with the formula C9H8O4
"""
logging.info(f"Performing advanced PubChem search with parameters: name={name}, smiles={smiles}, formula={formula}, max_results={max_results}")
# Validate input parameters
if name is None and smiles is None and formula is None:
return "## PubChem Search Error\n\nError: At least one search parameter (name, smiles, or formula) must be provided"
# Validate max_results
if max_results < 1:
max_results = 1
elif max_results > 10:
max_results = 10 # Limit to 10 results to prevent overwhelming responses
try:
results = []
query_type = ""
query_value = ""
# Prioritize search by name, then SMILES, then formula
if name is not None:
results = await _search_by_name(name, max_results)
query_type = "name"
query_value = name
elif smiles is not None:
results = await _search_by_smiles(smiles, max_results)
query_type = "SMILES"
query_value = smiles
elif formula is not None:
# Use pagination parameters for formula searches to avoid timeout
# Using the default values from _search_by_formula
results = await _search_by_formula(formula, max_results)
query_type = "formula"
query_value = formula
# Return results as markdown
return _format_results_as_markdown(results, query_type, query_value)
except Exception as e:
return f"## PubChem Search Error\n\nError: {str(e)}"

View File

@@ -0,0 +1,9 @@
"""
RDKit Tools Package
This package provides tools for molecular analysis, manipulation, and visualization
using the RDKit library. It includes functions for calculating molecular descriptors,
generating molecular fingerprints, analyzing molecular structures, and more.
"""
from .rdkit_tools import *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
"""
RXN Tools Module
This module provides tools for chemical reaction prediction and analysis
using the IBM RXN for Chemistry API.
"""

View File

@@ -0,0 +1,772 @@
"""
RXN Tools Module
This module provides tools for chemical reaction prediction and analysis
using the IBM RXN for Chemistry API through the rxn4chemistry package.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Dict, List, Union, Optional, Any, Tuple
from rxn4chemistry import RXN4ChemistryWrapper
from ...core.llm_tools import llm_tool
from ...core.config import Chemistry_Config
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Constants
DEFAULT_MAX_RESULTS = 3
DEFAULT_TIMEOUT = 180 # seconds - increased from 60 to 180
def _get_rxn_wrapper() -> RXN4ChemistryWrapper:
"""
Get an initialized RXN4Chemistry wrapper with API key and project set.
Returns:
Initialized RXN4ChemistryWrapper instance with project set
Raises:
ValueError: If API key is not available
"""
# Try to get API key from environment or config
api_key = Chemistry_Config.RXN4_CHEMISTRY_KEY or os.environ.get("RXN_API_KEY")
if not api_key:
raise ValueError("Error: RXN API key not found. Please set the RXN_API_KEY environment variable.")
# Initialize the wrapper
wrapper = RXN4ChemistryWrapper(api_key=api_key)
try:
# Create a new project
project_name = f"RXN_Tools_Project_{os.getpid()}" # Add process ID to make name unique
project_response = wrapper.create_project(project_name)
# Extract project ID from response
# The API response format is nested: {'response': {'payload': {'id': '...'}}
if project_response and isinstance(project_response, dict):
# Try to extract project ID from different possible response formats
project_id = None
# Direct format: {"project_id": "..."}
if "project_id" in project_response:
project_id = project_response["project_id"]
# Nested format: {"response": {"payload": {"id": "..."}}}
elif "response" in project_response and isinstance(project_response["response"], dict):
payload = project_response["response"].get("payload", {})
if isinstance(payload, dict) and "id" in payload:
project_id = payload["id"]
if project_id:
wrapper.set_project(project_id)
logger.info(f"RXN project '{project_name}' created and set successfully with ID: {project_id}")
else:
logger.warning(f"Could not extract project ID from response: {project_response}")
else:
logger.warning(f"Unexpected project creation response: {project_response}")
except Exception as e:
logger.error(f"Error creating RXN project: {e}")
# Check if project is set
if not hasattr(wrapper, "project_id") or not wrapper.project_id:
logger.warning("No project set. Some API calls may fail.")
return wrapper
def _format_reaction_markdown(reactants: str, products: List[str],
confidence: Optional[List[float]] = None) -> str:
"""
Format reaction results as Markdown.
Args:
reactants: SMILES of reactants
products: List of product SMILES
confidence: Optional list of confidence scores
Returns:
Formatted Markdown string
"""
markdown = f"## 反应预测结果\n\n"
markdown += f"### 输入反应物\n\n`{reactants}`\n\n"
markdown += f"### 预测产物\n\n"
for i, product in enumerate(products):
conf_str = f" (置信度: {confidence[i]:.2f})" if confidence and i < len(confidence) else ""
markdown += f"{i+1}. `{product}`{conf_str}\n"
return markdown
@llm_tool(name="predict_reaction_outcome_rxn",
description="Predict chemical reaction outcomes for given reactants using IBM RXN for Chemistry API")
async def predict_reaction_outcome_rxn(reactants: str) -> str:
"""
Predict chemical reaction outcomes for given reactants.
This function uses the IBM RXN for Chemistry API to predict the most likely
products formed when the given reactants are combined.
Args:
reactants: SMILES notation of reactants, multiple reactants separated by dots (.).
Returns:
Formatted Markdown string containing the predicted reaction results.
Examples:
>>> predict_reaction_outcome_rxn("BrBr.c1ccc2cc3ccccc3cc2c1")
# Returns predicted products of bromine and anthracene reaction
"""
try:
# Get RXN wrapper
wrapper = _get_rxn_wrapper()
# Clean input
reactants = reactants.strip()
# Submit prediction
response = await asyncio.to_thread(
wrapper.predict_reaction, reactants
)
if not response or "prediction_id" not in response:
return "Error: 无法提交反应预测请求"
# 直接获取结果而不是通过_wait_for_result
results = await asyncio.to_thread(
wrapper.get_predict_reaction_results,
response["prediction_id"]
)
# Extract products
try:
attempts = results.get("response", {}).get("payload", {}).get("attempts", [])
if not attempts:
return "Error: 未找到预测结果"
# Get the top predicted product
product_smiles = attempts[0].get("smiles", "")
confidence = attempts[0].get("confidence", None)
# Format results
return _format_reaction_markdown(
reactants,
[product_smiles] if product_smiles else ["无法预测产物"],
[confidence] if confidence is not None else None
)
except (KeyError, IndexError) as e:
logger.error(f"Error parsing prediction results: {e}")
return f"Error: 解析预测结果时出错: {str(e)}"
except Exception as e:
logger.error(f"Error in predict_reaction_outcome: {e}")
return f"Error: {str(e)}"
@llm_tool(name="predict_reaction_topn_rxn",
description="Predict multiple possible products for chemical reactions using IBM RXN for Chemistry API")
async def predict_reaction_topn_rxn(reactants: Union[str, List[str], List[List[str]]], topn: int = 3) -> str:
"""
Predict multiple possible products for chemical reactions.
This function uses the IBM RXN for Chemistry API to predict multiple products
that may be formed from given reactants, ranked by likelihood. Suitable for
scenarios where multiple reaction pathways need to be considered.
Args:
reactants: Reactants in one of the following formats:
- String: SMILES notation for a single reaction, multiple reactants separated by dots (.)
- List of strings: Multiple reactants for a single reaction, each reactant as a SMILES string
- List of lists of strings: Multiple reactions, each reaction composed of multiple reactant SMILES strings
topn: Number of predicted products to return for each reaction, default is 3.
Returns:
Formatted Markdown string containing multiple predicted reaction results.
Examples:
>>> predict_reaction_topn_rxn("BrBr.c1ccc2cc3ccccc3cc2c1", 5)
# Returns top 5 possible products for bromine and anthracene reaction
>>> predict_reaction_topn_rxn(["BrBr", "c1ccc2cc3ccccc3cc2c1"], 3)
# Returns top 3 possible products for bromine and anthracene reaction
>>> predict_reaction_topn_rxn([
... ["BrBr", "c1ccc2cc3ccccc3cc2c1"],
... ["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"]
... ], 3)
# Returns top 3 possible products for two different reactions
"""
try:
# Get RXN wrapper
wrapper = _get_rxn_wrapper()
# Validate topn
if topn < 1:
topn = 1
elif topn > 10:
topn = 10
logger.warning("topn限制为最大10个结果")
# Process input to create precursors_lists
precursors_lists = []
if isinstance(reactants, str):
# Single reaction as string (e.g., "BrBr.c1ccc2cc3ccccc3cc2c1")
reactants = reactants.strip()
precursors_lists = [reactants.split(".")]
# For display in results
reactants_display = [reactants]
elif isinstance(reactants, list):
if all(isinstance(r, str) for r in reactants):
# Single reaction as list of strings (e.g., ["BrBr", "c1ccc2cc3ccccc3cc2c1"])
precursors_lists = [reactants]
# For display in results
reactants_display = [".".join(reactants)]
elif all(isinstance(r, list) for r in reactants):
# Multiple reactions as list of lists (e.g., [["BrBr", "c1ccc2cc3ccccc3cc2c1"], ["BrBr", "c1ccc2cc3ccccc3cc2c1CCO"]])
precursors_lists = reactants
# For display in results
reactants_display = [".".join(r) for r in reactants]
else:
return "Error: 反应物列表格式无效,必须是字符串列表或字符串列表的列表"
else:
return "Error: 反应物参数类型无效,必须是字符串或列表"
# Submit prediction
response = await asyncio.to_thread(
wrapper.predict_reaction_batch_topn,
precursors_lists=precursors_lists,
topn=topn
)
if not response or "task_id" not in response:
return "Error: 无法提交多产物反应预测请求"
# 直接获取结果,不使用循环等待
results = await asyncio.to_thread(
wrapper.get_predict_reaction_batch_topn_results,
response["task_id"]
)
# Extract products
try:
# 记录结果的结构,以便调试
logger.info(f"Results structure: {results.keys()}")
# 更灵活地获取结果使用get方法并提供默认值
reaction_results = results.get("result", [])
# 如果结果为空,尝试其他可能的键
if not reaction_results and "predictions" in results:
reaction_results = results.get("predictions", [])
logger.info("Using 'predictions' key instead of 'result'")
# 如果结果仍然为空,尝试直接使用整个结果
if not reaction_results and isinstance(results, list):
reaction_results = results
logger.info("Using entire results as list")
if not reaction_results:
logger.warning(f"No reaction results found. Available keys: {results.keys()}")
return "Error: 未找到预测结果。请检查API响应格式。"
# Format results for all reactions
markdown = "## 反应预测结果\n\n"
# 确保reaction_results和reactants_display长度匹配
if len(reaction_results) != len(reactants_display):
logger.warning(f"Mismatch between results ({len(reaction_results)}) and reactants ({len(reactants_display)})")
# 如果不匹配,使用较短的长度
min_len = min(len(reaction_results), len(reactants_display))
reaction_results = reaction_results[:min_len]
reactants_display = reactants_display[:min_len]
for i, (reaction_result, reactants_str) in enumerate(zip(reaction_results, reactants_display), 1):
if not reaction_result:
markdown += f"### 反应 {i}: 未找到预测结果\n\n"
continue
# 记录每个反应结果的结构
logger.info(f"Reaction {i} result structure: {type(reaction_result)}")
# Extract products and confidences for this reaction
products = []
confidences = []
# 处理不同格式的反应结果
if isinstance(reaction_result, list):
# 标准格式:列表中的每个项目是一个预测
for pred in reaction_result:
if isinstance(pred, dict) and "smiles" in pred:
# 检查smiles是否为列表
if isinstance(pred["smiles"], list) and pred["smiles"]:
products.append(pred["smiles"][0]) # 取列表中的第一个元素
else:
products.append(pred["smiles"])
confidences.append(pred.get("confidence", 0.0))
elif isinstance(reaction_result, dict):
# 根据用户反馈,检查是否有'results'键
if "results" in reaction_result:
# 遍历results列表
for pred in reaction_result.get("results", []):
if isinstance(pred, dict) and "smiles" in pred:
# 检查smiles是否为列表
if isinstance(pred["smiles"], list) and pred["smiles"]:
products.append(pred["smiles"][0]) # 取列表中的第一个元素
else:
products.append(pred["smiles"])
confidences.append(pred.get("confidence", 0.0))
# 替代格式:字典中直接包含预测
elif "smiles" in reaction_result:
# 检查smiles是否为列表
if isinstance(reaction_result["smiles"], list) and reaction_result["smiles"]:
products.append(reaction_result["smiles"][0]) # 取列表中的第一个元素
else:
products.append(reaction_result["smiles"])
confidences.append(reaction_result.get("confidence", 0.0))
# 另一种可能的格式
elif "products" in reaction_result:
for prod in reaction_result.get("products", []):
if isinstance(prod, dict) and "smiles" in prod:
# 检查smiles是否为列表
if isinstance(prod["smiles"], list) and prod["smiles"]:
products.append(prod["smiles"][0]) # 取列表中的第一个元素
else:
products.append(prod["smiles"])
confidences.append(prod.get("confidence", 0.0))
# Add results for this reaction
markdown += f"### 反应 {i}\n\n"
markdown += f"**输入反应物:** `{reactants_str}`\n\n"
if products:
markdown += "**预测产物:**\n\n"
for j, (product, confidence) in enumerate(zip(products, confidences), 1):
markdown += f"{j}. `{product}` (置信度: {confidence:.2f})\n"
else:
markdown += "**预测产物:** 无法解析产物结构\n\n"
# 添加原始结果以便调试
markdown += f"**原始结果:** `{reaction_result}`\n\n"
markdown += "\n"
return markdown
except Exception as e:
logger.error(f"Error parsing topn prediction results: {e}", exc_info=True)
return f"Error: 解析多产物预测结果时出错: {str(e)}"
except Exception as e:
logger.error(f"Error in predict_reaction_topn: {e}")
return f"Error: {str(e)}"
# @llm_tool(name="predict_retrosynthesis",
# description="预测目标分子的逆合成路径")
# async def predict_retrosynthesis(target_molecule: str, max_steps: int = 3) -> str:
# """
# 预测目标分子的逆合成路径。
# 此函数使用IBM RXN for Chemistry API建议可能的合成路线
# 将目标分子分解为可能商业可得的更简单前体。
# Args:
# target_molecule: 目标分子的SMILES表示法。
# max_steps: 考虑的最大逆合成步骤数默认为3。
# Returns:
# 包含预测逆合成路径的格式化Markdown字符串。
# Examples:
# >>> predict_retrosynthesis("Brc1c2ccccc2c(Br)c2ccccc12")
# # 返回目标分子的可能合成路线
# """
# try:
# # Get RXN wrapper
# wrapper = _get_rxn_wrapper()
# # Clean input
# target_molecule = target_molecule.strip()
# # Validate max_steps
# if max_steps < 1:
# max_steps = 1
# elif max_steps > 5:
# max_steps = 5
# logger.warning("max_steps限制为最大5步")
# # Submit prediction
# response = await asyncio.to_thread(
# wrapper.predict_automatic_retrosynthesis,
# product=target_molecule,
# max_steps=max_steps
# )
# if not response or "prediction_id" not in response:
# return "Error: 无法提交逆合成预测请求"
# # 直接获取结果而不是通过_wait_for_result
# results = await asyncio.to_thread(
# wrapper.get_predict_automatic_retrosynthesis_results,
# response["prediction_id"]
# )
# # Extract retrosynthetic paths
# try:
# paths = results.get("retrosynthetic_paths", [])
# if not paths:
# return "## 逆合成分析结果\n\n未找到可行的逆合成路径。目标分子可能太复杂或结构有问题。"
# # Format results
# markdown = f"## 逆合成分析结果\n\n"
# markdown += f"### 目标分子\n\n`{target_molecule}`\n\n"
# markdown += f"### 找到的合成路径: {len(paths)}\n\n"
# # Limit to top 3 paths for readability
# display_paths = paths[:3]
# for i, path in enumerate(display_paths, 1):
# markdown += f"#### 路径 {i}\n\n"
# # Extract sequence information
# sequence_id = path.get("sequenceId", "未知")
# confidence = path.get("confidence", 0.0)
# markdown += f"**置信度:** {confidence:.2f}\n\n"
# # Extract steps
# steps = path.get("steps", [])
# if steps:
# markdown += "**合成步骤:**\n\n"
# for j, step in enumerate(steps, 1):
# # Extract reactants and products
# reactants = step.get("reactants", [])
# reactant_smiles = [r.get("smiles", "") for r in reactants if "smiles" in r]
# product = step.get("product", {})
# product_smiles = product.get("smiles", "")
# markdown += f"步骤 {j}: "
# if reactant_smiles and product_smiles:
# markdown += f"`{'.' if len(reactant_smiles) > 1 else ''.join(reactant_smiles)}` → `{product_smiles}`\n\n"
# else:
# markdown += "反应细节不可用\n\n"
# else:
# markdown += "**合成步骤:** 未提供详细步骤\n\n"
# markdown += "---\n\n"
# if len(paths) > 3:
# markdown += f"*注: 仅显示前3条路径共找到{len(paths)}条可能的合成路径。*\n"
# return markdown
# except (KeyError, IndexError) as e:
# logger.error(f"Error parsing retrosynthesis results: {e}")
# return f"Error: 解析逆合成结果时出错: {str(e)}"
# except Exception as e:
# logger.error(f"Error in predict_retrosynthesis: {e}")
# return f"Error: {str(e)}"
# @llm_tool(name="predict_biocatalytic_retrosynthesis",
# description="使用生物催化模型预测目标分子的逆合成路径")
# async def predict_biocatalytic_retrosynthesis(target_molecule: str) -> str:
# """
# 使用生物催化模型预测目标分子的逆合成路径。
# 此函数使用IBM RXN for Chemistry API的专门生物催化模型
# 建议可能的酶催化合成路线来创建目标分子。
# Args:
# target_molecule: 目标分子的SMILES表示法。
# Returns:
# 包含预测生物催化逆合成路径的格式化Markdown字符串。
# Examples:
# >>> predict_biocatalytic_retrosynthesis("OC1C(O)C=C(Br)C=C1")
# # 返回目标分子的可能酶催化合成路线
# """
# try:
# # Get RXN wrapper
# wrapper = _get_rxn_wrapper()
# # Clean input
# target_molecule = target_molecule.strip()
# # Submit prediction with enzymatic model
# # Note: The model name might change in future API versions
# response = await asyncio.to_thread(
# wrapper.predict_automatic_retrosynthesis,
# product=target_molecule,
# ai_model="enzymatic-2021-04-16" # Use the enzymatic model
# )
# if not response or "prediction_id" not in response:
# return "Error: 无法提交生物催化逆合成预测请求"
# # 直接获取结果而不是通过_wait_for_result
# results = await asyncio.to_thread(
# wrapper.get_predict_automatic_retrosynthesis_results,
# response["prediction_id"]
# )
# # Extract retrosynthetic paths
# try:
# paths = results.get("retrosynthetic_paths", [])
# if not paths:
# return "## 生物催化逆合成分析结果\n\n未找到可行的酶催化合成路径。目标分子可能不适合酶催化或结构有问题。"
# # Format results
# markdown = f"## 生物催化逆合成分析结果\n\n"
# markdown += f"### 目标分子\n\n`{target_molecule}`\n\n"
# markdown += f"### 找到的酶催化合成路径: {len(paths)}\n\n"
# # Limit to top 3 paths for readability
# display_paths = paths[:3]
# for i, path in enumerate(display_paths, 1):
# markdown += f"#### 路径 {i}\n\n"
# # Extract sequence information
# sequence_id = path.get("sequenceId", "未知")
# confidence = path.get("confidence", 0.0)
# markdown += f"**置信度:** {confidence:.2f}\n\n"
# # Extract steps
# steps = path.get("steps", [])
# if steps:
# markdown += "**酶催化步骤:**\n\n"
# for j, step in enumerate(steps, 1):
# # Extract reactants and products
# reactants = step.get("reactants", [])
# reactant_smiles = [r.get("smiles", "") for r in reactants if "smiles" in r]
# product = step.get("product", {})
# product_smiles = product.get("smiles", "")
# markdown += f"步骤 {j}: "
# if reactant_smiles and product_smiles:
# markdown += f"`{'.' if len(reactant_smiles) > 1 else ''.join(reactant_smiles)}` → `{product_smiles}`\n\n"
# else:
# markdown += "反应细节不可用\n\n"
# # Add enzyme information if available
# if "enzymeClass" in step:
# markdown += f"*可能的酶类别: {step['enzymeClass']}*\n\n"
# else:
# markdown += "**合成步骤:** 未提供详细步骤\n\n"
# markdown += "---\n\n"
# if len(paths) > 3:
# markdown += f"*注: 仅显示前3条路径共找到{len(paths)}条可能的酶催化合成路径。*\n"
# return markdown
# except (KeyError, IndexError) as e:
# logger.error(f"Error parsing biocatalytic retrosynthesis results: {e}")
# return f"Error: 解析生物催化逆合成结果时出错: {str(e)}"
# except Exception as e:
# logger.error(f"Error in predict_biocatalytic_retrosynthesis: {e}")
# return f"Error: {str(e)}"
@llm_tool(name="predict_reaction_properties_rxn",
description="Predict chemical reaction properties such as atom mapping and yield using IBM RXN for Chemistry API")
async def predict_reaction_properties_rxn(
reaction: str,
property_type: str = "atom-mapping"
) -> str:
"""
Predict chemical reaction properties such as atom mapping and yield.
This function uses the IBM RXN for Chemistry API to predict various properties
of chemical reactions, including atom-to-atom mapping (showing how atoms in
reactants correspond to atoms in products) and potential reaction yields.
Args:
reaction: SMILES notation of the reaction (reactants>>products).
property_type: Type of property to predict. Options: "atom-mapping", "yield".
Returns:
Formatted Markdown string containing predicted reaction properties.
Examples:
>>> predict_reaction_properties_rxn("CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F", "atom-mapping")
# Returns atom mapping for the reaction
"""
try:
# Get RXN wrapper
wrapper = _get_rxn_wrapper()
# Clean input
reaction = reaction.strip()
# Validate property_type
valid_property_types = ["atom-mapping", "yield"]
if property_type not in valid_property_types:
return f"Error: 无效的属性类型 '{property_type}'。支持的类型: {', '.join(valid_property_types)}"
# Determine model based on property type
ai_model = "atom-mapping-2020" if property_type == "atom-mapping" else "yield-2020-08-10"
# Submit prediction
response = await asyncio.to_thread(
wrapper.predict_reaction_properties,
reactions=[reaction],
ai_model=ai_model
)
if not response or "response" not in response:
return f"Error: 无法提交{property_type}预测请求"
# Extract results
try:
content = response.get("response", {}).get("payload", {}).get("content", [])
if not content:
return f"Error: 未找到{property_type}预测结果"
# Format results based on property type
markdown = f"## 反应{property_type}预测结果\n\n"
markdown += f"### 输入反应\n\n`{reaction}`\n\n"
if property_type == "atom-mapping":
# Extract mapped reaction
mapped_reaction = content[0].get("value", "")
if not mapped_reaction:
return "Error: 无法生成原子映射"
markdown += "### 原子映射结果\n\n"
markdown += f"`{mapped_reaction}`\n\n"
# Split into reactants and products for explanation
if ">>" in mapped_reaction:
reactants, products = mapped_reaction.split(">>")
markdown += "### 映射解释\n\n"
markdown += "原子映射显示了反应物中的原子如何对应到产物中的原子。\n"
markdown += "每个原子上的数字表示映射ID相同ID的原子在反应前后是同一个原子。\n\n"
markdown += f"**映射的反应物:** `{reactants}`\n\n"
markdown += f"**映射的产物:** `{products}`\n"
elif property_type == "yield":
# Extract predicted yield
predicted_yield = content[0].get("value", "")
if not predicted_yield:
return "Error: 无法预测反应产率"
try:
yield_value = float(predicted_yield)
markdown += "### 产率预测结果\n\n"
markdown += f"**预测产率:** {yield_value:.1f}%\n\n"
# Add interpretation
if yield_value < 30:
markdown += "**解释:** 预测产率较低,反应可能效率不高。考虑优化反应条件或探索替代路线。\n"
elif yield_value < 70:
markdown += "**解释:** 预测产率中等,反应可能是可行的,但有优化空间。\n"
else:
markdown += "**解释:** 预测产率较高,反应可能非常有效。\n"
except ValueError:
markdown += f"**预测产率:** {predicted_yield}\n"
return markdown
except (KeyError, IndexError) as e:
logger.error(f"Error parsing reaction properties results: {e}")
return f"Error: 解析反应属性预测结果时出错: {str(e)}"
except Exception as e:
logger.error(f"Error in predict_reaction_properties: {e}")
return f"Error: {str(e)}"
@llm_tool(name="extract_reaction_actions_rxn",
description="Extract structured reaction steps from text descriptions using IBM RXN for Chemistry API")
async def extract_reaction_actions_rxn(reaction_text: str) -> str:
"""
Extract structured reaction steps from text descriptions.
This function uses the IBM RXN for Chemistry API to parse text descriptions
of chemical procedures and extract structured actions representing the steps
of the procedure.
Args:
reaction_text: Text description of a chemical reaction procedure.
Returns:
Formatted Markdown string containing the extracted reaction steps.
Examples:
>>> extract_reaction_actions_rxn("To a stirred solution of 7-(difluoromethylsulfonyl)-4-fluoro-indan-1-one (110 mg, 0.42 mmol) in methanol (4 mL) was added sodium borohydride (24 mg, 0.62 mmol). The reaction mixture was stirred at ambient temperature for 1 hour.")
# Returns structured steps extracted from the text
"""
try:
# Get RXN wrapper
wrapper = _get_rxn_wrapper()
# Clean input
reaction_text = reaction_text.strip()
if not reaction_text:
return "Error: 反应文本为空"
# Submit extraction request
response = await asyncio.to_thread(
wrapper.paragraph_to_actions,
paragraph=reaction_text
)
# 检查response是否存在
if not response:
return "Error: 无法从文本中提取反应步骤"
# 直接返回response不做任何处理
# 这是基于参考代码中直接打印results['actions']的方式
# 我们假设response本身就是我们需要的结果
return f"""## 反应步骤提取结果
### 输入文本
{reaction_text}
### 提取的反应步骤
```
{response}
```
"""
return markdown
except Exception as e:
logger.error(f"Error in extract_reaction_actions: {e}")
return f"Error: {str(e)}"

84
sci_mcp/core/config.py Executable file
View File

@@ -0,0 +1,84 @@
"""
Configuration Module
This module provides configuration settings for the Mars Toolkit.
It includes API keys, endpoints, paths, and other configuration parameters.
"""
from typing import Dict, Any
class Config:
@classmethod
def as_dict(cls) -> Dict[str, Any]:
"""Return all configuration settings as a dictionary"""
return {
key: value for key, value in cls.__dict__.items()
if not key.startswith('__') and not callable(value)
}
@classmethod
def update(cls, **kwargs):
"""Update configuration settings"""
for key, value in kwargs.items():
if hasattr(cls, key):
setattr(cls, key, value)
class General_Config(Config):
"""Configuration class for General MCP"""
# Searxng
SEARXNG_HOST="http://192.168.168.1:40032/"
SEARXNG_MAX_RESULTS=10
class Material_Config(Config):
# Materials Project
MP_API_KEY = 'PMASAg256b814q3OaSRWeVc7MKx4mlKI'
MP_ENDPOINT = 'https://api.materialsproject.org/'
MP_TOPK = 3
LOCAL_MP_PROPS_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Props/'
LOCAL_MP_CIF_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/MPDatasets/'
# Proxy
HTTP_PROXY = ''#'http://192.168.168.1:20171'
HTTPS_PROXY = ''#'http://192.168.168.1:20171'
# FairChem
FAIRCHEM_MODEL_PATH = '/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/support/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
FMAX = 0.05
# MatterGen
MATTERGENMODEL_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/support/pretrained_models/mattergen_ckpt'
MATTERGEN_ROOT='/home/ubuntu/sas0/lzy/multi_mcp_server/sci_mcp/material_mcp/mattergen_gen/mattergen'
MATTERGENMODEL_RESULT_PATH = 'results/'
# Dify
DIFY_ROOT_URL = 'http://192.168.191.101:6080'
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
#temp root
TEMP_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/material'
class Chemistry_Config(Config):
TEMP_ROOT = '/home/ubuntu/sas0/lzy/multi_mcp_server/temp/chemistry'
RXN4_CHEMISTRY_KEY='apk-8928522a146c2503f30b16d9909222d7583f412ee8f1049f08d32a089ba88d34'
material_config = Material_Config()
general_config = General_Config()
chemistry_config = Chemistry_Config()

315
sci_mcp/core/llm_tools.py Executable file
View File

@@ -0,0 +1,315 @@
"""
LLM Tools Module
This module provides decorators and utilities for defining, registering, and managing LLM tools.
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
registered tools for use with LLM APIs.
"""
import asyncio
import inspect
import importlib
import pkgutil
import os
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
import docstring_parser
from pydantic import BaseModel, create_model, Field
# Registry to store all registered tools
_TOOL_REGISTRY = {}
# Mapping of domain names to their module paths
_DOMAIN_MODULE_MAPPING = {
'material': 'sci_mcp.material_mcp',
'general': 'sci_mcp.general_mcp',
'biology': 'sci_mcp.biology_mcp',
'chemistry': 'sci_mcp.chemistry_mcp'
}
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
"""
Decorator to mark a function as an LLM tool.
This decorator registers the function as an LLM tool, generates a JSON schema for it,
and makes it available for retrieval through the get_tools function.
Args:
name: Optional custom name for the tool. If not provided, the function name will be used.
description: Optional custom description for the tool. If not provided, the function's
docstring will be used.
Returns:
The decorated function with additional attributes for LLM tool functionality.
Example:
@llm_tool(name="weather_lookup", description="Get current weather for a location")
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
'''Get weather information for a specific location.'''
# Implementation...
return {"temperature": 22.5, "conditions": "sunny"}
"""
# Handle case when decorator is used without parentheses: @llm_tool
if callable(name):
func = name
name = None
description = None
return _llm_tool_impl(func, name, description)
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
def decorator(func: Callable) -> Callable:
return _llm_tool_impl(func, name, description)
return decorator
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
"""Implementation of the llm_tool decorator."""
# Get function signature and docstring
sig = inspect.signature(func)
doc = inspect.getdoc(func) or ""
parsed_doc = docstring_parser.parse(doc)
# Determine tool name
tool_name = name or func.__name__
# Determine tool description
tool_description = description or doc
# Create parameter properties for JSON schema
properties = {}
required = []
for param_name, param in sig.parameters.items():
# Skip self parameter for methods
if param_name == "self":
continue
param_type = param.annotation
param_default = None if param.default is inspect.Parameter.empty else param.default
param_required = param.default is inspect.Parameter.empty
# Get parameter description from docstring if available
param_desc = ""
for param_doc in parsed_doc.params:
if param_doc.arg_name == param_name:
param_desc = param_doc.description
break
# Handle Annotated types
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
args = get_args(param_type)
param_type = args[0] # The actual type
if len(args) > 1 and isinstance(args[1], str):
param_desc = args[1] # The description
# Create property for parameter
param_schema = {
"type": _get_json_type(param_type),
"description": param_desc,
"title": param_name.replace("_", " ").title()
}
# Add default value if available
if param_default is not None:
param_schema["default"] = param_default
properties[param_name] = param_schema
# Add to required list if no default value
if param_required:
required.append(param_name)
# Create OpenAI format JSON schema
openai_schema = {
"type": "function",
"function": {
"name": tool_name,
"description": tool_description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
}
# Create MCP format JSON schema
mcp_schema = {
"name": tool_name,
"description": tool_description,
"inputSchema": {
"type": "object",
"properties": properties,
"required": required
}
}
# Create Pydantic model for args schema
field_definitions = {}
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
param_type = param.annotation
param_default = ... if param.default is inspect.Parameter.empty else param.default
# Handle Annotated types
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
args = get_args(param_type)
param_type = args[0]
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
else:
field_definitions[param_name] = (param_type, Field(default=param_default))
# Create args schema model
model_name = f"{tool_name.title().replace('_', '')}Schema"
args_schema = create_model(model_name, **field_definitions)
# 根据原始函数是否是异步函数来创建相应类型的包装函数
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
else:
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
# Attach metadata to function
wrapper.is_llm_tool = True
wrapper.tool_name = tool_name
wrapper.tool_description = tool_description
wrapper.openai_schema = openai_schema
wrapper.mcp_schema = mcp_schema
wrapper.args_schema = args_schema
# Register the tool
_TOOL_REGISTRY[tool_name] = wrapper
return wrapper
def get_all_tools() -> Dict[str, Callable]:
"""
Get all registered LLM tools.
Returns:
A dictionary mapping tool names to their corresponding functions.
"""
return _TOOL_REGISTRY
def get_all_tool_schemas(schema_type='openai') -> List[Dict[str, Any]]:
"""
Get JSON schemas for all registered LLM tools.
Returns:
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
"""
return [tool.mcp_schema for tool in _TOOL_REGISTRY.values()] if schema_type == 'mcp' else [tool.openai_schema for tool in _TOOL_REGISTRY.values()]
def import_domain_tools(domains: List[str]) -> None:
"""
Import tools from specified domains to ensure they are registered.
This function dynamically imports modules from the specified domains to ensure
that all tools decorated with @llm_tool are registered in the _TOOL_REGISTRY.
Args:
domains: List of domain names (e.g., ['material', 'general'])
"""
for domain in domains:
if domain not in _DOMAIN_MODULE_MAPPING:
continue
module_path = _DOMAIN_MODULE_MAPPING[domain]
try:
# Import the base module
base_module = importlib.import_module(module_path)
base_path = os.path.dirname(base_module.__file__)
# Recursively import all submodules
for _, name, is_pkg in pkgutil.walk_packages([base_path], f"{module_path}."):
try:
importlib.import_module(name)
except ImportError as e:
print(f"Error importing {name}: {e}")
except ImportError as e:
print(f"Error importing domain {domain}: {e}")
def get_domain_tools(domains: List[str]) -> Dict[str, Dict[str, Callable]]:
"""
Get tools from specified domains.
Args:
domains: List of domain names (e.g., ['material', 'general'])
Returns:
A dictionary that maps tool names and their functions
"""
# First, ensure all tools from the specified domains are imported and registered
import_domain_tools(domains)
domain_tools = {}
for domain in domains:
if domain not in _DOMAIN_MODULE_MAPPING:
continue
domain_module_prefix = _DOMAIN_MODULE_MAPPING[domain]
for tool_name, tool_func in _TOOL_REGISTRY.items():
# Check if the tool's module belongs to this domain
if hasattr(tool_func, "__module__") and tool_func.__module__.startswith(domain_module_prefix):
domain_tools[tool_name] = tool_func
return domain_tools
def get_domain_tool_schemas(domains: List[str],schema_type='openai') -> Dict[str, List[Dict[str, Any]]]:
"""
Get JSON schemas for tools from specified domains.
Args:
domains: List of domain names (e.g., ['material', 'general'])
Returns:
A dictionary mapping domain names to lists of tool schemas
"""
# First, get all domain tools
domain_tools = get_domain_tools(domains)
if schema_type == 'mcp':
tools_schema_list = [tool.mcp_schema for tool in domain_tools.values()]
else:
tools_schema_list = [tool.openai_schema for tool in domain_tools.values()]
return tools_schema_list
def _get_json_type(python_type: Any) -> str:
"""
Convert Python type to JSON schema type.
Args:
python_type: Python type annotation
Returns:
Corresponding JSON schema type as string
"""
if python_type is str:
return "string"
elif python_type is int:
return "integer"
elif python_type is float:
return "number"
elif python_type is bool:
return "boolean"
elif python_type is list or python_type is List:
return "array"
elif python_type is dict or python_type is Dict:
return "object"
else:
# Default to string for complex types
return "string"

View File

View File

@@ -0,0 +1,78 @@
"""
Search Online Module
This module provides functions for searching information on the web.
"""
from ...core.llm_tools import llm_tool
from ...core.config import general_config
import asyncio
import os
from typing import Annotated, Any, Dict, List, Union
from langchain_community.utilities import SearxSearchWrapper
import mcp.types as types
@llm_tool(name="search_online_searxng", description="Search scientific information online using searxng")
async def search_online_searxng(
query: Annotated[str, "Search term"],
num_results: Annotated[int, "Number of results"] = 5
) -> str:
"""
Searches for scientific information online and returns results as a formatted string.
Args:
query: Search term for scientific content
num_results: Number of results to return
Returns:
Formatted string with search results (titles, snippets, links)
"""
# lzy: 此部分到正式发布时可能要删除因为searxng 已在本地部署,因此本地调试时无需设置代理
os.environ['HTTP_PROXY'] = ''
os.environ['HTTPS_PROXY'] = ''
try:
max_results = min(int(num_results), general_config.SEARXNG_MAX_RESULTS)
search = SearxSearchWrapper(
searx_host=general_config.SEARXNG_HOST,
categories=["science",],
k=num_results
)
# Execute search in a separate thread to avoid blocking the event loop
# since SearxSearchWrapper doesn't have native async support
loop = asyncio.get_event_loop()
raw_results = await loop.run_in_executor(
None,
lambda: search.results(query, language=['en','zh'], num_results=max_results)
)
# Transform results into structured format
formatted_results = []
for result in raw_results:
formatted_results.append({
"title": result.get("title", ""),
"snippet": result.get("snippet", ""),
"link": result.get("link", ""),
"source": result.get("source", "")
})
# Format results into a readable Markdown string
result_str = f"Search Results for '{query}' ({len(formatted_results)} items):\n\n"
if len(formatted_results) > 0:
for i, res in enumerate(formatted_results):
title = res.get("title", "No Title")
snippet = res.get("snippet", "No Snippet")
link = res.get("link", "No Link")
source = res.get("source", "No Source")
result_str += f"{i + 1}. **{title}**\n"
result_str += f" - Snippet: {snippet}\n"
result_str += f" - Link: [{link}]({link})\n"
result_str += f" - Source: {source}\n\n"
else:
result_str += "No results found.\n"
return result_str
except Exception as e:
return f"Error: {str(e)}"

View File

@@ -0,0 +1,34 @@
# # Core modules
# from mars_toolkit.core.config import config
# # Basic tools
# from mars_toolkit.misc.misc_tools import get_current_time
# # Compute modules
# from mars_toolkit.compute.material_gen import generate_material
# from mars_toolkit.compute.property_pred import predict_properties
# from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
# # Query modules
# from mars_toolkit.query.mp_query import (
# search_material_property_from_material_project,
# get_crystal_structures_from_materials_project,
# get_mpid_from_formula
# )
# from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
# from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
# from mars_toolkit.query.web_search import search_online
# # Visualization modules
# from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
# __version__ = "0.1.0"
# __all__ = ["llm_tool", "get_tools", "get_tool_schemas"]

View File

@@ -0,0 +1,191 @@
"""
Structure Optimization Module
This module provides functions for optimizing crystal structures using the FairChem model.
"""
import asyncio
from io import StringIO
import sys
import tempfile
import os
import logging
from typing import Optional, Dict, Any
from ase.io import read, write
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from ..support.utils import convert_structure, remove_symmetry_equiv_xyz, read_structure_from_file_name_or_content_string
from ...core.llm_tools import llm_tool
from ...core.config import material_config
logger = logging.getLogger(__name__)
# 初始化FairChem模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
if calc is not None:
return
try:
from fairchem.core import OCPCalculator
calc = OCPCalculator(checkpoint_path=material_config.FAIRCHEM_MODEL_PATH)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def generate_symmetry_cif(structure: Structure) -> str:
"""
生成对称性CIF
Args:
structure: Pymatgen Structure对象
Returns:
CIF格式的字符串
"""
analyzer = SpacegroupAnalyzer(structure)
structure_refined = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
return content
def optimize_structure(atoms: Atoms, output_format: str, fmax: float = 0.05) -> str:
"""
优化晶体结构
Args:
atoms: ASE Atoms对象
output_format: 输出格式 (cif, xyz, vasp等)
fmax: 力收敛标准
Returns:
包含优化结果的格式化字符串
"""
atoms.calc = calc
try:
# 捕获优化过程的输出
temp_output = StringIO()
original_stdout = sys.stdout
sys.stdout = temp_output
# 执行优化
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=fmax)
# 恢复标准输出并获取日志
sys.stdout = original_stdout
optimization_log = temp_output.getvalue()
temp_output.close()
# 获取总能量
total_energy = atoms.get_potential_energy()
# 处理优化后的结构
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
content = remove_symmetry_equiv_xyz(content)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
**Total Energy: {total_energy} eV**
#### Optimizing Log:
```text
{optimization_log}
```
### Optimized {output_format.upper()} Content:
```
{content}
```
"""
return format_result
except Exception as e:
return f"Error: Failed to optimize structure: {str(e)}"
@llm_tool(name="optimize_crystal_structure_FairChem",
description="Optimizes crystal structures using the FairChem model")
async def optimize_crystal_structure_FairChem(
structure_source: str,
format_type: str = "auto",
optimization_level: str = "normal"
) -> str:
"""
Optimizes a crystal structure to find its lowest energy configuration.
Args:
structure_source: Either a file name or direct structure content (CIF, XYZ, POSCAR)
format_type: Structure format type (auto, cif, xyz, poscar). Default "auto" will attempt to detect format.
optimization_level: Optimization precision level (quick, normal, precise)
Returns:
Optimized structure with total energy and optimization details
"""
# 确保模型已初始化
if calc is None:
init_model()
# 设置优化参数
fmax_values = {
"quick": 0.1,
"normal": 0.05,
"precise": 0.01
}
fmax = fmax_values.get(optimization_level, 0.05)
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_optimization():
try:
# 处理输入结构
content, actual_format = read_structure_from_file_name_or_content_string(structure_source, format_type)
# 转换格式映射
format_mapping = {
"cif": "cif",
"xyz": "xyz",
"poscar": "vasp",
"vasp": "vasp"
}
final_format = format_mapping.get(actual_format.lower(), "cif")
# 转换结构
atoms = convert_structure(final_format, content)
if atoms is None:
return f"Error: Unable to convert input structure. Please check if the format is correct."
# 优化结构
return optimize_structure(atoms, final_format, fmax=fmax)
except Exception as e:
return f"Error: Failed to optimize structure: {str(e)}"
return await asyncio.to_thread(run_optimization)

View File

@@ -0,0 +1,70 @@
import codecs
import json
import requests
from ...core.llm_tools import llm_tool
from ...core.config import material_config
@llm_tool(
name="retrieval_from_knowledge_base",
description="Retrieve information from local materials science literature knowledge base"
)
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
"""
检索本地材料科学文献知识库中的相关信息
Args:
query: 查询字符串,如材料名称"CsPbBr3"
topk: 返回结果数量默认3条
Returns:
包含文档ID、标题和相关性分数的字典
"""
# 设置Dify API的URL端点
url = f'{material_config.DIFY_ROOT_URL}/v1/chat-messages'
# 配置请求头包含API密钥和内容类型
headers = {
'Authorization': f'Bearer {material_config.DIFY_API_KEY}',
'Content-Type': 'application/json'
}
# 准备请求数据
data = {
"inputs": {"topK": topk}, # 设置返回的最大结果数量
"query": query, # 设置查询字符串
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
"conversation_id": "", # 不使用会话ID每次都是独立查询
"user": "abc-123" # 用户标识符
}
try:
# 发送POST请求到Dify API并获取响应
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
response = requests.post(url, headers=headers, json=data, timeout=1111)
# 获取响应文本
response_text = response.text
# 解码响应文本中的Unicode转义序列
response_text = codecs.decode(response_text, 'unicode_escape')
# 将响应文本解析为JSON对象
result_json = json.loads(response_text)
# 从响应中提取元数据
metadata = result_json.get("metadata", {})
# 构建包含关键信息的结果字典
useful_info = {
"id": metadata.get("document_id"), # 文档ID
"title": result_json.get("title"), # 文档标题
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
"score": metadata.get("score") # 相关性分数
}
# 返回提取的有用信息
return json.dumps(useful_info, ensure_ascii=False, indent=2)
except Exception as e:
# 捕获并处理所有可能的异常,返回错误信息
return f"Error: {str(e)}"

View File

@@ -0,0 +1,8 @@
"""
MatGL Tools Module
This module provides tools for material structure relaxation and property prediction
using MatGL (Materials Graph Library) models.
"""
from .matgl_tools import *

View File

@@ -0,0 +1,487 @@
"""
MatGL Tools Module
This module provides tools for material structure relaxation and property prediction
using MatGL (Materials Graph Library) models.
"""
from __future__ import annotations
from ...core.config import material_config
import warnings
import json
from typing import Dict, List, Union, Optional, Any
import torch
from pymatgen.core import Lattice, Structure
from pymatgen.ext.matproj import MPRester
from pymatgen.io.ase import AseAtomsAdaptor
import matgl
from matgl.ext.ase import Relaxer, MolecularDynamics, PESCalculator
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ...core.llm_tools import llm_tool
import os
from ..support.utils import read_structure_from_file_name_or_content_string
# To suppress warnings for clearer output
warnings.simplefilter("ignore")
@llm_tool(name="relax_crystal_structure_M3GNet",
description="Optimize crystal structure geometry using M3GNet universal potential from a structure file or content string")
async def relax_crystal_structure_M3GNet(
structure_source: str,
fmax: float = 0.01
) -> str:
"""
Optimize crystal structure geometry to find its equilibrium configuration.
Uses M3GNet universal potential for fast and accurate structure relaxation without DFT.
Accepts a structure file or content string.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
fmax: Maximum force tolerance for convergence in eV/Å (default: 0.01).
Returns:
A Markdown formatted string with the relaxation results or an error message.
"""
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# 使用pymatgen读取结构
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# Load the M3GNet universal potential model
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
# Create a relaxer and relax the structure
relaxer = Relaxer(potential=pot)
relax_results = relaxer.relax(structure, fmax=fmax)
# Get the relaxed structure
relaxed_structure = relax_results["final_structure"]
reduced_formula = relaxed_structure.composition.reduced_formula
# 添加结构信息
lattice_info = relaxed_structure.lattice
volume = relaxed_structure.volume
density = relaxed_structure.density
symmetry = relaxed_structure.get_space_group_info()
# 构建原子位置表格
sites_table = " # SP a b c\n"
sites_table += "--- ---- -------- -------- --------\n"
for i, site in enumerate(relaxed_structure):
frac_coords = site.frac_coords
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
return (f"## Structure Relaxation\n\n"
f"- **Structure**: `{reduced_formula}`\n"
f"- **Force Tolerance**: `{fmax} eV/Å`\n"
f"- **Status**: `Successfully relaxed`\n\n"
f"### Relaxed Structure Information\n\n"
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
f"- **Volume**: `{volume:.2f} ų`\n"
f"- **Density**: `{density:.2f} g/cm³`\n"
f"- **Lattice Parameters**:\n"
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
f"### Atomic Positions (Fractional Coordinates)\n\n"
f"```\n"
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
f"pbc : {relaxed_structure.lattice.pbc[0]!s:5s} {relaxed_structure.lattice.pbc[1]!s:5s} {relaxed_structure.lattice.pbc[2]!s:5s}\n"
f"Sites ({len(relaxed_structure)})\n"
f"{sites_table}```\n")
except Exception as e:
return f"Error during structure relaxation: {str(e)}"
# 内部函数,用于结构优化,返回结构对象而不是格式化字符串
async def _relax_crystal_structure_M3GNet_internal(
structure_source: str,
fmax: float = 0.01
) -> Union[Structure, str]:
"""
内部使用的结构优化函数,返回结构对象而不是格式化字符串。
Args:
structure_source: 结构文件名或内容字符串
fmax: 力收敛阈值 (eV/Å)
Returns:
优化后的结构对象或错误信息
"""
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# 使用pymatgen读取结构
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# Load the M3GNet universal potential model
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
# Create a relaxer and relax the structure
relaxer = Relaxer(potential=pot)
relax_results = relaxer.relax(structure, fmax=fmax)
# Get the relaxed structure
relaxed_structure = relax_results["final_structure"]
return relaxed_structure
except Exception as e:
return f"Error during structure relaxation: {str(e)}"
@llm_tool(name="predict_formation_energy_M3GNet",
description="Predict the formation energy of a crystal structure using the M3GNet formation energy model from a structure file or content string, with optional structure optimization")
async def predict_formation_energy_M3GNet(
structure_source: str,
optimize_structure: bool = True,
fmax: float = 0.01
) -> str:
"""
Predict the formation energy of a crystal structure using the M3GNet formation energy model.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
optimize_structure: Whether to optimize the structure before prediction (default: True).
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
Returns:
A Markdown formatted string containing the predicted formation energy in eV/atom or an error message.
"""
try:
# 获取结构(优化或不优化)
if optimize_structure:
# 使用内部函数优化结构
structure = await _relax_crystal_structure_M3GNet_internal(
structure_source=structure_source,
fmax=fmax
)
# 检查优化是否成功
if isinstance(structure, str) and structure.startswith("Error"):
return structure
else:
# 直接读取结构,不进行优化
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# 加载预训练模型
model = matgl.load_model("M3GNet-MP-2018.6.1-Eform")
# 预测形成能
eform = model.predict_structure(structure)
reduced_formula = structure.composition.reduced_formula
# 构建结果字符串
optimization_status = "optimized" if optimize_structure else "non-optimized"
# 添加结构信息
lattice_info = structure.lattice
volume = structure.volume
density = structure.density
symmetry = structure.get_space_group_info()
# 构建原子位置表格
sites_table = " # SP a b c\n"
sites_table += "--- ---- -------- -------- --------\n"
for i, site in enumerate(structure):
frac_coords = site.frac_coords
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
return (f"## Formation Energy Prediction\n\n"
f"- **Structure**: `{reduced_formula}`\n"
f"- **Structure Status**: `{optimization_status}`\n"
f"- **Formation Energy**: `{float(eform):.3f} eV/atom`\n\n"
f"### Structure Information\n\n"
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
f"- **Volume**: `{volume:.2f} ų`\n"
f"- **Density**: `{density:.2f} g/cm³`\n"
f"- **Lattice Parameters**:\n"
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
f"### Atomic Positions (Fractional Coordinates)\n\n"
f"```\n"
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
f"Sites ({len(structure)})\n"
f"{sites_table}```\n")
except Exception as e:
return f"Error: {str(e)}"
@llm_tool(name="run_molecular_dynamics_M3GNet",
description="Run molecular dynamics simulation on a crystal structure using M3GNet universal potential, with optional structure optimization")
async def run_molecular_dynamics_M3GNet(
structure_source: str,
temperature_K: float = 300,
steps: int = 100,
optimize_structure: bool = True,
fmax: float = 0.01
) -> str:
"""
Run molecular dynamics simulation on a crystal structure using M3GNet universal potential.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
temperature_K: Temperature for MD simulation in Kelvin (default: 300).
steps: Number of MD steps to run (default: 100).
optimize_structure: Whether to optimize the structure before simulation (default: True).
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
Returns:
A Markdown formatted string containing the simulation results, including final potential energy.
"""
try:
# 获取结构(优化或不优化)
if optimize_structure:
# 使用内部函数优化结构
structure = await _relax_crystal_structure_M3GNet_internal(
structure_source=structure_source,
fmax=fmax
)
# 检查优化是否成功
if isinstance(structure, str) and structure.startswith("Error"):
return structure
else:
# 直接读取结构,不进行优化
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# Load the M3GNet universal potential model
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
# Convert pymatgen structure to ASE atoms
ase_adaptor = AseAtomsAdaptor()
atoms = ase_adaptor.get_atoms(structure)
# Initialize the velocity according to Maxwell Boltzmann distribution
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
# Create the MD class and run simulation
driver = MolecularDynamics(atoms, potential=pot, temperature=temperature_K)
driver.run(steps)
# Get final potential energy
final_energy = atoms.get_potential_energy()
# Get final structure
final_structure = ase_adaptor.get_structure(atoms)
reduced_formula = final_structure.composition.reduced_formula
# 构建结果字符串
optimization_status = "optimized" if optimize_structure else "non-optimized"
# 添加结构信息
lattice_info = final_structure.lattice
volume = final_structure.volume
density = final_structure.density
symmetry = final_structure.get_space_group_info()
# 构建原子位置表格
sites_table = " # SP a b c\n"
sites_table += "--- ---- -------- -------- --------\n"
for i, site in enumerate(final_structure):
frac_coords = site.frac_coords
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
return (f"## Molecular Dynamics Simulation\n\n"
f"- **Structure**: `{reduced_formula}`\n"
f"- **Structure Status**: `{optimization_status}`\n"
f"- **Temperature**: `{temperature_K} K`\n"
f"- **Steps**: `{steps}`\n"
f"- **Final Potential Energy**: `{float(final_energy):.3f} eV`\n\n"
f"### Final Structure Information\n\n"
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
f"- **Volume**: `{volume:.2f} ų`\n"
f"- **Density**: `{density:.2f} g/cm³`\n"
f"- **Lattice Parameters**:\n"
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
f"### Atomic Positions (Fractional Coordinates)\n\n"
f"```\n"
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
f"pbc : {final_structure.lattice.pbc[0]!s:5s} {final_structure.lattice.pbc[1]!s:5s} {final_structure.lattice.pbc[2]!s:5s}\n"
f"Sites ({len(final_structure)})\n"
f"{sites_table}```\n")
except Exception as e:
return f"Error: {str(e)}"
@llm_tool(name="calculate_single_point_energy_M3GNet",
description="Calculate single point energy of a crystal structure using M3GNet universal potential, with optional structure optimization")
async def calculate_single_point_energy_M3GNet(
structure_source: str,
optimize_structure: bool = True,
fmax: float = 0.01
) -> str:
"""
Calculate single point energy of a crystal structure using M3GNet universal potential.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
optimize_structure: Whether to optimize the structure before calculation (default: True).
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
Returns:
A Markdown formatted string containing the calculated potential energy in eV.
"""
try:
# 获取结构(优化或不优化)
if optimize_structure:
# 使用内部函数优化结构
structure = await _relax_crystal_structure_M3GNet_internal(
structure_source=structure_source,
fmax=fmax
)
# 检查优化是否成功
if isinstance(structure, str) and structure.startswith("Error"):
return structure
else:
# 直接读取结构,不进行优化
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
structure = Structure.from_str(structure_content, fmt=content_format)
if structure is None:
return "Error: Failed to obtain a valid structure"
# Load the M3GNet universal potential model
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
# Convert pymatgen structure to ASE atoms
ase_adaptor = AseAtomsAdaptor()
atoms = ase_adaptor.get_atoms(structure)
# Set up the calculator for atoms object
calc = PESCalculator(pot)
atoms.set_calculator(calc)
# Calculate potential energy
energy = atoms.get_potential_energy()
reduced_formula = structure.composition.reduced_formula
# 构建结果字符串
optimization_status = "optimized" if optimize_structure else "non-optimized"
# 添加结构信息
lattice_info = structure.lattice
volume = structure.volume
density = structure.density
symmetry = structure.get_space_group_info()
# 构建原子位置表格
sites_table = " # SP a b c\n"
sites_table += "--- ---- -------- -------- --------\n"
for i, site in enumerate(structure):
frac_coords = site.frac_coords
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
return (f"## Single Point Energy Calculation\n\n"
f"- **Structure**: `{reduced_formula}`\n"
f"- **Structure Status**: `{optimization_status}`\n"
f"- **Potential Energy**: `{float(energy):.3f} eV`\n\n"
f"### Structure Information\n\n"
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
f"- **Volume**: `{volume:.2f} ų`\n"
f"- **Density**: `{density:.2f} g/cm³`\n"
f"- **Lattice Parameters**:\n"
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
f"### Atomic Positions (Fractional Coordinates)\n\n"
f"```\n"
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
f"Sites ({len(structure)})\n"
f"{sites_table}```\n")
except Exception as e:
return f"Error: {str(e)}"
#Error: Bad serialized model or bad model name. It is possible that you have an older model cached. Please clear your cache by running `python -c "import matgl; matgl.clear_cache()"`
# @llm_tool(name="predict_band_gap",
# description="Predict the band gap of a crystal structure using MEGNet multi-fidelity model from either a chemical formula or CIF file, with structure optimization")
# async def predict_band_gap(
# formula: str = None,
# cif_file_name: str = None,
# method: str = "PBE",
# fmax: float = 0.01
# ) -> str:
# """
# Predict the band gap of a crystal structure using the MEGNet multi-fidelity band gap model.
# First optimizes the crystal structure using M3GNet universal potential, then predicts
# the band gap based on the relaxed structure for more accurate results.
# Accepts either a chemical formula (searches Materials Project database) or a CIF file.
# Args:
# formula: Chemical formula to retrieve from Materials Project (e.g., "Fe2O3").
# cif_file_name: Name of CIF file in temp directory to use as structure source.
# method: The DFT method to use for the prediction. Options are "PBE", "GLLB-SC", "HSE", or "SCAN".
# Default is "PBE".
# fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
# Returns:
# A string containing the predicted band gap in eV or an error message.
# """
# try:
# # First, relax the crystal structure
# relaxed_result = await relax_crystal_structure(
# formula=formula,
# cif_file_name=cif_file_name,
# fmax=fmax
# )
# # Check if relaxation was successful
# if isinstance(relaxed_result, str) and relaxed_result.startswith("Error"):
# return relaxed_result
# # Use the relaxed structure for band gap prediction
# structure = relaxed_result
# if structure is None:
# return "Error: Failed to obtain a valid relaxed structure"
# # Load the pre-trained MEGNet band gap model
# model = matgl.load_model("MEGNet-MP-2019.4.1-BandGap-mfi")
# # Map method name to index
# method_map = {"PBE": 0, "GLLB-SC": 1, "HSE": 2, "SCAN": 3}
# if method not in method_map:
# return f"Error: Unsupported method: {method}. Choose from PBE, GLLB-SC, HSE, or SCAN."
# # Set the graph label based on the method
# graph_attrs = torch.tensor([method_map[method]])
# # Predict the band gap using the relaxed structure
# bandgap = model.predict_structure(structure=structure, state_attr=graph_attrs)
# reduced_formula = structure.reduced_formula
# # Return the band gap as a string
# return f"The predicted band gap for relaxed {reduced_formula} using {method} method is {float(bandgap):.3f} eV."
# except Exception as e:
# return f"Error: {str(e)}"

View File

@@ -0,0 +1,240 @@
import ast
import json
import logging
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
import tempfile
import os
import datetime
import asyncio
import zipfile
import shutil
import re
import multiprocessing
from multiprocessing import Process, Queue
from pathlib import Path
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
import logging
# 设置多进程启动方法为spawn解决CUDA初始化错误
try:
multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
# 如果已经设置过启动方法会抛出RuntimeError
pass
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
# 导入路径已更新
from ...core.llm_tools import llm_tool
from .mattergen_wrapper import *
# 使用mattergen_wrapper
import sys
import os
def convert_values(data_str):
"""
将字符串转换为字典
Args:
data_str: JSON字符串
Returns:
解析后的数据,如果解析失败则返回原字符串
"""
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return data_str # 如果无法解析为JSON返回原字符串
return data
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
"""
Preprocess a property value based on its name, converting it to the appropriate type.
Args:
property_name: Name of the property
property_value: Value of the property (can be string, float, or int)
Returns:
Tuple of (property_name, processed_value)
Raises:
ValueError: If the property value is invalid for the given property name
"""
valid_properties = [
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
"energy_above_hull", "formation_energy_per_atom", "space_group",
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
]
if property_name not in valid_properties:
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
# Process property_value if it's a string
if isinstance(property_value, str):
try:
# Try to convert string to float for numeric properties
if property_name != "chemical_system":
property_value = float(property_value)
except ValueError:
# If conversion fails, keep as string (for chemical_system)
pass
# Handle special cases for properties that need specific types
if property_name == "chemical_system":
if isinstance(property_value, (int, float)):
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
property_value = str(property_value)
elif property_name == "space_group" :
space_group = property_value
if space_group < 1 or space_group > 230:
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
return property_name, property_value
def main(
output_path: str,
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
model_path: str | None = None,
batch_size: int = 2,
num_batches: int = 1,
config_overrides: list[str] | None = None,
checkpoint_epoch: Literal["best", "last"] | int = "last",
properties_to_condition_on: TargetProperty | None = None,
sampling_config_path: str | None = None,
sampling_config_name: str = "default",
sampling_config_overrides: list[str] | None = None,
record_trajectories: bool = True,
diffusion_guidance_factor: float | None = None,
strict_checkpoint_loading: bool = True,
target_compositions: list[dict[str, int]] | None = None,
):
"""
Evaluate diffusion model against molecular metrics.
Args:
model_path: Path to DiffusionLightningModule checkpoint directory.
output_path: Path to output directory.
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
record: Whether to record the trajectories of the generated structures. (default: True)
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
Only supported for models trained for crystal structure prediction (CSP) (default: None)
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
"""
assert (
pretrained_name is not None or model_path is not None
), "Either pretrained_name or model_path must be provided."
assert (
pretrained_name is None or model_path is None
), "Only one of pretrained_name or model_path can be provided."
if not os.path.exists(output_path):
os.makedirs(output_path)
sampling_config_overrides = sampling_config_overrides or []
config_overrides = config_overrides or []
properties_to_condition_on = properties_to_condition_on or {}
target_compositions = target_compositions or []
if pretrained_name is not None:
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
pretrained_name, config_overrides=config_overrides
)
else:
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch=checkpoint_epoch,
config_overrides=config_overrides,
strict_checkpoint_loading=strict_checkpoint_loading,
)
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=properties_to_condition_on,
batch_size=batch_size,
num_batches=num_batches,
sampling_config_name=sampling_config_name,
sampling_config_path=_sampling_config_path,
sampling_config_overrides=sampling_config_overrides,
record_trajectories=record_trajectories,
diffusion_guidance_factor=(
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
),
target_compositions_dict=target_compositions,
)
generator.generate(output_dir=Path(output_path))
@llm_tool(name="generate_material_MatterGen", description="Generate crystal structures with optional property constraints using MatterGen model")
def generate_material_MatterGen(
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
) -> str:
"""
Generate crystal structures with optional property constraints.
This unified function can generate materials in three modes:
1. Unconditional generation (no properties specified)
2. Single property conditional generation (one property specified)
3. Multi-property conditional generation (multiple properties specified)
Args:
properties: Optional property constraints. Can be:
- None or empty dict for unconditional generation
- Dict with single key-value pair for single property conditioning
- Dict with multiple key-value pairs for multi-property conditioning
Valid property names include: "dft_band_gap", "chemical_system", etc.
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
Descriptive text with generated crystal structures in CIF format
"""
# 导入MatterGenService
from .mattergen_service import MatterGenService
logger.info("子进程成功导入MatterGenService")
# 获取MatterGenService实例
service = MatterGenService.get_instance()
logger.info("子进程成功获取MatterGenService实例")
# 使用服务生成材料
logger.info("子进程开始调用generate方法...")
result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
logger.info("子进程generate方法调用完成")
if "Error generating structures" in result:
return f"Error: Invalid properties {properties}."
else:
return result

View File

@@ -0,0 +1,466 @@
"""
MatterGen service for mars_toolkit.
This module provides a service for generating crystal structures using MatterGen.
The service initializes the CrystalGenerator once and reuses it for multiple
generation requests, improving performance.
"""
import datetime
import os
import logging
import json
from pathlib import Path
import re
from typing import Dict, Any, Optional, Union, List
import threading
import torch
from .mattergen_wrapper import *
from ...core.config import material_config
logger = logging.getLogger(__name__)
def format_cif_content(content):
"""
Format CIF content by removing unnecessary headers and organizing each CIF file.
Args:
content: String containing CIF content, possibly with PK headers
Returns:
Formatted string with each CIF file properly labeled and formatted
"""
# 如果内容为空,直接返回空字符串
if not content or content.strip() == '':
return ''
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
# 但我们需要保留这个字段在每个CIF文件中
cif_blocks = []
# 查找所有_chemical_formula_structural的位置
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
# 如果没有找到任何_chemical_formula_structural返回空字符串
if not formula_positions:
return ''
# 分割CIF块
for i in range(len(formula_positions)):
start_pos = formula_positions[i]
# 如果是最后一个块,结束位置是字符串末尾
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
cif_block = content[start_pos:end_pos].strip()
# 提取formula值
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
if formula_match:
formula = formula_match.group(1)
cif_blocks.append((formula, cif_block))
# 格式化输出
result = []
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]\n"
result.append(formatted)
return "\n".join(result)
def extract_cif_file_from_zip(cifs_zip_path: str):
"""
Extract CIF files from a zip archive, extract formula from each CIF file,
and save each CIF file with its formula as the filename.
Args:
cifs_zip_path: Path to the zip file
Returns:
list: List of tuples containing (index, formula, cif_path)
"""
result_dict = {}
if os.path.exists(cifs_zip_path):
with open(cifs_zip_path, 'rb') as f:
result_dict['cif_content'] = f.read()
cifs_content = format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))
pattern = r'\[cif (\d+) begin\]\n(.*?)\n\[cif \1 end\]'
matches = re.findall(pattern, cifs_content, re.DOTALL)
# 处理每个匹配项提取formula并保存CIF文件
saved_files = []
for idx, cif_content in matches:
# 提取data_{formula}中的formula
formula_match = re.search(r'data_([^\s]+)', cif_content)
if formula_match:
formula = formula_match.group(1)
# 构建保存路径
cif_path = os.path.join(material_config.TEMP_ROOT, f"{formula}.cif")
# 保存CIF文件
with open(cif_path, 'w') as f:
f.write(cif_content)
saved_files.append((idx, formula, cif_path))
return saved_files
class MatterGenService:
"""
Service for generating crystal structures using MatterGen.
This service initializes the CrystalGenerator once and reuses it for multiple
generation requests, improving performance.
"""
_instance = None
_lock = threading.Lock()
# 模型到GPU ID的映射
MODEL_TO_GPU = {
"mattergen_base": "0", # 基础模型使用GPU 0
"dft_mag_density": "1", # 磁密度模型使用GPU 1
"dft_bulk_modulus": "2", # 体积模量模型使用GPU 2
"dft_shear_modulus": "3", # 剪切模量模型使用GPU 3
"energy_above_hull": "4", # 能量模型使用GPU 4
"formation_energy_per_atom": "5", # 形成能模型使用GPU 5
"space_group": "6", # 空间群模型使用GPU 6
"hhi_score": "7", # HHI评分模型使用GPU 7
"ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0
"chemical_system": "1", # 化学系统模型使用GPU 1
"dft_band_gap": "2", # 带隙模型使用GPU 2
"dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3
"chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4
}
@classmethod
def get_instance(cls):
"""
Get the singleton instance of MatterGenService.
Returns:
MatterGenService: The singleton instance.
"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""
Initialize the MatterGenService.
This initializes the base generator without any property conditioning.
Specific generators for different property conditions will be initialized
on demand.
"""
self._generators = {}
self._output_dir = material_config.TEMP_ROOT
# 确保输出目录存在
if not os.path.exists(self._output_dir):
os.makedirs(self._output_dir)
# 初始化基础生成器(无条件生成)
self._init_base_generator()
def _init_base_generator(self):
"""
Initialize the base generator for unconditional generation.
"""
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
if not os.path.exists(model_path):
logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.")
return
logger.info(f"Initializing base MatterGen generator from {model_path}")
try:
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch="last",
config_overrides=[],
strict_checkpoint_loading=True,
)
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=None,
batch_size=2, # 默认值,可在生成时覆盖
num_batches=1, # 默认值,可在生成时覆盖
sampling_config_name="default",
sampling_config_path=None,
sampling_config_overrides=[],
record_trajectories=True,
diffusion_guidance_factor=0.0,
target_compositions_dict=[],
)
self._generators["base"] = generator
logger.info("Base MatterGen generator initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize base MatterGen generator: {e}")
def _get_or_create_generator(
self,
properties: Optional[Dict[str, Any]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
):
"""
Get or create a generator for the specified properties.
Args:
properties: Optional property constraints
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
tuple: (generator, generator_key, properties_to_condition_on, gpu_id)
"""
# 如果没有属性约束,使用基础生成器
if not properties:
if "base" not in self._generators:
self._init_base_generator()
gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0
return self._generators.get("base"), "base", None, gpu_id
# 处理属性约束
properties_to_condition_on = {}
for property_name, property_value in properties.items():
properties_to_condition_on[property_name] = property_value
# 确定模型目录
if len(properties) == 1:
# 单属性条件
property_name = list(properties.keys())[0]
property_to_model = {
"dft_mag_density": "dft_mag_density",
"dft_bulk_modulus": "dft_bulk_modulus",
"dft_shear_modulus": "dft_shear_modulus",
"energy_above_hull": "energy_above_hull",
"formation_energy_per_atom": "formation_energy_per_atom",
"space_group": "space_group",
"hhi_score": "hhi_score",
"ml_bulk_modulus": "ml_bulk_modulus",
"chemical_system": "chemical_system",
"dft_band_gap": "dft_band_gap"
}
model_dir = property_to_model.get(property_name, property_name)
generator_key = f"single_{property_name}"
else:
# 多属性条件
property_keys = set(properties.keys())
if property_keys == {"dft_mag_density", "hhi_score"}:
model_dir = "dft_mag_density_hhi_score"
generator_key = "multi_dft_mag_density_hhi_score"
elif property_keys == {"chemical_system", "energy_above_hull"}:
model_dir = "chemical_system_energy_above_hull"
generator_key = "multi_chemical_system_energy_above_hull"
else:
# 如果没有特定的多属性模型,使用第一个属性的模型
first_property = list(properties.keys())[0]
model_dir = first_property
generator_key = f"multi_{first_property}_etc"
# 获取对应的GPU ID
gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0
# 构建完整的模型路径
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, model_dir)
# 检查模型目录是否存在
if not os.path.exists(model_path):
# 如果特定模型不存在,回退到基础模型
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
generator_key = "base"
# 检查是否已经有这个生成器
if generator_key in self._generators:
# 更新生成器的参数
generator = self._generators[generator_key]
generator.batch_size = batch_size
generator.num_batches = num_batches
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
return generator, generator_key, properties_to_condition_on, gpu_id
# 创建新的生成器
try:
logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}")
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch="last",
config_overrides=[],
strict_checkpoint_loading=True,
)
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=properties_to_condition_on,
batch_size=batch_size,
num_batches=num_batches,
sampling_config_name="default",
sampling_config_path=None,
sampling_config_overrides=[],
record_trajectories=True,
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0,
target_compositions_dict=[],
)
self._generators[generator_key] = generator
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
return generator, generator_key, properties_to_condition_on, gpu_id
except Exception as e:
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
# 回退到基础生成器
if "base" not in self._generators:
self._init_base_generator()
base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0")
return self._generators.get("base"), "base", None, base_gpu_id
def generate(
self,
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
) -> str:
"""
Generate crystal structures with optional property constraints.
Args:
properties: Optional property constraints
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
str: Descriptive text with generated crystal structures in CIF format
"""
# 处理字符串输入(如果提供)
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError:
raise ValueError(f"Invalid properties JSON string: {properties}")
# 如果为None默认为空字典
properties = properties or {}
# 获取或创建生成器和GPU ID
generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator(
properties, batch_size, num_batches, diffusion_guidance_factor
)
print("gpu_id",gpu_id)
if generator is None:
return "Error: Failed to initialize MatterGen generator"
# 使用torch.cuda.set_device()直接设置当前GPU
try:
# 将字符串类型的gpu_id转换为整数
cuda_device_id = int(gpu_id)
torch.cuda.set_device(cuda_device_id)
logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}")
print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}")
except Exception as e:
logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.")
# 生成结构
try:
output_dir= Path(self._output_dir+f'/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}')
Path.mkdir(output_dir, parents=True, exist_ok=True)
generator.generate(output_dir=output_dir)
except Exception as e:
logger.error(f"Error generating structures: {e}")
return f"Error generating structures: {e}"
# 创建字典存储文件内容
result_dict = {}
# 定义文件路径
cif_zip_path = os.path.join(str(output_dir), f"generated_crystals_cif.zip")
xyz_file_path = os.path.join(str(output_dir), f"generated_crystals.extxyz")
trajectories_zip_path = os.path.join(str(output_dir), f"generated_trajectories.zip")
# 读取CIF压缩文件
if os.path.exists(cif_zip_path):
with open(cif_zip_path, 'rb') as f:
result_dict['cif_content'] = f.read()
# 根据生成类型创建描述性提示
if not properties:
generation_type = "unconditional"
title = "Generated Material Structures"
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
property_description = "unconditionally"
elif len(properties) == 1:
generation_type = "single_property"
property_name = list(properties.keys())[0]
property_value = properties[property_name]
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
property_description = f"conditioned on {property_name} = {property_value}"
else:
generation_type = "multi_property"
title = "Generated Material Structures Conditioned on Multiple Properties"
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
# 创建完整的提示
prompt = f"""
# {title}
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
{'' if generation_type == 'unconditional' else f'''
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
the generation adheres to the specified property values. Higher values produce samples that more
closely match the target properties but may reduce diversity.
'''}
## CIF Files (Crystallographic Information Files)
- Standard format for crystallographic structures
- Contains unit cell parameters, atomic positions, and symmetry information
- Used by crystallographic software and visualization tools
```
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
```
{description}
You can use these structures for materials discovery, property prediction, or further analysis.
"""
# print("prompt",prompt)
# 清理文件(读取后删除)
# try:
# if os.path.exists(cif_zip_path):
# os.remove(cif_zip_path)
# if os.path.exists(xyz_file_path):
# os.remove(xyz_file_path)
# if os.path.exists(trajectories_zip_path):
# os.remove(trajectories_zip_path)
# except Exception as e:
# logger.warning(f"Error cleaning up files: {e}")
# GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理
logger.info(f"Generation completed on GPU for model {generator_key}")
return prompt

View File

@@ -0,0 +1,26 @@
"""
This is a wrapper module that provides access to the mattergen modules
by modifying the Python path at runtime.
"""
import sys
import os
from pathlib import Path
from ...core.config import material_config
# Add the mattergen directory to the Python path
mattergen_dir = material_config.MATTERGEN_ROOT
sys.path.insert(0, mattergen_dir)
# Import the necessary modules from the mattergen package
try:
from mattergen import generator
from mattergen.common.data import chemgraph
from mattergen.common.data.types import TargetProperty
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
except ImportError as e:
print(f"Error importing mattergen modules: {e}")
print(f"Python path: {sys.path}")
raise
CrystalGenerator = generator.CrystalGenerator
# Re-export the modules
__all__ = ['generator', 'chemgraph', 'TargetProperty', 'MatterGenCheckpointInfo', 'PRETRAINED_MODEL_NAME','CrystalGenerator']

View File

@@ -0,0 +1,73 @@
"""
Property Prediction Module
This module provides functions for predicting properties of crystal structures.
"""
import asyncio
import torch
import numpy as np
from ase.units import GPa
from mattersim.forcefield import MatterSimCalculator
from ...core.llm_tools import llm_tool
from ..support.utils import convert_structure,read_structure_from_file_name_or_content_string
@llm_tool(
name="predict_properties_MatterSim",
description="Predict energy, forces, and stress of crystal structures using MatterSim model based on CIF string",
)
async def predict_properties_MatterSim(structure_source: str) -> str:
"""
Use MatterSim model to predict energy, forces, and stress of crystal structures.
Args:
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string
Returns:
String containing prediction results
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_prediction():
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
structure_content,content_format=read_structure_from_file_name_or_content_string(structure_source)
structure = convert_structure(content_format, structure_content)
if structure is None:
return "Unable to parse CIF string. Please check if the format is correct."
# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
# 使用 MatterSimCalculator 计算属性
structure.calc = MatterSimCalculator(device=device)
# 直接获取能量、力和应力
energy = structure.get_potential_energy()
forces = structure.get_forces()
stresses = structure.get_stress(voigt=False)
# 计算每原子能量
num_atoms = len(structure)
energy_per_atom = energy / num_atoms
# 计算应力GPa和eV/A^3格式
stresses_ev_a3 = stresses
stresses_gpa = stresses / GPa
# 构建返回的提示信息
prompt = f"""
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
Prediction results using the provided CIF structure:
- Total Energy (eV): {energy}
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
"""
return prompt
# 异步执行预测操作
return await asyncio.to_thread(run_prediction)

View File

@@ -0,0 +1,42 @@
import os
from typing import List
from mp_api.client import MPRester
from ...core.config import material_config
async def get_mpid_from_formula(formula: str) -> List[str]:
"""
Get material IDs (mpid) from Materials Project database by chemical formula.
Returns mpids for the lowest energy structures.
Args:
formula: Chemical formula (e.g., "Fe2O3")
Returns:
List of material IDs
"""
os.environ['HTTP_PROXY'] = material_config.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] =material_config.HTTPS_PROXY or ''
try:
id_list = []
cleaned_formula = formula.replace(" ", "").replace("\n", "").replace("\'", "").replace("\"", "")
if "=" in cleaned_formula:
name, id = cleaned_formula.split("=")
else:
id = cleaned_formula
formula_list = [id]
with MPRester(material_config.MP_API_KEY) as mpr:
docs = mpr.materials.summary.search(formula=formula_list)
if not docs:
return "No materials found"
else:
for doc in docs:
id_list.append(doc.material_id)
return id_list
except Exception as e:
return f"Error: get_mpid_from_formula: {str(e)}"

View File

@@ -0,0 +1,168 @@
import glob
import json
from typing import Dict, Any, Union
from ...core.llm_tools import llm_tool
from .get_mp_id import get_mpid_from_formula
from ..support.utils import extract_cif_info, remove_symmetry_equiv_xyz
from ...core.config import material_config
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
@llm_tool(name="search_crystal_structures_from_materials_project",
description="Retrieve and optimize crystal structures from Materials Project database using a chemical formula")
async def search_crystal_structures_from_materials_project(
formula: str,
conventional_unit_cell: bool = True,
symprec: float = 0.1
) -> str:
"""
Retrieves crystal structures for a given chemical formula from Materials Project database and applies symmetry optimization.
Args:
formula: Chemical formula to search for (e.g., "Fe2O3")
conventional_unit_cell: If True, returns conventional unit cell; if False, returns primitive cell
symprec: Symmetry precision parameter for structure refinement (default: 0.1)
Returns:
Formatted CIF data for the retrieved crystal structures with symmetry analysis
"""
try:
structures = {}
mp_id_list = await get_mpid_from_formula(formula=formula)
if isinstance(mp_id_list, str):
return mp_id_list # 直接返回错误信息
for i, mp_id in enumerate(mp_id_list):
try:
# 文件操作可能引发异常
cif_files = glob.glob(material_config.LOCAL_MP_CIF_ROOT + f"/{mp_id}.cif")
if not cif_files:
continue # 如果没有找到文件跳过这个mp_id
cif_file = cif_files[0]
structure = Structure.from_file(cif_file)
# 结构处理可能引发异常
if conventional_unit_cell:
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
# 对结构进行对称化处理
sga = SpacegroupAnalyzer(structure, symprec=symprec)
symmetrized_structure = sga.get_refined_structure()
# 使用CifWriter生成CIF数据
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
cif_data = str(cif_writer)
# 删除CIF文件中的对称性操作部分
cif_data = remove_symmetry_equiv_xyz(cif_data)
cif_data = cif_data.replace('# generated using pymatgen', "")
# 生成一个唯一的键
formula_key = structure.composition.reduced_formula
key = f"{formula_key}_{i}"
structures[key] = cif_data
# 只保留前config.MP_TOPK个结果
if len(structures) >= material_config.MP_TOPK:
break
except (FileNotFoundError, IndexError) as file_error:
# 处理文件相关错误
continue # 跳过这个mp_id继续处理下一个
except ValueError as value_error:
# 处理结构处理中的值错误
continue # 跳过这个mp_id继续处理下一个
except Exception as process_error:
# 记录处理特定结构时的错误,但继续处理其他结构
print(f"Error processing structure {mp_id}: {str(process_error)}")
continue
# 如果没有成功处理任何结构
if not structures:
return f"No valid crystal structures found for formula: {formula}"
# 格式化结果为可读字符串
prompt = f"""
# Materials Project Symmetrized Crystal Structure Data
Below are symmetrized crystal structure data for {len(structures)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
"""
for i, (key, cif_data) in enumerate(structures.items(), 1):
prompt += f"[cif {i} begin]\n"
prompt += cif_data
prompt += f"\n[cif {i} end]\n\n"
return prompt
except Exception as e:
# 捕获整个函数执行过程中的任何未处理异常
return f"Error: An unexpected error occurred while processing crystal structures: {str(e)}"
@llm_tool(name="search_material_property_from_material_project",
description="Query material properties from Materials Project database using chemical formula")
async def search_material_property_from_materials_project(
formula: str,
) -> str:
"""
Retrieve detailed property data for materials matching a chemical formula from Materials Project database.
Args:
formula: Chemical formula of the material(s) to search for (e.g. 'Fe2O3', 'LiFePO4')
Returns:
Formatted string containing material properties including structure, electronic, thermodynamic and mechanical data
"""
# 获取MP ID列表
mp_id_list = await get_mpid_from_formula(formula=formula)
# 检查get_mpid_from_formula的返回值类型
# 如果返回的是字符串,说明发生了错误或没有找到材料
if isinstance(mp_id_list, str):
return mp_id_list # 直接返回错误信息
# 如果代码执行到这里说明mp_id_list是一个有效的ID列表
try:
# 获取材料属性
properties = []
for mp_id in mp_id_list:
try:
file_path = material_config.LOCAL_MP_PROPS_ROOT + f"/{mp_id}.json"
crystal_props = extract_cif_info(file_path, ['all_fields'])
properties.append(crystal_props)
except Exception as file_error:
# 记录单个文件处理错误但继续处理其他ID
continue
# 检查是否有结果
if len(properties) == 0:
return "No material properties found for the given formula, please try again."
# 只保留前MP_TOPK个结果
properties = properties[:material_config.MP_TOPK]
# 格式化结果
formatted_results = []
for i, item in enumerate(properties, 1):
formatted_result = f"[property {i} begin]\n"
formatted_result += json.dumps(item, indent=2)
formatted_result += f"\n[property {i} end]\n\n"
formatted_results.append(formatted_result)
# 将所有结果合并为一个字符串
res_chunk = "\n\n".join(formatted_results)
res_template = f"""
Here are the search material property from the Materials Project database:
Due to length limitations, only the top {len(properties)} results are shown below:\n
{res_chunk}
"""
return res_template
except Exception as e:
return f"Error: processing material properties: {str(e)}"

View File

@@ -0,0 +1,92 @@
import logging
import httpx
import pandas as pd
from bs4 import BeautifulSoup
from io import StringIO
from typing import Annotated, Any, Dict, List
import mcp.types as types
from ...core.llm_tools import llm_tool
@llm_tool(name="query_material_from_OQMD", description="Query material properties by chemical formula from OQMD database")
async def query_material_from_OQMD(
formula: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
) -> str:
"""
Query material information by chemical formula from OQMD database.
Args:
formula: Chemical formula of the material (e.g., Fe2O3, LiFePO4)
Returns:
Formatted text with material information and property tables
"""
# Fetch data from OQMD
url = f"https://www.oqmd.org/materials/composition/{formula}"
try:
async with httpx.AsyncClient(timeout=100.0) as client:
response = await client.get(url)
response.raise_for_status()
# Validate response content
if not response.text or len(response.text) < 100:
raise ValueError("Invalid response content from OQMD API")
# Parse HTML data
html = response.text
soup = BeautifulSoup(html, 'html.parser')
# Parse basic data
basic_data = []
h1_element = soup.find('h1')
if h1_element:
basic_data.append(h1_element.text.strip())
else:
basic_data.append(f"Material: {formula}")
for script in soup.find_all('p'):
if script:
combined_text = ""
for element in script.contents:
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
url = "https://www.oqmd.org" + element['href']
combined_text += f"[{element.text.strip()}]({url}) "
elif hasattr(element, 'text'):
combined_text += element.text.strip() + " "
else:
combined_text += str(element).strip() + " "
basic_data.append(combined_text.strip())
# Parse table data
table_data = ""
table = soup.find('table')
if table:
try:
df = pd.read_html(StringIO(str(table)))[0]
df = df.fillna('')
df = df.replace([float('inf'), float('-inf')], '')
table_data = df.to_markdown(index=False)
except Exception as e:
table_data = "Error: parsing table data"
# Integrate data into a single text
combined_text = "\n\n".join(basic_data)
if table_data:
combined_text += "\n\n## Material Properties Table\n\n" + table_data
return combined_text
except httpx.HTTPStatusError as e:
return f"Error: OQMD API request failed - {str(e)}"
except httpx.TimeoutException:
return "Error: OQMD API request timed out"
except httpx.NetworkError as e:
return f"Error: Network error occurred - {str(e)}"
except ValueError as e:
return f"Error: Invalid response content - {str(e)}"
except Exception as e:
return f"Error: Unexpected error occurred - {str(e)}"

View File

@@ -0,0 +1,95 @@
import os
import asyncio
from pymatgen.core import Structure
from ...core.config import material_config
from ...core.llm_tools import llm_tool
from ..support.utils import read_structure_from_file_name_or_content_string
@llm_tool(name="calculate_density_Pymatgen", description="Calculate the density of a crystal structure from a file or content string using Pymatgen")
async def calculate_density_Pymatgen(structure_source: str) -> str:
"""
Calculates the density of a structure from a file or content string.
Args:
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
Returns:
str: A Markdown formatted string with the density or an error message if the calculation fails.
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# # 使用pymatgen读取结构
structure = Structure.from_str(structure_content,fmt=content_format)
density = structure.density
# 删除临时文件
return (f"## Density Calculation\n\n"
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
f"- **Density**: `{density:.2f} g/cm³`\n")
except Exception as e:
return f"Error: error occurred while calculating density: {str(e)}\n"
@llm_tool(name="get_element_composition_Pymatgen", description="Analyze and retrieve the elemental composition of a crystal structure from a file or content string using Pymatgen")
async def get_element_composition_Pymatgen(structure_source: str) -> str:
"""
Returns the elemental composition of a structure from a file or content string.
Args:
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
Returns:
str: A Markdown formatted string with the elemental composition or an error message if the operation fails.
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# 使用pymatgen读取结构
structure = Structure.from_str(structure_content, fmt=content_format)
composition = structure.composition
return (f"## Element Composition\n\n"
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
f"- **Composition**: `{composition}`\n")
except Exception as e:
return f"Error: error occurred while getting element composition: {str(e)}\n"
@llm_tool(name="calculate_symmetry_Pymatgen", description="Determine the space group and symmetry operations of a crystal structure from a file or content string using Pymatgen")
async def calculate_symmetry_Pymatgen(structure_source: str) -> str:
"""
Calculates the symmetry of a structure from a file or content string.
Args:
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
Returns:
str: A Markdown formatted string with the symmetry information or an error message if the operation fails.
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
try:
# 使用read_structure_from_file_name_or_content_string函数读取结构
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
# 使用pymatgen读取结构
structure = Structure.from_str(structure_content, fmt=content_format)
symmetry = structure.get_space_group_info()
return (f"## Symmetry Information\n\n"
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
f"- **Space Group**: `{symmetry[0]}`\n"
f"- **Number**: `{symmetry[1]}`\n")
except Exception as e:
return f"Error: error occurred while calculating symmetry: {str(e)}\n"

View File

View File

@@ -0,0 +1,212 @@
"""
CIF Utilities Module
This module provides basic functions for handling CIF (Crystallographic Information File) files,
which are commonly used in materials science for representing crystal structures.
"""
import json
import logging
import os
from ase.io import read
import tempfile
from typing import Optional, Tuple
from ase import Atoms
from ...core.config import material_config
logger = logging.getLogger(__name__)
def read_cif_txt_file(file_path):
"""
Read the CIF file and return its content.
Args:
file_path: Path to the CIF file
Returns:
String content of the CIF file or None if an error occurs
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"Error reading file {file_path}: {e}")
return None
def extract_cif_info(path: str, fields_name: list):
"""
Extract specific fields from the CIF description JSON file.
Args:
path: Path to the JSON file containing CIF information
fields_name: List of field categories to extract. Use 'all_fields' to extract all fields.
Other options include 'basic_fields', 'energy_electronic_fields', 'metal_magentic_fields'
Returns:
Dictionary containing the extracted fields
"""
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
selected_fields = []
if fields_name[0] == 'all_fields':
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
else:
for field in fields_name:
selected_fields.extend(locals().get(field, []))
with open(path, 'r') as f:
docs = json.load(f)
new_docs = {}
for field_name in selected_fields:
new_docs[field_name] = docs.get(field_name, '')
return new_docs
def remove_symmetry_equiv_xyz(cif_content):
"""
Remove symmetry operations section from CIF file content.
This is often useful when working with CIF files in certain visualization tools
or when focusing on the basic structure without symmetry operations.
Args:
cif_content: CIF file content string
Returns:
Cleaned CIF content string with symmetry operations removed
"""
lines = cif_content.split('\n')
output_lines = []
i = 0
while i < len(lines):
line = lines[i].strip()
# 检测循环开始
if line == 'loop_':
# 查看下一行,检查是否是对称性循环
next_lines = []
j = i + 1
while j < len(lines) and lines[j].strip().startswith('_'):
next_lines.append(lines[j].strip())
j += 1
# 检查是否包含对称性操作标签
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
# 跳过整个循环块
while i < len(lines):
if i + 1 >= len(lines):
break
next_line = lines[i + 1].strip()
# 检查是否到达下一个循环或数据块
if next_line == 'loop_' or next_line.startswith('data_'):
break
# 检查是否到达原子位置部分
if next_line.startswith('_atom_site_'):
break
i += 1
else:
# 不是对称性循环保留loop_行
output_lines.append(lines[i])
else:
# 非循环开始行,直接保留
output_lines.append(lines[i])
i += 1
return '\n'.join(output_lines)
def read_structure_from_file_name_or_content_string(file_name_or_content_string: str, format_type: str = "auto") -> Tuple[str, str]:
"""
处理结构输入,判断是文件名还是直接内容
当file_name_or_content_string被视为文件名时会在material_config.TEMP_ROOT目录下查找该文件。
这适用于大模型生成的临时文件,这些文件通常存储在临时目录中。
Args:
file_name_or_content_string: 文件名或结构内容字符串
format_type: 结构格式类型,"auto"表示自动检测
Returns:
tuple: (内容字符串, 实际格式类型)
"""
# 首先检查是否是完整路径的文件
if os.path.exists(file_name_or_content_string) and os.path.isfile(file_name_or_content_string):
# 是完整路径文件,读取文件内容
with open(file_name_or_content_string, 'r', encoding='utf-8') as f:
content = f.read()
# 如果格式为auto从文件扩展名推断
if format_type == "auto":
ext = os.path.splitext(file_name_or_content_string)[1].lower().lstrip('.')
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
else:
# 默认假设为CIF
format_type = 'cif'
else:
# 检查是否是临时目录中的文件名
temp_path = os.path.join(material_config.TEMP_ROOT, file_name_or_content_string)
if os.path.exists(temp_path) and os.path.isfile(temp_path):
# 是临时目录中的文件,读取文件内容
with open(temp_path, 'r', encoding='utf-8') as f:
content = f.read()
# 如果格式为auto从文件扩展名推断
if format_type == "auto":
ext = os.path.splitext(temp_path)[1].lower().lstrip('.')
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
else:
# 默认假设为CIF
format_type = 'cif'
else:
# 不是文件路径,假设是直接内容
content = file_name_or_content_string
# 如果格式为auto尝试从内容推断
if format_type == "auto":
# 简单启发式判断:
# CIF文件通常包含"data_"和"_cell_"
if "data_" in content and "_cell_" in content:
format_type = "cif"
# XYZ文件通常第一行是原子数量
elif content.strip().split('\n')[0].strip().isdigit():
format_type = "xyz"
# POSCAR/VASP格式通常第一行是注释
elif len(content.strip().split('\n')) > 5 and all(len(line.split()) == 3 for line in content.strip().split('\n')[2:5]):
format_type = "vasp"
# 默认假设为CIF
else:
format_type = "cif"
return content, format_type
def convert_structure(input_format: str='cif', content: str=None) -> Optional[Atoms]:
"""
将输入内容转换为Atoms对象
Args:
input_format: 输入格式 (cif, xyz, vasp等)
content: 结构内容字符串
Returns:
ASE Atoms对象如果转换失败则返回None
"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None