初次提交
This commit is contained in:
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Structure Optimization Module
|
||||
|
||||
This module provides functions for optimizing crystal structures using the FairChem model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from ase.io import read, write
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
from ..support.utils import convert_structure, remove_symmetry_equiv_xyz, read_structure_from_file_name_or_content_string
|
||||
from ...core.llm_tools import llm_tool
|
||||
from ...core.config import material_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FairChem模型
|
||||
calc = None
|
||||
|
||||
def init_model():
|
||||
"""初始化FairChem模型"""
|
||||
global calc
|
||||
if calc is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from fairchem.core import OCPCalculator
|
||||
calc = OCPCalculator(checkpoint_path=material_config.FAIRCHEM_MODEL_PATH)
|
||||
logger.info("FairChem model initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def generate_symmetry_cif(structure: Structure) -> str:
|
||||
"""
|
||||
生成对称性CIF
|
||||
|
||||
Args:
|
||||
structure: Pymatgen Structure对象
|
||||
|
||||
Returns:
|
||||
CIF格式的字符串
|
||||
"""
|
||||
analyzer = SpacegroupAnalyzer(structure)
|
||||
structure_refined = analyzer.get_refined_structure()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
|
||||
cif_writer.write_file(tmp_file.name)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
os.unlink(tmp_file.name)
|
||||
return content
|
||||
|
||||
def optimize_structure(atoms: Atoms, output_format: str, fmax: float = 0.05) -> str:
|
||||
"""
|
||||
优化晶体结构
|
||||
|
||||
Args:
|
||||
atoms: ASE Atoms对象
|
||||
output_format: 输出格式 (cif, xyz, vasp等)
|
||||
fmax: 力收敛标准
|
||||
|
||||
Returns:
|
||||
包含优化结果的格式化字符串
|
||||
"""
|
||||
atoms.calc = calc
|
||||
|
||||
try:
|
||||
# 捕获优化过程的输出
|
||||
temp_output = StringIO()
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = temp_output
|
||||
|
||||
# 执行优化
|
||||
dyn = FIRE(FrechetCellFilter(atoms))
|
||||
dyn.run(fmax=fmax)
|
||||
|
||||
# 恢复标准输出并获取日志
|
||||
sys.stdout = original_stdout
|
||||
optimization_log = temp_output.getvalue()
|
||||
temp_output.close()
|
||||
|
||||
# 获取总能量
|
||||
total_energy = atoms.get_potential_energy()
|
||||
|
||||
# 处理优化后的结构
|
||||
if output_format == "cif":
|
||||
optimized_structure = Structure.from_ase_atoms(atoms)
|
||||
content = generate_symmetry_cif(optimized_structure)
|
||||
content = remove_symmetry_equiv_xyz(content)
|
||||
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||||
write(tmp_file.name, atoms)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
# 格式化返回结果
|
||||
format_result = f"""
|
||||
The following is the optimized crystal structure information:
|
||||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||
**Total Energy: {total_energy} eV**
|
||||
|
||||
#### Optimizing Log:
|
||||
```text
|
||||
{optimization_log}
|
||||
```
|
||||
### Optimized {output_format.upper()} Content:
|
||||
```
|
||||
{content}
|
||||
```
|
||||
"""
|
||||
return format_result
|
||||
except Exception as e:
|
||||
return f"Error: Failed to optimize structure: {str(e)}"
|
||||
|
||||
@llm_tool(name="optimize_crystal_structure_FairChem",
|
||||
description="Optimizes crystal structures using the FairChem model")
|
||||
async def optimize_crystal_structure_FairChem(
|
||||
structure_source: str,
|
||||
format_type: str = "auto",
|
||||
optimization_level: str = "normal"
|
||||
) -> str:
|
||||
"""
|
||||
Optimizes a crystal structure to find its lowest energy configuration.
|
||||
|
||||
Args:
|
||||
structure_source: Either a file name or direct structure content (CIF, XYZ, POSCAR)
|
||||
format_type: Structure format type (auto, cif, xyz, poscar). Default "auto" will attempt to detect format.
|
||||
optimization_level: Optimization precision level (quick, normal, precise)
|
||||
|
||||
Returns:
|
||||
Optimized structure with total energy and optimization details
|
||||
"""
|
||||
# 确保模型已初始化
|
||||
if calc is None:
|
||||
init_model()
|
||||
|
||||
# 设置优化参数
|
||||
fmax_values = {
|
||||
"quick": 0.1,
|
||||
"normal": 0.05,
|
||||
"precise": 0.01
|
||||
}
|
||||
fmax = fmax_values.get(optimization_level, 0.05)
|
||||
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_optimization():
|
||||
try:
|
||||
# 处理输入结构
|
||||
content, actual_format = read_structure_from_file_name_or_content_string(structure_source, format_type)
|
||||
|
||||
# 转换格式映射
|
||||
format_mapping = {
|
||||
"cif": "cif",
|
||||
"xyz": "xyz",
|
||||
"poscar": "vasp",
|
||||
"vasp": "vasp"
|
||||
}
|
||||
final_format = format_mapping.get(actual_format.lower(), "cif")
|
||||
|
||||
# 转换结构
|
||||
atoms = convert_structure(final_format, content)
|
||||
if atoms is None:
|
||||
return f"Error: Unable to convert input structure. Please check if the format is correct."
|
||||
|
||||
# 优化结构
|
||||
return optimize_structure(atoms, final_format, fmax=fmax)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: Failed to optimize structure: {str(e)}"
|
||||
|
||||
return await asyncio.to_thread(run_optimization)
|
||||
|
||||
Reference in New Issue
Block a user