105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
"""
|
|
Author: Yutang LI
|
|
Institution: SIAT-MIC
|
|
Contact: yt.li2@siat.ac.cn
|
|
"""
|
|
|
|
import logging
|
|
import tempfile
|
|
import os
|
|
import datetime
|
|
from typing import Optional
|
|
from ase.optimize import FIRE
|
|
from ase.filters import FrechetCellFilter
|
|
from ase.atoms import Atoms
|
|
from ase.io import read, write
|
|
from pymatgen.core.structure import Structure
|
|
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
|
from pymatgen.io.cif import CifWriter
|
|
from utils import settings, handle_minio_upload
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 初始化模型
|
|
calc = None
|
|
|
|
def init_model():
|
|
"""初始化FairChem模型"""
|
|
global calc
|
|
try:
|
|
from fairchem.core import OCPCalculator
|
|
calc = OCPCalculator(checkpoint_path=settings.fairchem_model_path)
|
|
logger.info("FairChem model initialized successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
|
raise
|
|
|
|
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
|
"""将输入内容转换为Atoms对象"""
|
|
try:
|
|
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
|
tmp_file.write(content)
|
|
tmp_path = tmp_file.name
|
|
|
|
atoms = read(tmp_path)
|
|
os.unlink(tmp_path)
|
|
return atoms
|
|
except Exception as e:
|
|
logger.error(f"Failed to convert structure: {str(e)}")
|
|
return None
|
|
|
|
def generate_symmetry_cif(structure: Structure) -> str:
|
|
"""生成对称性CIF"""
|
|
analyzer = SpacegroupAnalyzer(structure)
|
|
structure = analyzer.get_refined_structure()
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
|
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
|
|
cif_writer.write_file(tmp_file.name)
|
|
tmp_file.seek(0)
|
|
return tmp_file.read()
|
|
|
|
def optimize_structure(atoms: Atoms, output_format: str):
|
|
"""优化晶体结构"""
|
|
atoms.calc = calc
|
|
|
|
try:
|
|
import io
|
|
from contextlib import redirect_stdout
|
|
|
|
# 创建StringIO对象捕获输出
|
|
f = io.StringIO()
|
|
dyn = FIRE(FrechetCellFilter(atoms))
|
|
|
|
# 同时捕获并输出到控制台
|
|
with redirect_stdout(f):
|
|
dyn.run(fmax=settings.fmax)
|
|
# 获取捕获的日志
|
|
optimization_log = f.getvalue()
|
|
# 同时输出到控制台
|
|
print(optimization_log)
|
|
total_energy = atoms.get_potential_energy()
|
|
|
|
# 处理对称性
|
|
if output_format == "cif":
|
|
optimized_structure = Structure.from_ase_atoms(atoms)
|
|
content = generate_symmetry_cif(optimized_structure)
|
|
else:
|
|
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
|
write(tmp_file.name, atoms)
|
|
tmp_file.seek(0)
|
|
content = tmp_file.read()
|
|
|
|
# 保存优化结果到临时文件
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
file_name = f"optimized_structure_{timestamp}.{output_format}"
|
|
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
|
|
tmp_file.write(content)
|
|
tmp_path = tmp_file.name
|
|
|
|
# 上传到MinIO
|
|
url = handle_minio_upload(tmp_path, file_name)
|
|
return total_energy, content, optimization_log, url
|
|
finally:
|
|
os.unlink(tmp_path)
|