初次提交

This commit is contained in:
lzy
2025-05-09 14:16:33 +08:00
commit 3a50afeec4
56 changed files with 9224 additions and 0 deletions

View File

@@ -0,0 +1,191 @@
"""
Structure Optimization Module
This module provides functions for optimizing crystal structures using the FairChem model.
"""
import asyncio
from io import StringIO
import sys
import tempfile
import os
import logging
from typing import Optional, Dict, Any
from ase.io import read, write
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from ..support.utils import convert_structure, remove_symmetry_equiv_xyz, read_structure_from_file_name_or_content_string
from ...core.llm_tools import llm_tool
from ...core.config import material_config
logger = logging.getLogger(__name__)
# 初始化FairChem模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
if calc is not None:
return
try:
from fairchem.core import OCPCalculator
calc = OCPCalculator(checkpoint_path=material_config.FAIRCHEM_MODEL_PATH)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def generate_symmetry_cif(structure: Structure) -> str:
"""
生成对称性CIF
Args:
structure: Pymatgen Structure对象
Returns:
CIF格式的字符串
"""
analyzer = SpacegroupAnalyzer(structure)
structure_refined = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
return content
def optimize_structure(atoms: Atoms, output_format: str, fmax: float = 0.05) -> str:
"""
优化晶体结构
Args:
atoms: ASE Atoms对象
output_format: 输出格式 (cif, xyz, vasp等)
fmax: 力收敛标准
Returns:
包含优化结果的格式化字符串
"""
atoms.calc = calc
try:
# 捕获优化过程的输出
temp_output = StringIO()
original_stdout = sys.stdout
sys.stdout = temp_output
# 执行优化
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=fmax)
# 恢复标准输出并获取日志
sys.stdout = original_stdout
optimization_log = temp_output.getvalue()
temp_output.close()
# 获取总能量
total_energy = atoms.get_potential_energy()
# 处理优化后的结构
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
content = remove_symmetry_equiv_xyz(content)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
**Total Energy: {total_energy} eV**
#### Optimizing Log:
```text
{optimization_log}
```
### Optimized {output_format.upper()} Content:
```
{content}
```
"""
return format_result
except Exception as e:
return f"Error: Failed to optimize structure: {str(e)}"
@llm_tool(name="optimize_crystal_structure_FairChem",
description="Optimizes crystal structures using the FairChem model")
async def optimize_crystal_structure_FairChem(
structure_source: str,
format_type: str = "auto",
optimization_level: str = "normal"
) -> str:
"""
Optimizes a crystal structure to find its lowest energy configuration.
Args:
structure_source: Either a file name or direct structure content (CIF, XYZ, POSCAR)
format_type: Structure format type (auto, cif, xyz, poscar). Default "auto" will attempt to detect format.
optimization_level: Optimization precision level (quick, normal, precise)
Returns:
Optimized structure with total energy and optimization details
"""
# 确保模型已初始化
if calc is None:
init_model()
# 设置优化参数
fmax_values = {
"quick": 0.1,
"normal": 0.05,
"precise": 0.01
}
fmax = fmax_values.get(optimization_level, 0.05)
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_optimization():
try:
# 处理输入结构
content, actual_format = read_structure_from_file_name_or_content_string(structure_source, format_type)
# 转换格式映射
format_mapping = {
"cif": "cif",
"xyz": "xyz",
"poscar": "vasp",
"vasp": "vasp"
}
final_format = format_mapping.get(actual_format.lower(), "cif")
# 转换结构
atoms = convert_structure(final_format, content)
if atoms is None:
return f"Error: Unable to convert input structure. Please check if the format is correct."
# 优化结构
return optimize_structure(atoms, final_format, fmax=fmax)
except Exception as e:
return f"Error: Failed to optimize structure: {str(e)}"
return await asyncio.to_thread(run_optimization)