Files
mars-mcp/mars_toolkit/compute/structure_opt.py

193 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Structure Optimization Module
This module provides functions for optimizing crystal structures using the FairChem model.
"""
import asyncio
from io import StringIO
import sys
import tempfile
import os
import logging
from typing import Optional, Dict, Any
from ase.io import read, write
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
from mars_toolkit.core.error_handlers import handle_general_error
logger = logging.getLogger(__name__)
# 初始化FairChem模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
if calc is not None:
return
try:
from fairchem.core import OCPCalculator
calc = OCPCalculator(checkpoint_path=config.FAIRCHEM_MODEL_PATH)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""
将输入内容转换为Atoms对象
Args:
input_format: 输入格式 (cif, xyz, vasp等)
content: 结构内容字符串
Returns:
ASE Atoms对象如果转换失败则返回None
"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""
生成对称性CIF
Args:
structure: Pymatgen Structure对象
Returns:
CIF格式的字符串
"""
analyzer = SpacegroupAnalyzer(structure)
structure_refined = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
return content
def optimize_structure(atoms: Atoms, output_format: str) -> str:
"""
优化晶体结构
Args:
atoms: ASE Atoms对象
output_format: 输出格式 (cif, xyz, vasp等)
Returns:
包含优化结果的格式化字符串
"""
atoms.calc = calc
try:
# 捕获优化过程的输出
temp_output = StringIO()
original_stdout = sys.stdout
sys.stdout = temp_output
# 执行优化
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=config.FMAX)
# 恢复标准输出并获取日志
sys.stdout = original_stdout
optimization_log = temp_output.getvalue()
temp_output.close()
# 获取总能量
total_energy = atoms.get_potential_energy()
# 处理优化后的结构
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
content = remove_symmetry_equiv_xyz(content)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
**Total Energy: {total_energy} eV**
#### Optimizing Log:
```text
{optimization_log}
```
### Optimized {output_format.upper()} Content:
```
{content}
```
"""
return format_result
except Exception as e:
logger.error(f"Failed to optimize structure: {str(e)}")
raise e
@llm_tool(name="optimize_crystal_structure",
description="Optimize crystal structure using FairChem model")
async def optimize_crystal_structure(
content: str,
input_format: str = "cif",
output_format: str = "cif"
) -> str:
"""
Optimize crystal structure using FairChem model.
Args:
content: Crystal structure content string
input_format: Input format (cif, xyz, vasp)
output_format: Output format (cif, xyz, vasp)
Returns:
Optimized structure with energy and optimization log
"""
# 确保模型已初始化
if calc is None:
init_model()
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_optimization():
# 转换结构
atoms = convert_structure(input_format, content)
if atoms is None:
raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确")
# 优化结构
return optimize_structure(atoms, output_format)
try:
# 直接返回结果或抛出异常
return await asyncio.to_thread(run_optimization)
except Exception as e:
return handle_general_error(e)