192 lines
5.4 KiB
Python
Executable File
192 lines
5.4 KiB
Python
Executable File
"""
|
||
Structure Optimization Module
|
||
|
||
This module provides functions for optimizing crystal structures using the FairChem model.
|
||
"""
|
||
|
||
import asyncio
|
||
from io import StringIO
|
||
import sys
|
||
import tempfile
|
||
import os
|
||
import logging
|
||
from typing import Optional, Dict, Any
|
||
|
||
from ase.io import read, write
|
||
from ase.optimize import FIRE
|
||
from ase.filters import FrechetCellFilter
|
||
from ase.atoms import Atoms
|
||
from pymatgen.core.structure import Structure
|
||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||
from pymatgen.io.cif import CifWriter
|
||
|
||
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
|
||
from mars_toolkit.core.llm_tools import llm_tool
|
||
from mars_toolkit.core.config import config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 初始化FairChem模型
|
||
calc = None
|
||
|
||
def init_model():
|
||
"""初始化FairChem模型"""
|
||
global calc
|
||
if calc is not None:
|
||
return
|
||
|
||
try:
|
||
from fairchem.core import OCPCalculator
|
||
calc = OCPCalculator(checkpoint_path=config.FAIRCHEM_MODEL_PATH)
|
||
logger.info("FairChem model initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||
raise
|
||
|
||
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
||
"""
|
||
将输入内容转换为Atoms对象
|
||
|
||
Args:
|
||
input_format: 输入格式 (cif, xyz, vasp等)
|
||
content: 结构内容字符串
|
||
|
||
Returns:
|
||
ASE Atoms对象,如果转换失败则返回None
|
||
"""
|
||
try:
|
||
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||
tmp_file.write(content)
|
||
tmp_path = tmp_file.name
|
||
|
||
atoms = read(tmp_path)
|
||
os.unlink(tmp_path)
|
||
return atoms
|
||
except Exception as e:
|
||
logger.error(f"Failed to convert structure: {str(e)}")
|
||
return None
|
||
|
||
def generate_symmetry_cif(structure: Structure) -> str:
|
||
"""
|
||
生成对称性CIF
|
||
|
||
Args:
|
||
structure: Pymatgen Structure对象
|
||
|
||
Returns:
|
||
CIF格式的字符串
|
||
"""
|
||
analyzer = SpacegroupAnalyzer(structure)
|
||
structure_refined = analyzer.get_refined_structure()
|
||
|
||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
|
||
cif_writer.write_file(tmp_file.name)
|
||
tmp_file.seek(0)
|
||
content = tmp_file.read()
|
||
os.unlink(tmp_file.name)
|
||
return content
|
||
|
||
def optimize_structure(atoms: Atoms, output_format: str) -> str:
|
||
"""
|
||
优化晶体结构
|
||
|
||
Args:
|
||
atoms: ASE Atoms对象
|
||
output_format: 输出格式 (cif, xyz, vasp等)
|
||
|
||
Returns:
|
||
包含优化结果的格式化字符串
|
||
"""
|
||
atoms.calc = calc
|
||
|
||
try:
|
||
# 捕获优化过程的输出
|
||
temp_output = StringIO()
|
||
original_stdout = sys.stdout
|
||
sys.stdout = temp_output
|
||
|
||
# 执行优化
|
||
dyn = FIRE(FrechetCellFilter(atoms))
|
||
dyn.run(fmax=config.FMAX)
|
||
|
||
# 恢复标准输出并获取日志
|
||
sys.stdout = original_stdout
|
||
optimization_log = temp_output.getvalue()
|
||
temp_output.close()
|
||
|
||
# 获取总能量
|
||
total_energy = atoms.get_potential_energy()
|
||
|
||
# 处理优化后的结构
|
||
if output_format == "cif":
|
||
optimized_structure = Structure.from_ase_atoms(atoms)
|
||
content = generate_symmetry_cif(optimized_structure)
|
||
content = remove_symmetry_equiv_xyz(content)
|
||
|
||
else:
|
||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||
write(tmp_file.name, atoms)
|
||
tmp_file.seek(0)
|
||
content = tmp_file.read()
|
||
|
||
os.unlink(tmp_file.name)
|
||
|
||
# 格式化返回结果
|
||
format_result = f"""
|
||
The following is the optimized crystal structure information:
|
||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||
**Total Energy: {total_energy} eV**
|
||
|
||
#### Optimizing Log:
|
||
```text
|
||
{optimization_log}
|
||
```
|
||
### Optimized {output_format.upper()} Content:
|
||
```
|
||
{content}
|
||
```
|
||
"""
|
||
return format_result
|
||
except Exception as e:
|
||
logger.error(f"Failed to optimize structure: {str(e)}")
|
||
raise e
|
||
|
||
@llm_tool(name="optimize_crystal_structure",
|
||
description="Optimize crystal structure using FairChem model")
|
||
async def optimize_crystal_structure(
|
||
content: str,
|
||
input_format: str = "cif",
|
||
output_format: str = "cif"
|
||
) -> str:
|
||
"""
|
||
Optimize crystal structure using FairChem model.
|
||
|
||
Args:
|
||
content: Crystal structure content string
|
||
input_format: Input format (cif, xyz, vasp)
|
||
output_format: Output format (cif, xyz, vasp)
|
||
|
||
Returns:
|
||
Optimized structure with energy and optimization log
|
||
"""
|
||
# 确保模型已初始化
|
||
if calc is None:
|
||
init_model()
|
||
|
||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||
def run_optimization():
|
||
# 转换结构
|
||
atoms = convert_structure(input_format, content)
|
||
if atoms is None:
|
||
raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确")
|
||
|
||
# 优化结构
|
||
return optimize_structure(atoms, output_format)
|
||
|
||
try:
|
||
# 直接返回结果或抛出异常
|
||
return await asyncio.to_thread(run_optimization)
|
||
except Exception as e:
|
||
return str(e)
|