构建mars_toolkit,删除tools_for_ms
This commit is contained in:
12
mars_toolkit/compute/__init__.py
Normal file
12
mars_toolkit/compute/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Compute Module
|
||||
|
||||
This module provides computational tools for materials science, including:
|
||||
- Material generation
|
||||
- Property prediction
|
||||
- Structure optimization
|
||||
"""
|
||||
|
||||
from mars_toolkit.compute.material_gen import generate_material
|
||||
from mars_toolkit.compute.property_pred import predict_properties
|
||||
from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
|
||||
BIN
mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc
Normal file
BIN
mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc
Normal file
BIN
mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc
Normal file
BIN
mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc
Normal file
Binary file not shown.
423
mars_toolkit/compute/material_gen.py
Normal file
423
mars_toolkit/compute/material_gen.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Material Generation Module
|
||||
|
||||
This module provides functions for generating crystal structures with optional property constraints.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import os
|
||||
import datetime
|
||||
import asyncio
|
||||
import zipfile
|
||||
import shutil
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
||||
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from ase.io import read, write
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
# 导入路径已更新
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
|
||||
# 使用mattergen_wrapper
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
||||
from mattergen_wrapper import generator
|
||||
CrystalGenerator = generator.CrystalGenerator
|
||||
from mattergen.common.data.types import TargetProperty
|
||||
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||||
from mattergen.common.utils.data_classes import (
|
||||
PRETRAINED_MODEL_NAME,
|
||||
MatterGenCheckpointInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_cif_content(content):
|
||||
"""
|
||||
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
||||
|
||||
Args:
|
||||
content: String containing CIF content, possibly with PK headers
|
||||
|
||||
Returns:
|
||||
Formatted string with each CIF file properly labeled and formatted
|
||||
"""
|
||||
# 如果内容为空,直接返回空字符串
|
||||
if not content or content.strip() == '':
|
||||
return ''
|
||||
|
||||
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
|
||||
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
|
||||
|
||||
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
|
||||
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
|
||||
|
||||
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
|
||||
# 但我们需要保留这个字段在每个CIF文件中
|
||||
cif_blocks = []
|
||||
|
||||
# 查找所有_chemical_formula_structural的位置
|
||||
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
|
||||
|
||||
# 如果没有找到任何_chemical_formula_structural,返回空字符串
|
||||
if not formula_positions:
|
||||
return ''
|
||||
|
||||
# 分割CIF块
|
||||
for i in range(len(formula_positions)):
|
||||
start_pos = formula_positions[i]
|
||||
# 如果是最后一个块,结束位置是字符串末尾
|
||||
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
|
||||
|
||||
cif_block = content[start_pos:end_pos].strip()
|
||||
|
||||
# 提取formula值
|
||||
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
|
||||
if formula_match:
|
||||
formula = formula_match.group(1)
|
||||
cif_blocks.append((formula, cif_block))
|
||||
|
||||
# 格式化输出
|
||||
result = []
|
||||
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
|
||||
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]"
|
||||
result.append(formatted)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def convert_values(data_str):
|
||||
"""
|
||||
将字符串转换为字典
|
||||
|
||||
Args:
|
||||
data_str: JSON字符串
|
||||
|
||||
Returns:
|
||||
解析后的数据,如果解析失败则返回原字符串
|
||||
"""
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return data_str # 如果无法解析为JSON,返回原字符串
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
|
||||
"""
|
||||
Preprocess a property value based on its name, converting it to the appropriate type.
|
||||
|
||||
Args:
|
||||
property_name: Name of the property
|
||||
property_value: Value of the property (can be string, float, or int)
|
||||
|
||||
Returns:
|
||||
Tuple of (property_name, processed_value)
|
||||
|
||||
Raises:
|
||||
ValueError: If the property value is invalid for the given property name
|
||||
"""
|
||||
valid_properties = [
|
||||
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
|
||||
"energy_above_hull", "formation_energy_per_atom", "space_group",
|
||||
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
|
||||
]
|
||||
|
||||
if property_name not in valid_properties:
|
||||
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
|
||||
|
||||
# Process property_value if it's a string
|
||||
if isinstance(property_value, str):
|
||||
try:
|
||||
# Try to convert string to float for numeric properties
|
||||
if property_name != "chemical_system":
|
||||
property_value = float(property_value)
|
||||
except ValueError:
|
||||
# If conversion fails, keep as string (for chemical_system)
|
||||
pass
|
||||
|
||||
# Handle special cases for properties that need specific types
|
||||
if property_name == "chemical_system":
|
||||
if isinstance(property_value, (int, float)):
|
||||
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
|
||||
property_value = str(property_value)
|
||||
elif property_name == "space_group" :
|
||||
space_group = property_value
|
||||
if space_group < 1 or space_group > 230:
|
||||
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
|
||||
|
||||
return property_name, property_value
|
||||
|
||||
|
||||
def main(
|
||||
output_path: str,
|
||||
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
|
||||
model_path: str | None = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
config_overrides: list[str] | None = None,
|
||||
checkpoint_epoch: Literal["best", "last"] | int = "last",
|
||||
properties_to_condition_on: TargetProperty | None = None,
|
||||
sampling_config_path: str | None = None,
|
||||
sampling_config_name: str = "default",
|
||||
sampling_config_overrides: list[str] | None = None,
|
||||
record_trajectories: bool = True,
|
||||
diffusion_guidance_factor: float | None = None,
|
||||
strict_checkpoint_loading: bool = True,
|
||||
target_compositions: list[dict[str, int]] | None = None,
|
||||
):
|
||||
"""
|
||||
Evaluate diffusion model against molecular metrics.
|
||||
|
||||
Args:
|
||||
model_path: Path to DiffusionLightningModule checkpoint directory.
|
||||
output_path: Path to output directory.
|
||||
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
|
||||
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
|
||||
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
|
||||
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
|
||||
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
|
||||
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
|
||||
record: Whether to record the trajectories of the generated structures. (default: True)
|
||||
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
|
||||
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
|
||||
Only supported for models trained for crystal structure prediction (CSP) (default: None)
|
||||
|
||||
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
|
||||
"""
|
||||
assert (
|
||||
pretrained_name is not None or model_path is not None
|
||||
), "Either pretrained_name or model_path must be provided."
|
||||
assert (
|
||||
pretrained_name is None or model_path is None
|
||||
), "Only one of pretrained_name or model_path can be provided."
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
sampling_config_overrides = sampling_config_overrides or []
|
||||
config_overrides = config_overrides or []
|
||||
properties_to_condition_on = properties_to_condition_on or {}
|
||||
target_compositions = target_compositions or []
|
||||
|
||||
if pretrained_name is not None:
|
||||
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
|
||||
pretrained_name, config_overrides=config_overrides
|
||||
)
|
||||
else:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch=checkpoint_epoch,
|
||||
config_overrides=config_overrides,
|
||||
strict_checkpoint_loading=strict_checkpoint_loading,
|
||||
)
|
||||
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name=sampling_config_name,
|
||||
sampling_config_path=_sampling_config_path,
|
||||
sampling_config_overrides=sampling_config_overrides,
|
||||
record_trajectories=record_trajectories,
|
||||
diffusion_guidance_factor=(
|
||||
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
|
||||
),
|
||||
target_compositions_dict=target_compositions,
|
||||
)
|
||||
generator.generate(output_dir=Path(output_path))
|
||||
|
||||
|
||||
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
|
||||
async def generate_material(
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
This unified function can generate materials in three modes:
|
||||
1. Unconditional generation (no properties specified)
|
||||
2. Single property conditional generation (one property specified)
|
||||
3. Multi-property conditional generation (multiple properties specified)
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints. Can be:
|
||||
- None or empty dict for unconditional generation
|
||||
- Dict with single key-value pair for single property conditioning
|
||||
- Dict with multiple key-value pairs for multi-property conditioning
|
||||
Valid property names include: "dft_band_gap", "chemical_system", etc.
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
# 使用配置中的结果目录
|
||||
output_dir = config.MATTERGENMODEL_RESULT_PATH
|
||||
|
||||
# 处理字符串输入(如果提供)
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||
|
||||
# 如果为None,默认为空字典
|
||||
properties = properties or {}
|
||||
|
||||
# 根据生成模式处理属性
|
||||
if not properties:
|
||||
# 无条件生成
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
properties_to_condition_on = None
|
||||
generation_type = "unconditional"
|
||||
property_description = "unconditionally"
|
||||
else:
|
||||
# 条件生成(单属性或多属性)
|
||||
properties_to_condition_on = {}
|
||||
|
||||
# 处理每个属性
|
||||
for property_name, property_value in properties.items():
|
||||
_, processed_value = preprocess_property(property_name, property_value)
|
||||
properties_to_condition_on[property_name] = processed_value
|
||||
|
||||
# 根据属性确定使用哪个模型
|
||||
if len(properties) == 1:
|
||||
# 单属性条件
|
||||
property_name = list(properties.keys())[0]
|
||||
property_to_model = {
|
||||
"dft_mag_density": "dft_mag_density",
|
||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||
"dft_shear_modulus": "dft_shear_modulus",
|
||||
"energy_above_hull": "energy_above_hull",
|
||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||
"space_group": "space_group",
|
||||
"hhi_score": "hhi_score",
|
||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||
"chemical_system": "chemical_system",
|
||||
"dft_band_gap": "dft_band_gap"
|
||||
}
|
||||
model_dir = property_to_model.get(property_name, property_name)
|
||||
generation_type = "single_property"
|
||||
property_description = f"conditioned on {property_name} = {properties[property_name]}"
|
||||
else:
|
||||
# 多属性条件
|
||||
property_keys = set(properties.keys())
|
||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||
model_dir = "dft_mag_density_hhi_score"
|
||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||
model_dir = "chemical_system_energy_above_hull"
|
||||
else:
|
||||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||||
first_property = list(properties.keys())[0]
|
||||
model_dir = first_property
|
||||
generation_type = "multi_property"
|
||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||
|
||||
# 构建完整的模型路径
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
# 检查模型目录是否存在
|
||||
if not os.path.exists(model_path):
|
||||
# 如果特定模型不存在,回退到基础模型
|
||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
|
||||
# 使用适当的参数调用main函数
|
||||
main(
|
||||
output_path=output_dir,
|
||||
model_path=model_path,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
|
||||
)
|
||||
|
||||
# 创建字典存储文件内容
|
||||
result_dict = {}
|
||||
|
||||
# 定义文件路径
|
||||
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
|
||||
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
|
||||
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
|
||||
|
||||
# 读取CIF压缩文件
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
|
||||
# 根据生成类型创建描述性提示
|
||||
if generation_type == "unconditional":
|
||||
title = "Generated Material Structures"
|
||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||
elif generation_type == "single_property":
|
||||
property_name = list(properties.keys())[0]
|
||||
property_value = properties[property_name]
|
||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||
else: # multi_property
|
||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||
|
||||
# 创建完整的提示
|
||||
prompt = f"""
|
||||
# {title}
|
||||
|
||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||
|
||||
{'' if generation_type == 'unconditional' else f'''
|
||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||
the generation adheres to the specified property values. Higher values produce samples that more
|
||||
closely match the target properties but may reduce diversity.
|
||||
'''}
|
||||
|
||||
## CIF Files (Crystallographic Information Files)
|
||||
|
||||
- Standard format for crystallographic structures
|
||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||
- Used by crystallographic software and visualization tools
|
||||
|
||||
```
|
||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||
```
|
||||
|
||||
{description}
|
||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||
"""
|
||||
|
||||
# 清理文件(读取后删除)
|
||||
try:
|
||||
if os.path.exists(cif_zip_path):
|
||||
os.remove(cif_zip_path)
|
||||
if os.path.exists(xyz_file_path):
|
||||
os.remove(xyz_file_path)
|
||||
if os.path.exists(trajectories_zip_path):
|
||||
os.remove(trajectories_zip_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up files: {e}")
|
||||
return prompt
|
||||
72
mars_toolkit/compute/property_pred.py
Normal file
72
mars_toolkit/compute/property_pred.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Property Prediction Module
|
||||
|
||||
This module provides functions for predicting properties of crystal structures.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
import numpy as np
|
||||
from ase.units import GPa
|
||||
from mattersim.forcefield import MatterSimCalculator
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.compute.structure_opt import convert_structure
|
||||
|
||||
@llm_tool(
|
||||
name="predict_properties",
|
||||
description="Predict energy, forces, and stress of crystal structures based on CIF string",
|
||||
)
|
||||
async def predict_properties(cif_content: str) -> str:
|
||||
"""
|
||||
Use MatterSim to predict energy, forces, and stress of crystal structures.
|
||||
|
||||
Args:
|
||||
cif_content: Crystal structure string in CIF format
|
||||
|
||||
Returns:
|
||||
String containing prediction results
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_prediction():
|
||||
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
|
||||
structure = convert_structure("cif", cif_content)
|
||||
if structure is None:
|
||||
return "Unable to parse CIF string. Please check if the format is correct."
|
||||
|
||||
# 设置设备
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 使用 MatterSimCalculator 计算属性
|
||||
structure.calc = MatterSimCalculator(device=device)
|
||||
|
||||
# 直接获取能量、力和应力
|
||||
energy = structure.get_potential_energy()
|
||||
forces = structure.get_forces()
|
||||
stresses = structure.get_stress(voigt=False)
|
||||
|
||||
# 计算每原子能量
|
||||
num_atoms = len(structure)
|
||||
energy_per_atom = energy / num_atoms
|
||||
|
||||
# 计算应力(GPa和eV/A^3格式)
|
||||
stresses_ev_a3 = stresses
|
||||
stresses_gpa = stresses / GPa
|
||||
|
||||
# 构建返回的提示信息
|
||||
prompt = f"""
|
||||
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
|
||||
|
||||
Prediction results using the provided CIF structure:
|
||||
|
||||
- Total Energy (eV): {energy}
|
||||
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
|
||||
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
|
||||
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
|
||||
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
# 异步执行预测操作
|
||||
return await asyncio.to_thread(run_prediction)
|
||||
192
mars_toolkit/compute/structure_opt.py
Normal file
192
mars_toolkit/compute/structure_opt.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Structure Optimization Module
|
||||
|
||||
This module provides functions for optimizing crystal structures using the FairChem model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from ase.io import read, write
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.error_handlers import handle_general_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FairChem模型
|
||||
calc = None
|
||||
|
||||
def init_model():
|
||||
"""初始化FairChem模型"""
|
||||
global calc
|
||||
if calc is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from fairchem.core import OCPCalculator
|
||||
calc = OCPCalculator(checkpoint_path=config.FAIRCHEM_MODEL_PATH)
|
||||
logger.info("FairChem model initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||||
raise
|
||||
|
||||
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
||||
"""
|
||||
将输入内容转换为Atoms对象
|
||||
|
||||
Args:
|
||||
input_format: 输入格式 (cif, xyz, vasp等)
|
||||
content: 结构内容字符串
|
||||
|
||||
Returns:
|
||||
ASE Atoms对象,如果转换失败则返回None
|
||||
"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
atoms = read(tmp_path)
|
||||
os.unlink(tmp_path)
|
||||
return atoms
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert structure: {str(e)}")
|
||||
return None
|
||||
|
||||
def generate_symmetry_cif(structure: Structure) -> str:
|
||||
"""
|
||||
生成对称性CIF
|
||||
|
||||
Args:
|
||||
structure: Pymatgen Structure对象
|
||||
|
||||
Returns:
|
||||
CIF格式的字符串
|
||||
"""
|
||||
analyzer = SpacegroupAnalyzer(structure)
|
||||
structure_refined = analyzer.get_refined_structure()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
|
||||
cif_writer.write_file(tmp_file.name)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
os.unlink(tmp_file.name)
|
||||
return content
|
||||
|
||||
def optimize_structure(atoms: Atoms, output_format: str) -> str:
|
||||
"""
|
||||
优化晶体结构
|
||||
|
||||
Args:
|
||||
atoms: ASE Atoms对象
|
||||
output_format: 输出格式 (cif, xyz, vasp等)
|
||||
|
||||
Returns:
|
||||
包含优化结果的格式化字符串
|
||||
"""
|
||||
atoms.calc = calc
|
||||
|
||||
try:
|
||||
# 捕获优化过程的输出
|
||||
temp_output = StringIO()
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = temp_output
|
||||
|
||||
# 执行优化
|
||||
dyn = FIRE(FrechetCellFilter(atoms))
|
||||
dyn.run(fmax=config.FMAX)
|
||||
|
||||
# 恢复标准输出并获取日志
|
||||
sys.stdout = original_stdout
|
||||
optimization_log = temp_output.getvalue()
|
||||
temp_output.close()
|
||||
|
||||
# 获取总能量
|
||||
total_energy = atoms.get_potential_energy()
|
||||
|
||||
# 处理优化后的结构
|
||||
if output_format == "cif":
|
||||
optimized_structure = Structure.from_ase_atoms(atoms)
|
||||
content = generate_symmetry_cif(optimized_structure)
|
||||
content = remove_symmetry_equiv_xyz(content)
|
||||
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||||
write(tmp_file.name, atoms)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
# 格式化返回结果
|
||||
format_result = f"""
|
||||
The following is the optimized crystal structure information:
|
||||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||
**Total Energy: {total_energy} eV**
|
||||
|
||||
#### Optimizing Log:
|
||||
```text
|
||||
{optimization_log}
|
||||
```
|
||||
### Optimized {output_format.upper()} Content:
|
||||
```
|
||||
{content}
|
||||
```
|
||||
"""
|
||||
return format_result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to optimize structure: {str(e)}")
|
||||
raise e
|
||||
|
||||
@llm_tool(name="optimize_crystal_structure",
|
||||
description="Optimize crystal structure using FairChem model")
|
||||
async def optimize_crystal_structure(
|
||||
content: str,
|
||||
input_format: str = "cif",
|
||||
output_format: str = "cif"
|
||||
) -> str:
|
||||
"""
|
||||
Optimize crystal structure using FairChem model.
|
||||
|
||||
Args:
|
||||
content: Crystal structure content string
|
||||
input_format: Input format (cif, xyz, vasp)
|
||||
output_format: Output format (cif, xyz, vasp)
|
||||
|
||||
Returns:
|
||||
Optimized structure with energy and optimization log
|
||||
"""
|
||||
# 确保模型已初始化
|
||||
if calc is None:
|
||||
init_model()
|
||||
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_optimization():
|
||||
# 转换结构
|
||||
atoms = convert_structure(input_format, content)
|
||||
if atoms is None:
|
||||
raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确")
|
||||
|
||||
# 优化结构
|
||||
return optimize_structure(atoms, output_format)
|
||||
|
||||
try:
|
||||
# 直接返回结果或抛出异常
|
||||
return await asyncio.to_thread(run_optimization)
|
||||
except Exception as e:
|
||||
return handle_general_error(e)
|
||||
Reference in New Issue
Block a user