""" Structure Optimization Module This module provides functions for optimizing crystal structures using the FairChem model. """ import asyncio from io import StringIO import sys import tempfile import os import logging from typing import Optional, Dict, Any from ase.io import read, write from ase.optimize import FIRE from ase.filters import FrechetCellFilter from ase.atoms import Atoms from pymatgen.core.structure import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.io.cif import CifWriter from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz from mars_toolkit.core.llm_tools import llm_tool from mars_toolkit.core.config import config logger = logging.getLogger(__name__) # 初始化FairChem模型 calc = None def init_model(): """初始化FairChem模型""" global calc if calc is not None: return try: from fairchem.core import OCPCalculator calc = OCPCalculator(checkpoint_path=config.FAIRCHEM_MODEL_PATH) logger.info("FairChem model initialized successfully") except Exception as e: logger.error(f"Failed to initialize FairChem model: {str(e)}") raise def convert_structure(input_format: str, content: str) -> Optional[Atoms]: """ 将输入内容转换为Atoms对象 Args: input_format: 输入格式 (cif, xyz, vasp等) content: 结构内容字符串 Returns: ASE Atoms对象,如果转换失败则返回None """ try: with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file: tmp_file.write(content) tmp_path = tmp_file.name atoms = read(tmp_path) os.unlink(tmp_path) return atoms except Exception as e: logger.error(f"Failed to convert structure: {str(e)}") return None def generate_symmetry_cif(structure: Structure) -> str: """ 生成对称性CIF Args: structure: Pymatgen Structure对象 Returns: CIF格式的字符串 """ analyzer = SpacegroupAnalyzer(structure) structure_refined = analyzer.get_refined_structure() with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file: cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True) cif_writer.write_file(tmp_file.name) tmp_file.seek(0) content = tmp_file.read() os.unlink(tmp_file.name) return content def optimize_structure(atoms: Atoms, output_format: str) -> str: """ 优化晶体结构 Args: atoms: ASE Atoms对象 output_format: 输出格式 (cif, xyz, vasp等) Returns: 包含优化结果的格式化字符串 """ atoms.calc = calc try: # 捕获优化过程的输出 temp_output = StringIO() original_stdout = sys.stdout sys.stdout = temp_output # 执行优化 dyn = FIRE(FrechetCellFilter(atoms)) dyn.run(fmax=config.FMAX) # 恢复标准输出并获取日志 sys.stdout = original_stdout optimization_log = temp_output.getvalue() temp_output.close() # 获取总能量 total_energy = atoms.get_potential_energy() # 处理优化后的结构 if output_format == "cif": optimized_structure = Structure.from_ase_atoms(atoms) content = generate_symmetry_cif(optimized_structure) content = remove_symmetry_equiv_xyz(content) else: with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file: write(tmp_file.name, atoms) tmp_file.seek(0) content = tmp_file.read() os.unlink(tmp_file.name) # 格式化返回结果 format_result = f""" The following is the optimized crystal structure information: ### Optimization Results (using FIRE(eqV2_86M) algorithm): **Total Energy: {total_energy} eV** #### Optimizing Log: ```text {optimization_log} ``` ### Optimized {output_format.upper()} Content: ``` {content} ``` """ return format_result except Exception as e: logger.error(f"Failed to optimize structure: {str(e)}") raise e @llm_tool(name="optimize_crystal_structure", description="Optimize crystal structure using FairChem model") async def optimize_crystal_structure( content: str, input_format: str = "cif", output_format: str = "cif" ) -> str: """ Optimize crystal structure using FairChem model. Args: content: Crystal structure content string input_format: Input format (cif, xyz, vasp) output_format: Output format (cif, xyz, vasp) Returns: Optimized structure with energy and optimization log """ # 确保模型已初始化 if calc is None: init_model() # 使用asyncio.to_thread异步执行可能阻塞的操作 def run_optimization(): # 转换结构 atoms = convert_structure(input_format, content) if atoms is None: raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确") # 优化结构 return optimize_structure(atoms, output_format) try: # 直接返回结果或抛出异常 return await asyncio.to_thread(run_optimization) except Exception as e: return str(e)