初次提交
This commit is contained in:
34
sci_mcp/material_mcp/__init__.py
Executable file
34
sci_mcp/material_mcp/__init__.py
Executable file
@@ -0,0 +1,34 @@
|
||||
|
||||
# # Core modules
|
||||
# from mars_toolkit.core.config import config
|
||||
|
||||
|
||||
# # Basic tools
|
||||
# from mars_toolkit.misc.misc_tools import get_current_time
|
||||
|
||||
# # Compute modules
|
||||
# from mars_toolkit.compute.material_gen import generate_material
|
||||
# from mars_toolkit.compute.property_pred import predict_properties
|
||||
# from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
|
||||
|
||||
# # Query modules
|
||||
# from mars_toolkit.query.mp_query import (
|
||||
# search_material_property_from_material_project,
|
||||
# get_crystal_structures_from_materials_project,
|
||||
# get_mpid_from_formula
|
||||
# )
|
||||
# from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
# from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
# from mars_toolkit.query.web_search import search_online
|
||||
|
||||
# # Visualization modules
|
||||
|
||||
|
||||
# from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# __version__ = "0.1.0"
|
||||
# __all__ = ["llm_tool", "get_tools", "get_tool_schemas"]
|
||||
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Structure Optimization Module
|
||||
|
||||
This module provides functions for optimizing crystal structures using the FairChem model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from ase.io import read, write
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
from ..support.utils import convert_structure, remove_symmetry_equiv_xyz, read_structure_from_file_name_or_content_string
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FairChem模型
|
||||
calc = None
|
||||
|
||||
def init_model():
|
||||
"""初始化FairChem模型"""
|
||||
global calc
|
||||
if calc is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from fairchem.core import OCPCalculator
|
||||
calc = OCPCalculator(checkpoint_path=material_config.FAIRCHEM_MODEL_PATH)
|
||||
logger.info("FairChem model initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def generate_symmetry_cif(structure: Structure) -> str:
|
||||
"""
|
||||
生成对称性CIF
|
||||
|
||||
Args:
|
||||
structure: Pymatgen Structure对象
|
||||
|
||||
Returns:
|
||||
CIF格式的字符串
|
||||
"""
|
||||
analyzer = SpacegroupAnalyzer(structure)
|
||||
structure_refined = analyzer.get_refined_structure()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
|
||||
cif_writer.write_file(tmp_file.name)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
os.unlink(tmp_file.name)
|
||||
return content
|
||||
|
||||
def optimize_structure(atoms: Atoms, output_format: str, fmax: float = 0.05) -> str:
|
||||
"""
|
||||
优化晶体结构
|
||||
|
||||
Args:
|
||||
atoms: ASE Atoms对象
|
||||
output_format: 输出格式 (cif, xyz, vasp等)
|
||||
fmax: 力收敛标准
|
||||
|
||||
Returns:
|
||||
包含优化结果的格式化字符串
|
||||
"""
|
||||
atoms.calc = calc
|
||||
|
||||
try:
|
||||
# 捕获优化过程的输出
|
||||
temp_output = StringIO()
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = temp_output
|
||||
|
||||
# 执行优化
|
||||
dyn = FIRE(FrechetCellFilter(atoms))
|
||||
dyn.run(fmax=fmax)
|
||||
|
||||
# 恢复标准输出并获取日志
|
||||
sys.stdout = original_stdout
|
||||
optimization_log = temp_output.getvalue()
|
||||
temp_output.close()
|
||||
|
||||
# 获取总能量
|
||||
total_energy = atoms.get_potential_energy()
|
||||
|
||||
# 处理优化后的结构
|
||||
if output_format == "cif":
|
||||
optimized_structure = Structure.from_ase_atoms(atoms)
|
||||
content = generate_symmetry_cif(optimized_structure)
|
||||
content = remove_symmetry_equiv_xyz(content)
|
||||
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||||
write(tmp_file.name, atoms)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
# 格式化返回结果
|
||||
format_result = f"""
|
||||
The following is the optimized crystal structure information:
|
||||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||
**Total Energy: {total_energy} eV**
|
||||
|
||||
#### Optimizing Log:
|
||||
```text
|
||||
{optimization_log}
|
||||
```
|
||||
### Optimized {output_format.upper()} Content:
|
||||
```
|
||||
{content}
|
||||
```
|
||||
"""
|
||||
return format_result
|
||||
except Exception as e:
|
||||
return f"Error: Failed to optimize structure: {str(e)}"
|
||||
|
||||
@llm_tool(name="optimize_crystal_structure_FairChem",
|
||||
description="Optimizes crystal structures using the FairChem model")
|
||||
async def optimize_crystal_structure_FairChem(
|
||||
structure_source: str,
|
||||
format_type: str = "auto",
|
||||
optimization_level: str = "normal"
|
||||
) -> str:
|
||||
"""
|
||||
Optimizes a crystal structure to find its lowest energy configuration.
|
||||
|
||||
Args:
|
||||
structure_source: Either a file name or direct structure content (CIF, XYZ, POSCAR)
|
||||
format_type: Structure format type (auto, cif, xyz, poscar). Default "auto" will attempt to detect format.
|
||||
optimization_level: Optimization precision level (quick, normal, precise)
|
||||
|
||||
Returns:
|
||||
Optimized structure with total energy and optimization details
|
||||
"""
|
||||
# 确保模型已初始化
|
||||
if calc is None:
|
||||
init_model()
|
||||
|
||||
# 设置优化参数
|
||||
fmax_values = {
|
||||
"quick": 0.1,
|
||||
"normal": 0.05,
|
||||
"precise": 0.01
|
||||
}
|
||||
fmax = fmax_values.get(optimization_level, 0.05)
|
||||
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_optimization():
|
||||
try:
|
||||
# 处理输入结构
|
||||
content, actual_format = read_structure_from_file_name_or_content_string(structure_source, format_type)
|
||||
|
||||
# 转换格式映射
|
||||
format_mapping = {
|
||||
"cif": "cif",
|
||||
"xyz": "xyz",
|
||||
"poscar": "vasp",
|
||||
"vasp": "vasp"
|
||||
}
|
||||
final_format = format_mapping.get(actual_format.lower(), "cif")
|
||||
|
||||
# 转换结构
|
||||
atoms = convert_structure(final_format, content)
|
||||
if atoms is None:
|
||||
return f"Error: Unable to convert input structure. Please check if the format is correct."
|
||||
|
||||
# 优化结构
|
||||
return optimize_structure(atoms, final_format, fmax=fmax)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: Failed to optimize structure: {str(e)}"
|
||||
|
||||
return await asyncio.to_thread(run_optimization)
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import codecs
|
||||
import json
|
||||
import requests
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import material_config
|
||||
|
||||
@llm_tool(
|
||||
name="retrieval_from_knowledge_base",
|
||||
description="Retrieve information from local materials science literature knowledge base"
|
||||
)
|
||||
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
|
||||
"""
|
||||
检索本地材料科学文献知识库中的相关信息
|
||||
|
||||
Args:
|
||||
query: 查询字符串,如材料名称"CsPbBr3"
|
||||
topk: 返回结果数量,默认3条
|
||||
|
||||
Returns:
|
||||
包含文档ID、标题和相关性分数的字典
|
||||
"""
|
||||
# 设置Dify API的URL端点
|
||||
url = f'{material_config.DIFY_ROOT_URL}/v1/chat-messages'
|
||||
|
||||
# 配置请求头,包含API密钥和内容类型
|
||||
headers = {
|
||||
'Authorization': f'Bearer {material_config.DIFY_API_KEY}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# 准备请求数据
|
||||
data = {
|
||||
"inputs": {"topK": topk}, # 设置返回的最大结果数量
|
||||
"query": query, # 设置查询字符串
|
||||
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
|
||||
"conversation_id": "", # 不使用会话ID,每次都是独立查询
|
||||
"user": "abc-123" # 用户标识符
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送POST请求到Dify API并获取响应
|
||||
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
|
||||
response = requests.post(url, headers=headers, json=data, timeout=1111)
|
||||
|
||||
# 获取响应文本
|
||||
response_text = response.text
|
||||
|
||||
# 解码响应文本中的Unicode转义序列
|
||||
response_text = codecs.decode(response_text, 'unicode_escape')
|
||||
|
||||
# 将响应文本解析为JSON对象
|
||||
result_json = json.loads(response_text)
|
||||
|
||||
# 从响应中提取元数据
|
||||
metadata = result_json.get("metadata", {})
|
||||
|
||||
# 构建包含关键信息的结果字典
|
||||
useful_info = {
|
||||
"id": metadata.get("document_id"), # 文档ID
|
||||
"title": result_json.get("title"), # 文档标题
|
||||
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
|
||||
"score": metadata.get("score") # 相关性分数
|
||||
}
|
||||
|
||||
# 返回提取的有用信息
|
||||
return json.dumps(useful_info, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并处理所有可能的异常,返回错误信息
|
||||
return f"Error: {str(e)}"
|
||||
8
sci_mcp/material_mcp/matgl_tools/__init__.py
Normal file
8
sci_mcp/material_mcp/matgl_tools/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
MatGL Tools Module
|
||||
|
||||
This module provides tools for material structure relaxation and property prediction
|
||||
using MatGL (Materials Graph Library) models.
|
||||
"""
|
||||
|
||||
from .matgl_tools import *
|
||||
487
sci_mcp/material_mcp/matgl_tools/matgl_tools.py
Normal file
487
sci_mcp/material_mcp/matgl_tools/matgl_tools.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
MatGL Tools Module
|
||||
|
||||
This module provides tools for material structure relaxation and property prediction
|
||||
using MatGL (Materials Graph Library) models.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from ...core.config import material_config
|
||||
|
||||
|
||||
import warnings
|
||||
import json
|
||||
from typing import Dict, List, Union, Optional, Any
|
||||
|
||||
import torch
|
||||
from pymatgen.core import Lattice, Structure
|
||||
from pymatgen.ext.matproj import MPRester
|
||||
from pymatgen.io.ase import AseAtomsAdaptor
|
||||
|
||||
import matgl
|
||||
from matgl.ext.ase import Relaxer, MolecularDynamics, PESCalculator
|
||||
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
|
||||
|
||||
from ...core.llm_tools import llm_tool
|
||||
import os
|
||||
from ..support.utils import read_structure_from_file_name_or_content_string
|
||||
# To suppress warnings for clearer output
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
|
||||
@llm_tool(name="relax_crystal_structure_M3GNet",
|
||||
description="Optimize crystal structure geometry using M3GNet universal potential from a structure file or content string")
|
||||
async def relax_crystal_structure_M3GNet(
|
||||
structure_source: str,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Optimize crystal structure geometry to find its equilibrium configuration.
|
||||
|
||||
Uses M3GNet universal potential for fast and accurate structure relaxation without DFT.
|
||||
Accepts a structure file or content string.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
fmax: Maximum force tolerance for convergence in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string with the relaxation results or an error message.
|
||||
"""
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Create a relaxer and relax the structure
|
||||
relaxer = Relaxer(potential=pot)
|
||||
relax_results = relaxer.relax(structure, fmax=fmax)
|
||||
|
||||
# Get the relaxed structure
|
||||
relaxed_structure = relax_results["final_structure"]
|
||||
reduced_formula = relaxed_structure.composition.reduced_formula
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = relaxed_structure.lattice
|
||||
volume = relaxed_structure.volume
|
||||
density = relaxed_structure.density
|
||||
symmetry = relaxed_structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(relaxed_structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Structure Relaxation\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Force Tolerance**: `{fmax} eV/Å`\n"
|
||||
f"- **Status**: `Successfully relaxed`\n\n"
|
||||
f"### Relaxed Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {relaxed_structure.lattice.pbc[0]!s:5s} {relaxed_structure.lattice.pbc[1]!s:5s} {relaxed_structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(relaxed_structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error during structure relaxation: {str(e)}"
|
||||
|
||||
|
||||
# 内部函数,用于结构优化,返回结构对象而不是格式化字符串
|
||||
async def _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source: str,
|
||||
fmax: float = 0.01
|
||||
) -> Union[Structure, str]:
|
||||
"""
|
||||
内部使用的结构优化函数,返回结构对象而不是格式化字符串。
|
||||
|
||||
Args:
|
||||
structure_source: 结构文件名或内容字符串
|
||||
fmax: 力收敛阈值 (eV/Å)
|
||||
|
||||
Returns:
|
||||
优化后的结构对象或错误信息
|
||||
"""
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Create a relaxer and relax the structure
|
||||
relaxer = Relaxer(potential=pot)
|
||||
relax_results = relaxer.relax(structure, fmax=fmax)
|
||||
|
||||
# Get the relaxed structure
|
||||
relaxed_structure = relax_results["final_structure"]
|
||||
|
||||
return relaxed_structure
|
||||
except Exception as e:
|
||||
return f"Error during structure relaxation: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="predict_formation_energy_M3GNet",
|
||||
description="Predict the formation energy of a crystal structure using the M3GNet formation energy model from a structure file or content string, with optional structure optimization")
|
||||
async def predict_formation_energy_M3GNet(
|
||||
structure_source: str,
|
||||
optimize_structure: bool = True,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Predict the formation energy of a crystal structure using the M3GNet formation energy model.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
optimize_structure: Whether to optimize the structure before prediction (default: True).
|
||||
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string containing the predicted formation energy in eV/atom or an error message.
|
||||
"""
|
||||
try:
|
||||
# 获取结构(优化或不优化)
|
||||
if optimize_structure:
|
||||
# 使用内部函数优化结构
|
||||
structure = await _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source=structure_source,
|
||||
fmax=fmax
|
||||
)
|
||||
|
||||
# 检查优化是否成功
|
||||
if isinstance(structure, str) and structure.startswith("Error"):
|
||||
return structure
|
||||
else:
|
||||
# 直接读取结构,不进行优化
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# 加载预训练模型
|
||||
model = matgl.load_model("M3GNet-MP-2018.6.1-Eform")
|
||||
|
||||
# 预测形成能
|
||||
eform = model.predict_structure(structure)
|
||||
reduced_formula = structure.composition.reduced_formula
|
||||
|
||||
# 构建结果字符串
|
||||
optimization_status = "optimized" if optimize_structure else "non-optimized"
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = structure.lattice
|
||||
volume = structure.volume
|
||||
density = structure.density
|
||||
symmetry = structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Formation Energy Prediction\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Structure Status**: `{optimization_status}`\n"
|
||||
f"- **Formation Energy**: `{float(eform):.3f} eV/atom`\n\n"
|
||||
f"### Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="run_molecular_dynamics_M3GNet",
|
||||
description="Run molecular dynamics simulation on a crystal structure using M3GNet universal potential, with optional structure optimization")
|
||||
async def run_molecular_dynamics_M3GNet(
|
||||
structure_source: str,
|
||||
temperature_K: float = 300,
|
||||
steps: int = 100,
|
||||
optimize_structure: bool = True,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Run molecular dynamics simulation on a crystal structure using M3GNet universal potential.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
temperature_K: Temperature for MD simulation in Kelvin (default: 300).
|
||||
steps: Number of MD steps to run (default: 100).
|
||||
optimize_structure: Whether to optimize the structure before simulation (default: True).
|
||||
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string containing the simulation results, including final potential energy.
|
||||
"""
|
||||
try:
|
||||
# 获取结构(优化或不优化)
|
||||
if optimize_structure:
|
||||
# 使用内部函数优化结构
|
||||
structure = await _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source=structure_source,
|
||||
fmax=fmax
|
||||
)
|
||||
|
||||
# 检查优化是否成功
|
||||
if isinstance(structure, str) and structure.startswith("Error"):
|
||||
return structure
|
||||
else:
|
||||
# 直接读取结构,不进行优化
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Convert pymatgen structure to ASE atoms
|
||||
ase_adaptor = AseAtomsAdaptor()
|
||||
atoms = ase_adaptor.get_atoms(structure)
|
||||
|
||||
# Initialize the velocity according to Maxwell Boltzmann distribution
|
||||
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
|
||||
|
||||
# Create the MD class and run simulation
|
||||
driver = MolecularDynamics(atoms, potential=pot, temperature=temperature_K)
|
||||
driver.run(steps)
|
||||
|
||||
# Get final potential energy
|
||||
final_energy = atoms.get_potential_energy()
|
||||
|
||||
# Get final structure
|
||||
final_structure = ase_adaptor.get_structure(atoms)
|
||||
reduced_formula = final_structure.composition.reduced_formula
|
||||
|
||||
# 构建结果字符串
|
||||
optimization_status = "optimized" if optimize_structure else "non-optimized"
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = final_structure.lattice
|
||||
volume = final_structure.volume
|
||||
density = final_structure.density
|
||||
symmetry = final_structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(final_structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Molecular Dynamics Simulation\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Structure Status**: `{optimization_status}`\n"
|
||||
f"- **Temperature**: `{temperature_K} K`\n"
|
||||
f"- **Steps**: `{steps}`\n"
|
||||
f"- **Final Potential Energy**: `{float(final_energy):.3f} eV`\n\n"
|
||||
f"### Final Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {final_structure.lattice.pbc[0]!s:5s} {final_structure.lattice.pbc[1]!s:5s} {final_structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(final_structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@llm_tool(name="calculate_single_point_energy_M3GNet",
|
||||
description="Calculate single point energy of a crystal structure using M3GNet universal potential, with optional structure optimization")
|
||||
async def calculate_single_point_energy_M3GNet(
|
||||
structure_source: str,
|
||||
optimize_structure: bool = True,
|
||||
fmax: float = 0.01
|
||||
) -> str:
|
||||
"""
|
||||
Calculate single point energy of a crystal structure using M3GNet universal potential.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
optimize_structure: Whether to optimize the structure before calculation (default: True).
|
||||
fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
Returns:
|
||||
A Markdown formatted string containing the calculated potential energy in eV.
|
||||
"""
|
||||
try:
|
||||
# 获取结构(优化或不优化)
|
||||
if optimize_structure:
|
||||
# 使用内部函数优化结构
|
||||
structure = await _relax_crystal_structure_M3GNet_internal(
|
||||
structure_source=structure_source,
|
||||
fmax=fmax
|
||||
)
|
||||
|
||||
# 检查优化是否成功
|
||||
if isinstance(structure, str) and structure.startswith("Error"):
|
||||
return structure
|
||||
else:
|
||||
# 直接读取结构,不进行优化
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
|
||||
if structure is None:
|
||||
return "Error: Failed to obtain a valid structure"
|
||||
|
||||
# Load the M3GNet universal potential model
|
||||
pot = matgl.load_model("M3GNet-MP-2021.2.8-PES")
|
||||
|
||||
# Convert pymatgen structure to ASE atoms
|
||||
ase_adaptor = AseAtomsAdaptor()
|
||||
atoms = ase_adaptor.get_atoms(structure)
|
||||
|
||||
# Set up the calculator for atoms object
|
||||
calc = PESCalculator(pot)
|
||||
atoms.set_calculator(calc)
|
||||
|
||||
# Calculate potential energy
|
||||
energy = atoms.get_potential_energy()
|
||||
reduced_formula = structure.composition.reduced_formula
|
||||
|
||||
# 构建结果字符串
|
||||
optimization_status = "optimized" if optimize_structure else "non-optimized"
|
||||
|
||||
# 添加结构信息
|
||||
lattice_info = structure.lattice
|
||||
volume = structure.volume
|
||||
density = structure.density
|
||||
symmetry = structure.get_space_group_info()
|
||||
|
||||
# 构建原子位置表格
|
||||
sites_table = " # SP a b c\n"
|
||||
sites_table += "--- ---- -------- -------- --------\n"
|
||||
for i, site in enumerate(structure):
|
||||
frac_coords = site.frac_coords
|
||||
sites_table += f"{i:3d} {site.species_string:4s} {frac_coords[0]:8.6f} {frac_coords[1]:8.6f} {frac_coords[2]:8.6f}\n"
|
||||
|
||||
return (f"## Single Point Energy Calculation\n\n"
|
||||
f"- **Structure**: `{reduced_formula}`\n"
|
||||
f"- **Structure Status**: `{optimization_status}`\n"
|
||||
f"- **Potential Energy**: `{float(energy):.3f} eV`\n\n"
|
||||
f"### Structure Information\n\n"
|
||||
f"- **Space Group**: `{symmetry[0]} (#{symmetry[1]})`\n"
|
||||
f"- **Volume**: `{volume:.2f} ų`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n"
|
||||
f"- **Lattice Parameters**:\n"
|
||||
f" - a = `{lattice_info.a:.6f} Å`, b = `{lattice_info.b:.6f} Å`, c = `{lattice_info.c:.6f} Å`\n"
|
||||
f" - α = `{lattice_info.alpha:.6f}°`, β = `{lattice_info.beta:.6f}°`, γ = `{lattice_info.gamma:.6f}°`\n\n"
|
||||
f"### Atomic Positions (Fractional Coordinates)\n\n"
|
||||
f"```\n"
|
||||
f"abc : {lattice_info.a:.6f} {lattice_info.b:.6f} {lattice_info.c:.6f}\n"
|
||||
f"angles: {lattice_info.alpha:.6f} {lattice_info.beta:.6f} {lattice_info.gamma:.6f}\n"
|
||||
f"pbc : {structure.lattice.pbc[0]!s:5s} {structure.lattice.pbc[1]!s:5s} {structure.lattice.pbc[2]!s:5s}\n"
|
||||
f"Sites ({len(structure)})\n"
|
||||
f"{sites_table}```\n")
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
#Error: Bad serialized model or bad model name. It is possible that you have an older model cached. Please clear your cache by running `python -c "import matgl; matgl.clear_cache()"`
|
||||
# @llm_tool(name="predict_band_gap",
|
||||
# description="Predict the band gap of a crystal structure using MEGNet multi-fidelity model from either a chemical formula or CIF file, with structure optimization")
|
||||
# async def predict_band_gap(
|
||||
# formula: str = None,
|
||||
# cif_file_name: str = None,
|
||||
# method: str = "PBE",
|
||||
# fmax: float = 0.01
|
||||
# ) -> str:
|
||||
# """
|
||||
# Predict the band gap of a crystal structure using the MEGNet multi-fidelity band gap model.
|
||||
|
||||
# First optimizes the crystal structure using M3GNet universal potential, then predicts
|
||||
# the band gap based on the relaxed structure for more accurate results.
|
||||
|
||||
# Accepts either a chemical formula (searches Materials Project database) or a CIF file.
|
||||
|
||||
# Args:
|
||||
# formula: Chemical formula to retrieve from Materials Project (e.g., "Fe2O3").
|
||||
# cif_file_name: Name of CIF file in temp directory to use as structure source.
|
||||
# method: The DFT method to use for the prediction. Options are "PBE", "GLLB-SC", "HSE", or "SCAN".
|
||||
# Default is "PBE".
|
||||
# fmax: Maximum force tolerance for structure relaxation in eV/Å (default: 0.01).
|
||||
|
||||
# Returns:
|
||||
# A string containing the predicted band gap in eV or an error message.
|
||||
# """
|
||||
# try:
|
||||
# # First, relax the crystal structure
|
||||
# relaxed_result = await relax_crystal_structure(
|
||||
# formula=formula,
|
||||
# cif_file_name=cif_file_name,
|
||||
# fmax=fmax
|
||||
# )
|
||||
|
||||
# # Check if relaxation was successful
|
||||
# if isinstance(relaxed_result, str) and relaxed_result.startswith("Error"):
|
||||
# return relaxed_result
|
||||
|
||||
# # Use the relaxed structure for band gap prediction
|
||||
# structure = relaxed_result
|
||||
|
||||
# if structure is None:
|
||||
# return "Error: Failed to obtain a valid relaxed structure"
|
||||
|
||||
# # Load the pre-trained MEGNet band gap model
|
||||
# model = matgl.load_model("MEGNet-MP-2019.4.1-BandGap-mfi")
|
||||
|
||||
# # Map method name to index
|
||||
# method_map = {"PBE": 0, "GLLB-SC": 1, "HSE": 2, "SCAN": 3}
|
||||
# if method not in method_map:
|
||||
# return f"Error: Unsupported method: {method}. Choose from PBE, GLLB-SC, HSE, or SCAN."
|
||||
|
||||
# # Set the graph label based on the method
|
||||
# graph_attrs = torch.tensor([method_map[method]])
|
||||
|
||||
# # Predict the band gap using the relaxed structure
|
||||
# bandgap = model.predict_structure(structure=structure, state_attr=graph_attrs)
|
||||
# reduced_formula = structure.reduced_formula
|
||||
|
||||
# # Return the band gap as a string
|
||||
# return f"The predicted band gap for relaxed {reduced_formula} using {method} method is {float(bandgap):.3f} eV."
|
||||
# except Exception as e:
|
||||
# return f"Error: {str(e)}"
|
||||
|
||||
|
||||
240
sci_mcp/material_mcp/mattergen_gen/material_gen_tools.py
Executable file
240
sci_mcp/material_mcp/mattergen_gen/material_gen_tools.py
Executable file
@@ -0,0 +1,240 @@
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
import tempfile
|
||||
import os
|
||||
import datetime
|
||||
import asyncio
|
||||
import zipfile
|
||||
import shutil
|
||||
import re
|
||||
import multiprocessing
|
||||
from multiprocessing import Process, Queue
|
||||
from pathlib import Path
|
||||
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
||||
import logging
|
||||
# 设置多进程启动方法为spawn,解决CUDA初始化错误
|
||||
try:
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
except RuntimeError:
|
||||
# 如果已经设置过启动方法,会抛出RuntimeError
|
||||
pass
|
||||
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from ase.io import read, write
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
# 导入路径已更新
|
||||
from ...core.llm_tools import llm_tool
|
||||
from .mattergen_wrapper import *
|
||||
|
||||
# 使用mattergen_wrapper
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
|
||||
def convert_values(data_str):
|
||||
"""
|
||||
将字符串转换为字典
|
||||
|
||||
Args:
|
||||
data_str: JSON字符串
|
||||
|
||||
Returns:
|
||||
解析后的数据,如果解析失败则返回原字符串
|
||||
"""
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return data_str # 如果无法解析为JSON,返回原字符串
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
|
||||
"""
|
||||
Preprocess a property value based on its name, converting it to the appropriate type.
|
||||
|
||||
Args:
|
||||
property_name: Name of the property
|
||||
property_value: Value of the property (can be string, float, or int)
|
||||
|
||||
Returns:
|
||||
Tuple of (property_name, processed_value)
|
||||
|
||||
Raises:
|
||||
ValueError: If the property value is invalid for the given property name
|
||||
"""
|
||||
valid_properties = [
|
||||
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
|
||||
"energy_above_hull", "formation_energy_per_atom", "space_group",
|
||||
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
|
||||
]
|
||||
|
||||
if property_name not in valid_properties:
|
||||
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
|
||||
|
||||
# Process property_value if it's a string
|
||||
if isinstance(property_value, str):
|
||||
try:
|
||||
# Try to convert string to float for numeric properties
|
||||
if property_name != "chemical_system":
|
||||
property_value = float(property_value)
|
||||
except ValueError:
|
||||
# If conversion fails, keep as string (for chemical_system)
|
||||
pass
|
||||
|
||||
# Handle special cases for properties that need specific types
|
||||
if property_name == "chemical_system":
|
||||
if isinstance(property_value, (int, float)):
|
||||
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
|
||||
property_value = str(property_value)
|
||||
elif property_name == "space_group" :
|
||||
space_group = property_value
|
||||
if space_group < 1 or space_group > 230:
|
||||
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
|
||||
|
||||
return property_name, property_value
|
||||
|
||||
|
||||
def main(
|
||||
output_path: str,
|
||||
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
|
||||
model_path: str | None = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
config_overrides: list[str] | None = None,
|
||||
checkpoint_epoch: Literal["best", "last"] | int = "last",
|
||||
properties_to_condition_on: TargetProperty | None = None,
|
||||
sampling_config_path: str | None = None,
|
||||
sampling_config_name: str = "default",
|
||||
sampling_config_overrides: list[str] | None = None,
|
||||
record_trajectories: bool = True,
|
||||
diffusion_guidance_factor: float | None = None,
|
||||
strict_checkpoint_loading: bool = True,
|
||||
target_compositions: list[dict[str, int]] | None = None,
|
||||
):
|
||||
"""
|
||||
Evaluate diffusion model against molecular metrics.
|
||||
|
||||
Args:
|
||||
model_path: Path to DiffusionLightningModule checkpoint directory.
|
||||
output_path: Path to output directory.
|
||||
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
|
||||
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
|
||||
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
|
||||
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
|
||||
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
|
||||
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
|
||||
record: Whether to record the trajectories of the generated structures. (default: True)
|
||||
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
|
||||
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
|
||||
Only supported for models trained for crystal structure prediction (CSP) (default: None)
|
||||
|
||||
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
|
||||
"""
|
||||
assert (
|
||||
pretrained_name is not None or model_path is not None
|
||||
), "Either pretrained_name or model_path must be provided."
|
||||
assert (
|
||||
pretrained_name is None or model_path is None
|
||||
), "Only one of pretrained_name or model_path can be provided."
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
sampling_config_overrides = sampling_config_overrides or []
|
||||
config_overrides = config_overrides or []
|
||||
properties_to_condition_on = properties_to_condition_on or {}
|
||||
target_compositions = target_compositions or []
|
||||
|
||||
if pretrained_name is not None:
|
||||
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
|
||||
pretrained_name, config_overrides=config_overrides
|
||||
)
|
||||
else:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch=checkpoint_epoch,
|
||||
config_overrides=config_overrides,
|
||||
strict_checkpoint_loading=strict_checkpoint_loading,
|
||||
)
|
||||
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name=sampling_config_name,
|
||||
sampling_config_path=_sampling_config_path,
|
||||
sampling_config_overrides=sampling_config_overrides,
|
||||
record_trajectories=record_trajectories,
|
||||
diffusion_guidance_factor=(
|
||||
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
|
||||
),
|
||||
target_compositions_dict=target_compositions,
|
||||
)
|
||||
generator.generate(output_dir=Path(output_path))
|
||||
|
||||
|
||||
@llm_tool(name="generate_material_MatterGen", description="Generate crystal structures with optional property constraints using MatterGen model")
|
||||
def generate_material_MatterGen(
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
This unified function can generate materials in three modes:
|
||||
1. Unconditional generation (no properties specified)
|
||||
2. Single property conditional generation (one property specified)
|
||||
3. Multi-property conditional generation (multiple properties specified)
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints. Can be:
|
||||
- None or empty dict for unconditional generation
|
||||
- Dict with single key-value pair for single property conditioning
|
||||
- Dict with multiple key-value pairs for multi-property conditioning
|
||||
Valid property names include: "dft_band_gap", "chemical_system", etc.
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
|
||||
|
||||
# 导入MatterGenService
|
||||
from .mattergen_service import MatterGenService
|
||||
logger.info("子进程成功导入MatterGenService")
|
||||
|
||||
# 获取MatterGenService实例
|
||||
service = MatterGenService.get_instance()
|
||||
logger.info("子进程成功获取MatterGenService实例")
|
||||
|
||||
# 使用服务生成材料
|
||||
logger.info("子进程开始调用generate方法...")
|
||||
result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
|
||||
logger.info("子进程generate方法调用完成")
|
||||
if "Error generating structures" in result:
|
||||
return f"Error: Invalid properties {properties}."
|
||||
else:
|
||||
return result
|
||||
466
sci_mcp/material_mcp/mattergen_gen/mattergen_service.py
Executable file
466
sci_mcp/material_mcp/mattergen_gen/mattergen_service.py
Executable file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
MatterGen service for mars_toolkit.
|
||||
|
||||
This module provides a service for generating crystal structures using MatterGen.
|
||||
The service initializes the CrystalGenerator once and reuses it for multiple
|
||||
generation requests, improving performance.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
import threading
|
||||
import torch
|
||||
|
||||
from .mattergen_wrapper import *
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
def format_cif_content(content):
|
||||
"""
|
||||
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
||||
|
||||
Args:
|
||||
content: String containing CIF content, possibly with PK headers
|
||||
|
||||
Returns:
|
||||
Formatted string with each CIF file properly labeled and formatted
|
||||
"""
|
||||
# 如果内容为空,直接返回空字符串
|
||||
if not content or content.strip() == '':
|
||||
return ''
|
||||
|
||||
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
|
||||
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
|
||||
|
||||
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
|
||||
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
|
||||
|
||||
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
|
||||
# 但我们需要保留这个字段在每个CIF文件中
|
||||
cif_blocks = []
|
||||
|
||||
# 查找所有_chemical_formula_structural的位置
|
||||
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
|
||||
|
||||
# 如果没有找到任何_chemical_formula_structural,返回空字符串
|
||||
if not formula_positions:
|
||||
return ''
|
||||
|
||||
# 分割CIF块
|
||||
for i in range(len(formula_positions)):
|
||||
start_pos = formula_positions[i]
|
||||
# 如果是最后一个块,结束位置是字符串末尾
|
||||
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
|
||||
|
||||
cif_block = content[start_pos:end_pos].strip()
|
||||
|
||||
# 提取formula值
|
||||
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
|
||||
if formula_match:
|
||||
formula = formula_match.group(1)
|
||||
cif_blocks.append((formula, cif_block))
|
||||
|
||||
# 格式化输出
|
||||
result = []
|
||||
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
|
||||
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]\n"
|
||||
result.append(formatted)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def extract_cif_file_from_zip(cifs_zip_path: str):
|
||||
"""
|
||||
Extract CIF files from a zip archive, extract formula from each CIF file,
|
||||
and save each CIF file with its formula as the filename.
|
||||
|
||||
Args:
|
||||
cifs_zip_path: Path to the zip file
|
||||
|
||||
Returns:
|
||||
list: List of tuples containing (index, formula, cif_path)
|
||||
"""
|
||||
result_dict = {}
|
||||
if os.path.exists(cifs_zip_path):
|
||||
with open(cifs_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
cifs_content = format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))
|
||||
pattern = r'\[cif (\d+) begin\]\n(.*?)\n\[cif \1 end\]'
|
||||
matches = re.findall(pattern, cifs_content, re.DOTALL)
|
||||
|
||||
# 处理每个匹配项,提取formula并保存CIF文件
|
||||
saved_files = []
|
||||
for idx, cif_content in matches:
|
||||
# 提取data_{formula}中的formula
|
||||
formula_match = re.search(r'data_([^\s]+)', cif_content)
|
||||
if formula_match:
|
||||
formula = formula_match.group(1)
|
||||
# 构建保存路径
|
||||
cif_path = os.path.join(material_config.TEMP_ROOT, f"{formula}.cif")
|
||||
# 保存CIF文件
|
||||
with open(cif_path, 'w') as f:
|
||||
f.write(cif_content)
|
||||
saved_files.append((idx, formula, cif_path))
|
||||
|
||||
return saved_files
|
||||
|
||||
|
||||
class MatterGenService:
|
||||
"""
|
||||
Service for generating crystal structures using MatterGen.
|
||||
|
||||
This service initializes the CrystalGenerator once and reuses it for multiple
|
||||
generation requests, improving performance.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
# 模型到GPU ID的映射
|
||||
MODEL_TO_GPU = {
|
||||
"mattergen_base": "0", # 基础模型使用GPU 0
|
||||
"dft_mag_density": "1", # 磁密度模型使用GPU 1
|
||||
"dft_bulk_modulus": "2", # 体积模量模型使用GPU 2
|
||||
"dft_shear_modulus": "3", # 剪切模量模型使用GPU 3
|
||||
"energy_above_hull": "4", # 能量模型使用GPU 4
|
||||
"formation_energy_per_atom": "5", # 形成能模型使用GPU 5
|
||||
"space_group": "6", # 空间群模型使用GPU 6
|
||||
"hhi_score": "7", # HHI评分模型使用GPU 7
|
||||
"ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0
|
||||
"chemical_system": "1", # 化学系统模型使用GPU 1
|
||||
"dft_band_gap": "2", # 带隙模型使用GPU 2
|
||||
"dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3
|
||||
"chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""
|
||||
Get the singleton instance of MatterGenService.
|
||||
|
||||
Returns:
|
||||
MatterGenService: The singleton instance.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the MatterGenService.
|
||||
|
||||
This initializes the base generator without any property conditioning.
|
||||
Specific generators for different property conditions will be initialized
|
||||
on demand.
|
||||
"""
|
||||
self._generators = {}
|
||||
self._output_dir = material_config.TEMP_ROOT
|
||||
|
||||
# 确保输出目录存在
|
||||
if not os.path.exists(self._output_dir):
|
||||
os.makedirs(self._output_dir)
|
||||
|
||||
# 初始化基础生成器(无条件生成)
|
||||
self._init_base_generator()
|
||||
|
||||
def _init_base_generator(self):
|
||||
"""
|
||||
Initialize the base generator for unconditional generation.
|
||||
"""
|
||||
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.")
|
||||
return
|
||||
|
||||
logger.info(f"Initializing base MatterGen generator from {model_path}")
|
||||
|
||||
try:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch="last",
|
||||
config_overrides=[],
|
||||
strict_checkpoint_loading=True,
|
||||
)
|
||||
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=None,
|
||||
batch_size=2, # 默认值,可在生成时覆盖
|
||||
num_batches=1, # 默认值,可在生成时覆盖
|
||||
sampling_config_name="default",
|
||||
sampling_config_path=None,
|
||||
sampling_config_overrides=[],
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=0.0,
|
||||
target_compositions_dict=[],
|
||||
)
|
||||
|
||||
self._generators["base"] = generator
|
||||
logger.info("Base MatterGen generator initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize base MatterGen generator: {e}")
|
||||
|
||||
def _get_or_create_generator(
|
||||
self,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
):
|
||||
"""
|
||||
Get or create a generator for the specified properties.
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
tuple: (generator, generator_key, properties_to_condition_on, gpu_id)
|
||||
"""
|
||||
# 如果没有属性约束,使用基础生成器
|
||||
if not properties:
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0
|
||||
return self._generators.get("base"), "base", None, gpu_id
|
||||
|
||||
# 处理属性约束
|
||||
properties_to_condition_on = {}
|
||||
for property_name, property_value in properties.items():
|
||||
properties_to_condition_on[property_name] = property_value
|
||||
|
||||
# 确定模型目录
|
||||
if len(properties) == 1:
|
||||
# 单属性条件
|
||||
property_name = list(properties.keys())[0]
|
||||
property_to_model = {
|
||||
"dft_mag_density": "dft_mag_density",
|
||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||
"dft_shear_modulus": "dft_shear_modulus",
|
||||
"energy_above_hull": "energy_above_hull",
|
||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||
"space_group": "space_group",
|
||||
"hhi_score": "hhi_score",
|
||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||
"chemical_system": "chemical_system",
|
||||
"dft_band_gap": "dft_band_gap"
|
||||
}
|
||||
model_dir = property_to_model.get(property_name, property_name)
|
||||
generator_key = f"single_{property_name}"
|
||||
else:
|
||||
# 多属性条件
|
||||
property_keys = set(properties.keys())
|
||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||
model_dir = "dft_mag_density_hhi_score"
|
||||
generator_key = "multi_dft_mag_density_hhi_score"
|
||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||
model_dir = "chemical_system_energy_above_hull"
|
||||
generator_key = "multi_chemical_system_energy_above_hull"
|
||||
else:
|
||||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||||
first_property = list(properties.keys())[0]
|
||||
model_dir = first_property
|
||||
generator_key = f"multi_{first_property}_etc"
|
||||
|
||||
# 获取对应的GPU ID
|
||||
gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0
|
||||
|
||||
# 构建完整的模型路径
|
||||
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
# 检查模型目录是否存在
|
||||
if not os.path.exists(model_path):
|
||||
# 如果特定模型不存在,回退到基础模型
|
||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||
model_path = os.path.join(material_config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
generator_key = "base"
|
||||
|
||||
# 检查是否已经有这个生成器
|
||||
if generator_key in self._generators:
|
||||
# 更新生成器的参数
|
||||
generator = self._generators[generator_key]
|
||||
generator.batch_size = batch_size
|
||||
generator.num_batches = num_batches
|
||||
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
|
||||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||||
|
||||
# 创建新的生成器
|
||||
try:
|
||||
logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}")
|
||||
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch="last",
|
||||
config_overrides=[],
|
||||
strict_checkpoint_loading=True,
|
||||
)
|
||||
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name="default",
|
||||
sampling_config_path=None,
|
||||
sampling_config_overrides=[],
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0,
|
||||
target_compositions_dict=[],
|
||||
)
|
||||
|
||||
self._generators[generator_key] = generator
|
||||
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
|
||||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
|
||||
# 回退到基础生成器
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0")
|
||||
return self._generators.get("base"), "base", None, base_gpu_id
|
||||
|
||||
def generate(
|
||||
self,
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
str: Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
|
||||
|
||||
# 处理字符串输入(如果提供)
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||
|
||||
# 如果为None,默认为空字典
|
||||
properties = properties or {}
|
||||
|
||||
# 获取或创建生成器和GPU ID
|
||||
generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator(
|
||||
properties, batch_size, num_batches, diffusion_guidance_factor
|
||||
)
|
||||
print("gpu_id",gpu_id)
|
||||
if generator is None:
|
||||
return "Error: Failed to initialize MatterGen generator"
|
||||
|
||||
# 使用torch.cuda.set_device()直接设置当前GPU
|
||||
try:
|
||||
# 将字符串类型的gpu_id转换为整数
|
||||
cuda_device_id = int(gpu_id)
|
||||
torch.cuda.set_device(cuda_device_id)
|
||||
logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}")
|
||||
print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.")
|
||||
|
||||
# 生成结构
|
||||
try:
|
||||
|
||||
output_dir= Path(self._output_dir+f'/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}')
|
||||
Path.mkdir(output_dir, parents=True, exist_ok=True)
|
||||
generator.generate(output_dir=output_dir)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating structures: {e}")
|
||||
return f"Error generating structures: {e}"
|
||||
|
||||
# 创建字典存储文件内容
|
||||
result_dict = {}
|
||||
|
||||
# 定义文件路径
|
||||
cif_zip_path = os.path.join(str(output_dir), f"generated_crystals_cif.zip")
|
||||
xyz_file_path = os.path.join(str(output_dir), f"generated_crystals.extxyz")
|
||||
trajectories_zip_path = os.path.join(str(output_dir), f"generated_trajectories.zip")
|
||||
|
||||
# 读取CIF压缩文件
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
|
||||
# 根据生成类型创建描述性提示
|
||||
if not properties:
|
||||
generation_type = "unconditional"
|
||||
title = "Generated Material Structures"
|
||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||
property_description = "unconditionally"
|
||||
elif len(properties) == 1:
|
||||
generation_type = "single_property"
|
||||
property_name = list(properties.keys())[0]
|
||||
property_value = properties[property_name]
|
||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||
property_description = f"conditioned on {property_name} = {property_value}"
|
||||
else:
|
||||
generation_type = "multi_property"
|
||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||
|
||||
# 创建完整的提示
|
||||
prompt = f"""
|
||||
# {title}
|
||||
|
||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||
|
||||
{'' if generation_type == 'unconditional' else f'''
|
||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||
the generation adheres to the specified property values. Higher values produce samples that more
|
||||
closely match the target properties but may reduce diversity.
|
||||
'''}
|
||||
|
||||
## CIF Files (Crystallographic Information Files)
|
||||
|
||||
- Standard format for crystallographic structures
|
||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||
- Used by crystallographic software and visualization tools
|
||||
|
||||
```
|
||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||
```
|
||||
|
||||
{description}
|
||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||
"""
|
||||
# print("prompt",prompt)
|
||||
# 清理文件(读取后删除)
|
||||
# try:
|
||||
# if os.path.exists(cif_zip_path):
|
||||
# os.remove(cif_zip_path)
|
||||
# if os.path.exists(xyz_file_path):
|
||||
# os.remove(xyz_file_path)
|
||||
# if os.path.exists(trajectories_zip_path):
|
||||
# os.remove(trajectories_zip_path)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error cleaning up files: {e}")
|
||||
|
||||
# GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理
|
||||
logger.info(f"Generation completed on GPU for model {generator_key}")
|
||||
|
||||
return prompt
|
||||
26
sci_mcp/material_mcp/mattergen_gen/mattergen_wrapper.py
Executable file
26
sci_mcp/material_mcp/mattergen_gen/mattergen_wrapper.py
Executable file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
This is a wrapper module that provides access to the mattergen modules
|
||||
by modifying the Python path at runtime.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from ...core.config import material_config
|
||||
# Add the mattergen directory to the Python path
|
||||
mattergen_dir = material_config.MATTERGEN_ROOT
|
||||
sys.path.insert(0, mattergen_dir)
|
||||
|
||||
# Import the necessary modules from the mattergen package
|
||||
try:
|
||||
from mattergen import generator
|
||||
from mattergen.common.data import chemgraph
|
||||
from mattergen.common.data.types import TargetProperty
|
||||
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||||
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
|
||||
except ImportError as e:
|
||||
print(f"Error importing mattergen modules: {e}")
|
||||
print(f"Python path: {sys.path}")
|
||||
raise
|
||||
CrystalGenerator = generator.CrystalGenerator
|
||||
# Re-export the modules
|
||||
__all__ = ['generator', 'chemgraph', 'TargetProperty', 'MatterGenCheckpointInfo', 'PRETRAINED_MODEL_NAME','CrystalGenerator']
|
||||
73
sci_mcp/material_mcp/mattersim_pred/property_pred_tools.py
Normal file
73
sci_mcp/material_mcp/mattersim_pred/property_pred_tools.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Property Prediction Module
|
||||
|
||||
This module provides functions for predicting properties of crystal structures.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
import numpy as np
|
||||
from ase.units import GPa
|
||||
from mattersim.forcefield import MatterSimCalculator
|
||||
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ..support.utils import convert_structure,read_structure_from_file_name_or_content_string
|
||||
|
||||
@llm_tool(
|
||||
name="predict_properties_MatterSim",
|
||||
description="Predict energy, forces, and stress of crystal structures using MatterSim model based on CIF string",
|
||||
)
|
||||
async def predict_properties_MatterSim(structure_source: str) -> str:
|
||||
"""
|
||||
Use MatterSim model to predict energy, forces, and stress of crystal structures.
|
||||
|
||||
Args:
|
||||
structure_source: The name of the structure file (e.g., POSCAR, CIF) or the content string
|
||||
|
||||
Returns:
|
||||
String containing prediction results
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_prediction():
|
||||
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
|
||||
structure_content,content_format=read_structure_from_file_name_or_content_string(structure_source)
|
||||
structure = convert_structure(content_format, structure_content)
|
||||
if structure is None:
|
||||
return "Unable to parse CIF string. Please check if the format is correct."
|
||||
|
||||
# 设置设备
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 使用 MatterSimCalculator 计算属性
|
||||
structure.calc = MatterSimCalculator(device=device)
|
||||
|
||||
# 直接获取能量、力和应力
|
||||
energy = structure.get_potential_energy()
|
||||
forces = structure.get_forces()
|
||||
stresses = structure.get_stress(voigt=False)
|
||||
|
||||
# 计算每原子能量
|
||||
num_atoms = len(structure)
|
||||
energy_per_atom = energy / num_atoms
|
||||
|
||||
# 计算应力(GPa和eV/A^3格式)
|
||||
stresses_ev_a3 = stresses
|
||||
stresses_gpa = stresses / GPa
|
||||
|
||||
# 构建返回的提示信息
|
||||
prompt = f"""
|
||||
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
|
||||
|
||||
Prediction results using the provided CIF structure:
|
||||
|
||||
- Total Energy (eV): {energy}
|
||||
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
|
||||
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
|
||||
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
|
||||
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
# 异步执行预测操作
|
||||
return await asyncio.to_thread(run_prediction)
|
||||
0
sci_mcp/material_mcp/mp_query/__init__.py
Normal file
0
sci_mcp/material_mcp/mp_query/__init__.py
Normal file
42
sci_mcp/material_mcp/mp_query/get_mp_id.py
Normal file
42
sci_mcp/material_mcp/mp_query/get_mp_id.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import List
|
||||
from mp_api.client import MPRester
|
||||
from ...core.config import material_config
|
||||
|
||||
async def get_mpid_from_formula(formula: str) -> List[str]:
|
||||
"""
|
||||
Get material IDs (mpid) from Materials Project database by chemical formula.
|
||||
Returns mpids for the lowest energy structures.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula (e.g., "Fe2O3")
|
||||
|
||||
Returns:
|
||||
List of material IDs
|
||||
"""
|
||||
os.environ['HTTP_PROXY'] = material_config.HTTP_PROXY or ''
|
||||
os.environ['HTTPS_PROXY'] =material_config.HTTPS_PROXY or ''
|
||||
|
||||
|
||||
try:
|
||||
id_list = []
|
||||
|
||||
cleaned_formula = formula.replace(" ", "").replace("\n", "").replace("\'", "").replace("\"", "")
|
||||
if "=" in cleaned_formula:
|
||||
name, id = cleaned_formula.split("=")
|
||||
else:
|
||||
id = cleaned_formula
|
||||
|
||||
formula_list = [id]
|
||||
|
||||
with MPRester(material_config.MP_API_KEY) as mpr:
|
||||
docs = mpr.materials.summary.search(formula=formula_list)
|
||||
if not docs:
|
||||
return "No materials found"
|
||||
else:
|
||||
for doc in docs:
|
||||
id_list.append(doc.material_id)
|
||||
return id_list
|
||||
except Exception as e:
|
||||
|
||||
return f"Error: get_mpid_from_formula: {str(e)}"
|
||||
168
sci_mcp/material_mcp/mp_query/mp_query_tools.py
Normal file
168
sci_mcp/material_mcp/mp_query/mp_query_tools.py
Normal file
@@ -0,0 +1,168 @@
|
||||
|
||||
import glob
|
||||
import json
|
||||
from typing import Dict, Any, Union
|
||||
from ...core.llm_tools import llm_tool
|
||||
from .get_mp_id import get_mpid_from_formula
|
||||
from ..support.utils import extract_cif_info, remove_symmetry_equiv_xyz
|
||||
from ...core.config import material_config
|
||||
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
@llm_tool(name="search_crystal_structures_from_materials_project",
|
||||
description="Retrieve and optimize crystal structures from Materials Project database using a chemical formula")
|
||||
async def search_crystal_structures_from_materials_project(
|
||||
formula: str,
|
||||
conventional_unit_cell: bool = True,
|
||||
symprec: float = 0.1
|
||||
) -> str:
|
||||
"""
|
||||
Retrieves crystal structures for a given chemical formula from Materials Project database and applies symmetry optimization.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula to search for (e.g., "Fe2O3")
|
||||
conventional_unit_cell: If True, returns conventional unit cell; if False, returns primitive cell
|
||||
symprec: Symmetry precision parameter for structure refinement (default: 0.1)
|
||||
|
||||
Returns:
|
||||
Formatted CIF data for the retrieved crystal structures with symmetry analysis
|
||||
"""
|
||||
try:
|
||||
structures = {}
|
||||
mp_id_list = await get_mpid_from_formula(formula=formula)
|
||||
if isinstance(mp_id_list, str):
|
||||
return mp_id_list # 直接返回错误信息
|
||||
|
||||
for i, mp_id in enumerate(mp_id_list):
|
||||
try:
|
||||
# 文件操作可能引发异常
|
||||
cif_files = glob.glob(material_config.LOCAL_MP_CIF_ROOT + f"/{mp_id}.cif")
|
||||
if not cif_files:
|
||||
continue # 如果没有找到文件,跳过这个mp_id
|
||||
|
||||
cif_file = cif_files[0]
|
||||
structure = Structure.from_file(cif_file)
|
||||
|
||||
# 结构处理可能引发异常
|
||||
if conventional_unit_cell:
|
||||
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
|
||||
|
||||
# 对结构进行对称化处理
|
||||
sga = SpacegroupAnalyzer(structure, symprec=symprec)
|
||||
symmetrized_structure = sga.get_refined_structure()
|
||||
|
||||
# 使用CifWriter生成CIF数据
|
||||
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
|
||||
cif_data = str(cif_writer)
|
||||
|
||||
# 删除CIF文件中的对称性操作部分
|
||||
cif_data = remove_symmetry_equiv_xyz(cif_data)
|
||||
cif_data = cif_data.replace('# generated using pymatgen', "")
|
||||
|
||||
# 生成一个唯一的键
|
||||
formula_key = structure.composition.reduced_formula
|
||||
key = f"{formula_key}_{i}"
|
||||
|
||||
structures[key] = cif_data
|
||||
|
||||
# 只保留前config.MP_TOPK个结果
|
||||
if len(structures) >= material_config.MP_TOPK:
|
||||
break
|
||||
|
||||
except (FileNotFoundError, IndexError) as file_error:
|
||||
# 处理文件相关错误
|
||||
continue # 跳过这个mp_id,继续处理下一个
|
||||
except ValueError as value_error:
|
||||
# 处理结构处理中的值错误
|
||||
continue # 跳过这个mp_id,继续处理下一个
|
||||
except Exception as process_error:
|
||||
# 记录处理特定结构时的错误,但继续处理其他结构
|
||||
print(f"Error: processing structure {mp_id}: {str(process_error)}")
|
||||
continue
|
||||
|
||||
# 如果没有成功处理任何结构
|
||||
if not structures:
|
||||
return f"No valid crystal structures found for formula: {formula}"
|
||||
|
||||
# 格式化结果为可读字符串
|
||||
prompt = f"""
|
||||
# Materials Project Symmetrized Crystal Structure Data
|
||||
|
||||
Below are symmetrized crystal structure data for {len(structures)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
|
||||
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
|
||||
"""
|
||||
|
||||
for i, (key, cif_data) in enumerate(structures.items(), 1):
|
||||
prompt += f"[cif {i} begin]\n"
|
||||
prompt += cif_data
|
||||
prompt += f"\n[cif {i} end]\n\n"
|
||||
|
||||
return prompt
|
||||
|
||||
except Exception as e:
|
||||
# 捕获整个函数执行过程中的任何未处理异常
|
||||
return f"Error: An unexpected error occurred while processing crystal structures: {str(e)}"
|
||||
|
||||
@llm_tool(name="search_material_property_from_material_project",
|
||||
description="Query material properties from Materials Project database using chemical formula")
|
||||
async def search_material_property_from_materials_project(
|
||||
formula: str,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve detailed property data for materials matching a chemical formula from Materials Project database.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula of the material(s) to search for (e.g. 'Fe2O3', 'LiFePO4')
|
||||
|
||||
Returns:
|
||||
Formatted string containing material properties including structure, electronic, thermodynamic and mechanical data
|
||||
"""
|
||||
# 获取MP ID列表
|
||||
mp_id_list = await get_mpid_from_formula(formula=formula)
|
||||
|
||||
# 检查get_mpid_from_formula的返回值类型
|
||||
# 如果返回的是字符串,说明发生了错误或没有找到材料
|
||||
if isinstance(mp_id_list, str):
|
||||
return mp_id_list # 直接返回错误信息
|
||||
|
||||
# 如果代码执行到这里,说明mp_id_list是一个有效的ID列表
|
||||
try:
|
||||
# 获取材料属性
|
||||
properties = []
|
||||
for mp_id in mp_id_list:
|
||||
try:
|
||||
file_path = material_config.LOCAL_MP_PROPS_ROOT + f"/{mp_id}.json"
|
||||
crystal_props = extract_cif_info(file_path, ['all_fields'])
|
||||
properties.append(crystal_props)
|
||||
except Exception as file_error:
|
||||
# 记录单个文件处理错误但继续处理其他ID
|
||||
continue
|
||||
|
||||
# 检查是否有结果
|
||||
if len(properties) == 0:
|
||||
return "No material properties found for the given formula, please try again."
|
||||
|
||||
# 只保留前MP_TOPK个结果
|
||||
properties = properties[:material_config.MP_TOPK]
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = []
|
||||
for i, item in enumerate(properties, 1):
|
||||
formatted_result = f"[property {i} begin]\n"
|
||||
formatted_result += json.dumps(item, indent=2)
|
||||
formatted_result += f"\n[property {i} end]\n\n"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
# 将所有结果合并为一个字符串
|
||||
res_chunk = "\n\n".join(formatted_results)
|
||||
res_template = f"""
|
||||
Here are the search material property from the Materials Project database:
|
||||
Due to length limitations, only the top {len(properties)} results are shown below:\n
|
||||
{res_chunk}
|
||||
"""
|
||||
return res_template
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: processing material properties: {str(e)}"
|
||||
0
sci_mcp/material_mcp/oqmd_query/__init__.py
Normal file
0
sci_mcp/material_mcp/oqmd_query/__init__.py
Normal file
92
sci_mcp/material_mcp/oqmd_query/oqmd_query_tools.py
Executable file
92
sci_mcp/material_mcp/oqmd_query/oqmd_query_tools.py
Executable file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from bs4 import BeautifulSoup
|
||||
from io import StringIO
|
||||
from typing import Annotated, Any, Dict, List
|
||||
import mcp.types as types
|
||||
from ...core.llm_tools import llm_tool
|
||||
|
||||
|
||||
|
||||
@llm_tool(name="query_material_from_OQMD", description="Query material properties by chemical formula from OQMD database")
|
||||
async def query_material_from_OQMD(
|
||||
formula: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
|
||||
) -> str:
|
||||
"""
|
||||
Query material information by chemical formula from OQMD database.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula of the material (e.g., Fe2O3, LiFePO4)
|
||||
|
||||
Returns:
|
||||
Formatted text with material information and property tables
|
||||
"""
|
||||
# Fetch data from OQMD
|
||||
url = f"https://www.oqmd.org/materials/composition/{formula}"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=100.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Validate response content
|
||||
if not response.text or len(response.text) < 100:
|
||||
raise ValueError("Invalid response content from OQMD API")
|
||||
|
||||
# Parse HTML data
|
||||
html = response.text
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# Parse basic data
|
||||
basic_data = []
|
||||
h1_element = soup.find('h1')
|
||||
if h1_element:
|
||||
basic_data.append(h1_element.text.strip())
|
||||
else:
|
||||
basic_data.append(f"Material: {formula}")
|
||||
|
||||
for script in soup.find_all('p'):
|
||||
if script:
|
||||
combined_text = ""
|
||||
for element in script.contents:
|
||||
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
|
||||
url = "https://www.oqmd.org" + element['href']
|
||||
combined_text += f"[{element.text.strip()}]({url}) "
|
||||
elif hasattr(element, 'text'):
|
||||
combined_text += element.text.strip() + " "
|
||||
else:
|
||||
combined_text += str(element).strip() + " "
|
||||
basic_data.append(combined_text.strip())
|
||||
|
||||
# Parse table data
|
||||
table_data = ""
|
||||
table = soup.find('table')
|
||||
if table:
|
||||
try:
|
||||
df = pd.read_html(StringIO(str(table)))[0]
|
||||
df = df.fillna('')
|
||||
df = df.replace([float('inf'), float('-inf')], '')
|
||||
table_data = df.to_markdown(index=False)
|
||||
except Exception as e:
|
||||
|
||||
table_data = "Error: parsing table data"
|
||||
|
||||
# Integrate data into a single text
|
||||
combined_text = "\n\n".join(basic_data)
|
||||
if table_data:
|
||||
combined_text += "\n\n## Material Properties Table\n\n" + table_data
|
||||
|
||||
return combined_text
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return f"Error: OQMD API request failed - {str(e)}"
|
||||
except httpx.TimeoutException:
|
||||
return "Error: OQMD API request timed out"
|
||||
except httpx.NetworkError as e:
|
||||
return f"Error: Network error occurred - {str(e)}"
|
||||
except ValueError as e:
|
||||
return f"Error: Invalid response content - {str(e)}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error occurred - {str(e)}"
|
||||
|
||||
|
||||
95
sci_mcp/material_mcp/pymatgen_cal/pymatgen_cal_tools.py
Normal file
95
sci_mcp/material_mcp/pymatgen_cal/pymatgen_cal_tools.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
import asyncio
|
||||
from pymatgen.core import Structure
|
||||
from ...core.config import material_config
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ..support.utils import read_structure_from_file_name_or_content_string
|
||||
|
||||
@llm_tool(name="calculate_density_Pymatgen", description="Calculate the density of a crystal structure from a file or content string using Pymatgen")
|
||||
async def calculate_density_Pymatgen(structure_source: str) -> str:
|
||||
"""
|
||||
Calculates the density of a structure from a file or content string.
|
||||
|
||||
Args:
|
||||
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
|
||||
Returns:
|
||||
str: A Markdown formatted string with the density or an error message if the calculation fails.
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
# # 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content,fmt=content_format)
|
||||
density = structure.density
|
||||
|
||||
# 删除临时文件
|
||||
|
||||
return (f"## Density Calculation\n\n"
|
||||
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
|
||||
f"- **Density**: `{density:.2f} g/cm³`\n")
|
||||
except Exception as e:
|
||||
return f"Error: error occurred while calculating density: {str(e)}\n"
|
||||
|
||||
|
||||
@llm_tool(name="get_element_composition_Pymatgen", description="Analyze and retrieve the elemental composition of a crystal structure from a file or content string using Pymatgen")
|
||||
async def get_element_composition_Pymatgen(structure_source: str) -> str:
|
||||
"""
|
||||
Returns the elemental composition of a structure from a file or content string.
|
||||
|
||||
Args:
|
||||
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
|
||||
Returns:
|
||||
str: A Markdown formatted string with the elemental composition or an error message if the operation fails.
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
composition = structure.composition
|
||||
|
||||
return (f"## Element Composition\n\n"
|
||||
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
|
||||
f"- **Composition**: `{composition}`\n")
|
||||
except Exception as e:
|
||||
return f"Error: error occurred while getting element composition: {str(e)}\n"
|
||||
|
||||
|
||||
|
||||
@llm_tool(name="calculate_symmetry_Pymatgen", description="Determine the space group and symmetry operations of a crystal structure from a file or content string using Pymatgen")
|
||||
async def calculate_symmetry_Pymatgen(structure_source: str) -> str:
|
||||
"""
|
||||
Calculates the symmetry of a structure from a file or content string.
|
||||
|
||||
Args:
|
||||
structure_source (str): The name of the structure file (e.g., POSCAR, CIF) or the content string.
|
||||
|
||||
Returns:
|
||||
str: A Markdown formatted string with the symmetry information or an error message if the operation fails.
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
|
||||
try:
|
||||
# 使用read_structure_from_file_name_or_content_string函数读取结构
|
||||
structure_content, content_format = read_structure_from_file_name_or_content_string(structure_source)
|
||||
|
||||
|
||||
|
||||
# 使用pymatgen读取结构
|
||||
structure = Structure.from_str(structure_content, fmt=content_format)
|
||||
symmetry = structure.get_space_group_info()
|
||||
|
||||
return (f"## Symmetry Information\n\n"
|
||||
f"- **Structure**: `{structure.composition.reduced_formula}`\n"
|
||||
f"- **Space Group**: `{symmetry[0]}`\n"
|
||||
f"- **Number**: `{symmetry[1]}`\n")
|
||||
except Exception as e:
|
||||
return f"Error: error occurred while calculating symmetry: {str(e)}\n"
|
||||
|
||||
0
sci_mcp/material_mcp/support/__init__.py
Normal file
0
sci_mcp/material_mcp/support/__init__.py
Normal file
212
sci_mcp/material_mcp/support/utils.py
Executable file
212
sci_mcp/material_mcp/support/utils.py
Executable file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
CIF Utilities Module
|
||||
|
||||
This module provides basic functions for handling CIF (Crystallographic Information File) files,
|
||||
which are commonly used in materials science for representing crystal structures.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from ase.io import read
|
||||
import tempfile
|
||||
from typing import Optional, Tuple
|
||||
from ase import Atoms
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def read_cif_txt_file(file_path):
|
||||
"""
|
||||
Read the CIF file and return its content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the CIF file
|
||||
|
||||
Returns:
|
||||
String content of the CIF file or None if an error occurs
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def extract_cif_info(path: str, fields_name: list):
|
||||
"""
|
||||
Extract specific fields from the CIF description JSON file.
|
||||
|
||||
Args:
|
||||
path: Path to the JSON file containing CIF information
|
||||
fields_name: List of field categories to extract. Use 'all_fields' to extract all fields.
|
||||
Other options include 'basic_fields', 'energy_electronic_fields', 'metal_magentic_fields'
|
||||
|
||||
Returns:
|
||||
Dictionary containing the extracted fields
|
||||
"""
|
||||
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
|
||||
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
|
||||
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
|
||||
|
||||
selected_fields = []
|
||||
if fields_name[0] == 'all_fields':
|
||||
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
|
||||
else:
|
||||
for field in fields_name:
|
||||
selected_fields.extend(locals().get(field, []))
|
||||
|
||||
with open(path, 'r') as f:
|
||||
docs = json.load(f)
|
||||
|
||||
new_docs = {}
|
||||
for field_name in selected_fields:
|
||||
new_docs[field_name] = docs.get(field_name, '')
|
||||
|
||||
return new_docs
|
||||
|
||||
def remove_symmetry_equiv_xyz(cif_content):
|
||||
"""
|
||||
Remove symmetry operations section from CIF file content.
|
||||
|
||||
This is often useful when working with CIF files in certain visualization tools
|
||||
or when focusing on the basic structure without symmetry operations.
|
||||
|
||||
Args:
|
||||
cif_content: CIF file content string
|
||||
|
||||
Returns:
|
||||
Cleaned CIF content string with symmetry operations removed
|
||||
"""
|
||||
lines = cif_content.split('\n')
|
||||
output_lines = []
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# 检测循环开始
|
||||
if line == 'loop_':
|
||||
# 查看下一行,检查是否是对称性循环
|
||||
next_lines = []
|
||||
j = i + 1
|
||||
while j < len(lines) and lines[j].strip().startswith('_'):
|
||||
next_lines.append(lines[j].strip())
|
||||
j += 1
|
||||
|
||||
# 检查是否包含对称性操作标签
|
||||
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
|
||||
# 跳过整个循环块
|
||||
while i < len(lines):
|
||||
if i + 1 >= len(lines):
|
||||
break
|
||||
|
||||
next_line = lines[i + 1].strip()
|
||||
# 检查是否到达下一个循环或数据块
|
||||
if next_line == 'loop_' or next_line.startswith('data_'):
|
||||
break
|
||||
|
||||
# 检查是否到达原子位置部分
|
||||
if next_line.startswith('_atom_site_'):
|
||||
break
|
||||
|
||||
i += 1
|
||||
else:
|
||||
# 不是对称性循环,保留loop_行
|
||||
output_lines.append(lines[i])
|
||||
else:
|
||||
# 非循环开始行,直接保留
|
||||
output_lines.append(lines[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
|
||||
def read_structure_from_file_name_or_content_string(file_name_or_content_string: str, format_type: str = "auto") -> Tuple[str, str]:
|
||||
"""
|
||||
处理结构输入,判断是文件名还是直接内容
|
||||
|
||||
当file_name_or_content_string被视为文件名时,会在material_config.TEMP_ROOT目录下查找该文件。
|
||||
这适用于大模型生成的临时文件,这些文件通常存储在临时目录中。
|
||||
|
||||
Args:
|
||||
file_name_or_content_string: 文件名或结构内容字符串
|
||||
format_type: 结构格式类型,"auto"表示自动检测
|
||||
|
||||
Returns:
|
||||
tuple: (内容字符串, 实际格式类型)
|
||||
"""
|
||||
# 首先检查是否是完整路径的文件
|
||||
if os.path.exists(file_name_or_content_string) and os.path.isfile(file_name_or_content_string):
|
||||
# 是完整路径文件,读取文件内容
|
||||
with open(file_name_or_content_string, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 如果格式为auto,从文件扩展名推断
|
||||
if format_type == "auto":
|
||||
ext = os.path.splitext(file_name_or_content_string)[1].lower().lstrip('.')
|
||||
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
|
||||
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
|
||||
else:
|
||||
# 默认假设为CIF
|
||||
format_type = 'cif'
|
||||
else:
|
||||
# 检查是否是临时目录中的文件名
|
||||
temp_path = os.path.join(material_config.TEMP_ROOT, file_name_or_content_string)
|
||||
if os.path.exists(temp_path) and os.path.isfile(temp_path):
|
||||
# 是临时目录中的文件,读取文件内容
|
||||
with open(temp_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 如果格式为auto,从文件扩展名推断
|
||||
if format_type == "auto":
|
||||
ext = os.path.splitext(temp_path)[1].lower().lstrip('.')
|
||||
if ext in ['cif', 'xyz', 'vasp', 'poscar']:
|
||||
format_type = 'cif' if ext == 'cif' else 'xyz' if ext == 'xyz' else 'vasp'
|
||||
else:
|
||||
# 默认假设为CIF
|
||||
format_type = 'cif'
|
||||
else:
|
||||
# 不是文件路径,假设是直接内容
|
||||
content = file_name_or_content_string
|
||||
|
||||
# 如果格式为auto,尝试从内容推断
|
||||
if format_type == "auto":
|
||||
# 简单启发式判断:
|
||||
# CIF文件通常包含"data_"和"_cell_"
|
||||
if "data_" in content and "_cell_" in content:
|
||||
format_type = "cif"
|
||||
# XYZ文件通常第一行是原子数量
|
||||
elif content.strip().split('\n')[0].strip().isdigit():
|
||||
format_type = "xyz"
|
||||
# POSCAR/VASP格式通常第一行是注释
|
||||
elif len(content.strip().split('\n')) > 5 and all(len(line.split()) == 3 for line in content.strip().split('\n')[2:5]):
|
||||
format_type = "vasp"
|
||||
# 默认假设为CIF
|
||||
else:
|
||||
format_type = "cif"
|
||||
|
||||
return content, format_type
|
||||
|
||||
def convert_structure(input_format: str='cif', content: str=None) -> Optional[Atoms]:
|
||||
"""
|
||||
将输入内容转换为Atoms对象
|
||||
|
||||
Args:
|
||||
input_format: 输入格式 (cif, xyz, vasp等)
|
||||
content: 结构内容字符串
|
||||
|
||||
Returns:
|
||||
ASE Atoms对象,如果转换失败则返回None
|
||||
"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
atoms = read(tmp_path)
|
||||
os.unlink(tmp_path)
|
||||
return atoms
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert structure: {str(e)}")
|
||||
return None
|
||||
Reference in New Issue
Block a user