构建mars_toolkit,删除tools_for_ms

This commit is contained in:
lzy
2025-04-02 12:53:50 +08:00
parent 603304e10f
commit a77c2cd377
73 changed files with 1884 additions and 896 deletions

View File

@@ -0,0 +1,12 @@
"""
Compute Module
This module provides computational tools for materials science, including:
- Material generation
- Property prediction
- Structure optimization
"""
from mars_toolkit.compute.material_gen import generate_material
from mars_toolkit.compute.property_pred import predict_properties
from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure

View File

@@ -0,0 +1,423 @@
"""
Material Generation Module
This module provides functions for generating crystal structures with optional property constraints.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import ast
import json
import logging
import tempfile
import os
import datetime
import asyncio
import zipfile
import shutil
import re
from pathlib import Path
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
# 导入路径已更新
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
# 使用mattergen_wrapper
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from mattergen_wrapper import generator
CrystalGenerator = generator.CrystalGenerator
from mattergen.common.data.types import TargetProperty
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
from mattergen.common.utils.data_classes import (
PRETRAINED_MODEL_NAME,
MatterGenCheckpointInfo,
)
logger = logging.getLogger(__name__)
def format_cif_content(content):
"""
Format CIF content by removing unnecessary headers and organizing each CIF file.
Args:
content: String containing CIF content, possibly with PK headers
Returns:
Formatted string with each CIF file properly labeled and formatted
"""
# 如果内容为空,直接返回空字符串
if not content or content.strip() == '':
return ''
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
# 但我们需要保留这个字段在每个CIF文件中
cif_blocks = []
# 查找所有_chemical_formula_structural的位置
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
# 如果没有找到任何_chemical_formula_structural返回空字符串
if not formula_positions:
return ''
# 分割CIF块
for i in range(len(formula_positions)):
start_pos = formula_positions[i]
# 如果是最后一个块,结束位置是字符串末尾
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
cif_block = content[start_pos:end_pos].strip()
# 提取formula值
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
if formula_match:
formula = formula_match.group(1)
cif_blocks.append((formula, cif_block))
# 格式化输出
result = []
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]"
result.append(formatted)
return "\n".join(result)
def convert_values(data_str):
"""
将字符串转换为字典
Args:
data_str: JSON字符串
Returns:
解析后的数据,如果解析失败则返回原字符串
"""
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return data_str # 如果无法解析为JSON返回原字符串
return data
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
"""
Preprocess a property value based on its name, converting it to the appropriate type.
Args:
property_name: Name of the property
property_value: Value of the property (can be string, float, or int)
Returns:
Tuple of (property_name, processed_value)
Raises:
ValueError: If the property value is invalid for the given property name
"""
valid_properties = [
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
"energy_above_hull", "formation_energy_per_atom", "space_group",
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
]
if property_name not in valid_properties:
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
# Process property_value if it's a string
if isinstance(property_value, str):
try:
# Try to convert string to float for numeric properties
if property_name != "chemical_system":
property_value = float(property_value)
except ValueError:
# If conversion fails, keep as string (for chemical_system)
pass
# Handle special cases for properties that need specific types
if property_name == "chemical_system":
if isinstance(property_value, (int, float)):
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
property_value = str(property_value)
elif property_name == "space_group" :
space_group = property_value
if space_group < 1 or space_group > 230:
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
return property_name, property_value
def main(
output_path: str,
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
model_path: str | None = None,
batch_size: int = 2,
num_batches: int = 1,
config_overrides: list[str] | None = None,
checkpoint_epoch: Literal["best", "last"] | int = "last",
properties_to_condition_on: TargetProperty | None = None,
sampling_config_path: str | None = None,
sampling_config_name: str = "default",
sampling_config_overrides: list[str] | None = None,
record_trajectories: bool = True,
diffusion_guidance_factor: float | None = None,
strict_checkpoint_loading: bool = True,
target_compositions: list[dict[str, int]] | None = None,
):
"""
Evaluate diffusion model against molecular metrics.
Args:
model_path: Path to DiffusionLightningModule checkpoint directory.
output_path: Path to output directory.
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
record: Whether to record the trajectories of the generated structures. (default: True)
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
Only supported for models trained for crystal structure prediction (CSP) (default: None)
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
"""
assert (
pretrained_name is not None or model_path is not None
), "Either pretrained_name or model_path must be provided."
assert (
pretrained_name is None or model_path is None
), "Only one of pretrained_name or model_path can be provided."
if not os.path.exists(output_path):
os.makedirs(output_path)
sampling_config_overrides = sampling_config_overrides or []
config_overrides = config_overrides or []
properties_to_condition_on = properties_to_condition_on or {}
target_compositions = target_compositions or []
if pretrained_name is not None:
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
pretrained_name, config_overrides=config_overrides
)
else:
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch=checkpoint_epoch,
config_overrides=config_overrides,
strict_checkpoint_loading=strict_checkpoint_loading,
)
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=properties_to_condition_on,
batch_size=batch_size,
num_batches=num_batches,
sampling_config_name=sampling_config_name,
sampling_config_path=_sampling_config_path,
sampling_config_overrides=sampling_config_overrides,
record_trajectories=record_trajectories,
diffusion_guidance_factor=(
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
),
target_compositions_dict=target_compositions,
)
generator.generate(output_dir=Path(output_path))
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
async def generate_material(
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
) -> str:
"""
Generate crystal structures with optional property constraints.
This unified function can generate materials in three modes:
1. Unconditional generation (no properties specified)
2. Single property conditional generation (one property specified)
3. Multi-property conditional generation (multiple properties specified)
Args:
properties: Optional property constraints. Can be:
- None or empty dict for unconditional generation
- Dict with single key-value pair for single property conditioning
- Dict with multiple key-value pairs for multi-property conditioning
Valid property names include: "dft_band_gap", "chemical_system", etc.
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
Descriptive text with generated crystal structures in CIF format
"""
# 使用配置中的结果目录
output_dir = config.MATTERGENMODEL_RESULT_PATH
# 处理字符串输入(如果提供)
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError:
raise ValueError(f"Invalid properties JSON string: {properties}")
# 如果为None默认为空字典
properties = properties or {}
# 根据生成模式处理属性
if not properties:
# 无条件生成
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
properties_to_condition_on = None
generation_type = "unconditional"
property_description = "unconditionally"
else:
# 条件生成(单属性或多属性)
properties_to_condition_on = {}
# 处理每个属性
for property_name, property_value in properties.items():
_, processed_value = preprocess_property(property_name, property_value)
properties_to_condition_on[property_name] = processed_value
# 根据属性确定使用哪个模型
if len(properties) == 1:
# 单属性条件
property_name = list(properties.keys())[0]
property_to_model = {
"dft_mag_density": "dft_mag_density",
"dft_bulk_modulus": "dft_bulk_modulus",
"dft_shear_modulus": "dft_shear_modulus",
"energy_above_hull": "energy_above_hull",
"formation_energy_per_atom": "formation_energy_per_atom",
"space_group": "space_group",
"hhi_score": "hhi_score",
"ml_bulk_modulus": "ml_bulk_modulus",
"chemical_system": "chemical_system",
"dft_band_gap": "dft_band_gap"
}
model_dir = property_to_model.get(property_name, property_name)
generation_type = "single_property"
property_description = f"conditioned on {property_name} = {properties[property_name]}"
else:
# 多属性条件
property_keys = set(properties.keys())
if property_keys == {"dft_mag_density", "hhi_score"}:
model_dir = "dft_mag_density_hhi_score"
elif property_keys == {"chemical_system", "energy_above_hull"}:
model_dir = "chemical_system_energy_above_hull"
else:
# 如果没有特定的多属性模型,使用第一个属性的模型
first_property = list(properties.keys())[0]
model_dir = first_property
generation_type = "multi_property"
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
# 构建完整的模型路径
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
# 检查模型目录是否存在
if not os.path.exists(model_path):
# 如果特定模型不存在,回退到基础模型
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
# 使用适当的参数调用main函数
main(
output_path=output_dir,
model_path=model_path,
batch_size=batch_size,
num_batches=num_batches,
properties_to_condition_on=properties_to_condition_on,
record_trajectories=True,
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
)
# 创建字典存储文件内容
result_dict = {}
# 定义文件路径
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
# 读取CIF压缩文件
if os.path.exists(cif_zip_path):
with open(cif_zip_path, 'rb') as f:
result_dict['cif_content'] = f.read()
# 根据生成类型创建描述性提示
if generation_type == "unconditional":
title = "Generated Material Structures"
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
elif generation_type == "single_property":
property_name = list(properties.keys())[0]
property_value = properties[property_name]
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
else: # multi_property
title = "Generated Material Structures Conditioned on Multiple Properties"
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
# 创建完整的提示
prompt = f"""
# {title}
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
{'' if generation_type == 'unconditional' else f'''
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
the generation adheres to the specified property values. Higher values produce samples that more
closely match the target properties but may reduce diversity.
'''}
## CIF Files (Crystallographic Information Files)
- Standard format for crystallographic structures
- Contains unit cell parameters, atomic positions, and symmetry information
- Used by crystallographic software and visualization tools
```
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
```
{description}
You can use these structures for materials discovery, property prediction, or further analysis.
"""
# 清理文件(读取后删除)
try:
if os.path.exists(cif_zip_path):
os.remove(cif_zip_path)
if os.path.exists(xyz_file_path):
os.remove(xyz_file_path)
if os.path.exists(trajectories_zip_path):
os.remove(trajectories_zip_path)
except Exception as e:
logger.warning(f"Error cleaning up files: {e}")
return prompt

View File

@@ -0,0 +1,72 @@
"""
Property Prediction Module
This module provides functions for predicting properties of crystal structures.
"""
import asyncio
import torch
import numpy as np
from ase.units import GPa
from mattersim.forcefield import MatterSimCalculator
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.compute.structure_opt import convert_structure
@llm_tool(
name="predict_properties",
description="Predict energy, forces, and stress of crystal structures based on CIF string",
)
async def predict_properties(cif_content: str) -> str:
"""
Use MatterSim to predict energy, forces, and stress of crystal structures.
Args:
cif_content: Crystal structure string in CIF format
Returns:
String containing prediction results
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_prediction():
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
structure = convert_structure("cif", cif_content)
if structure is None:
return "Unable to parse CIF string. Please check if the format is correct."
# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
# 使用 MatterSimCalculator 计算属性
structure.calc = MatterSimCalculator(device=device)
# 直接获取能量、力和应力
energy = structure.get_potential_energy()
forces = structure.get_forces()
stresses = structure.get_stress(voigt=False)
# 计算每原子能量
num_atoms = len(structure)
energy_per_atom = energy / num_atoms
# 计算应力GPa和eV/A^3格式
stresses_ev_a3 = stresses
stresses_gpa = stresses / GPa
# 构建返回的提示信息
prompt = f"""
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
Prediction results using the provided CIF structure:
- Total Energy (eV): {energy}
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
"""
return prompt
# 异步执行预测操作
return await asyncio.to_thread(run_prediction)

View File

@@ -0,0 +1,192 @@
"""
Structure Optimization Module
This module provides functions for optimizing crystal structures using the FairChem model.
"""
import asyncio
from io import StringIO
import sys
import tempfile
import os
import logging
from typing import Optional, Dict, Any
from ase.io import read, write
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
from mars_toolkit.core.error_handlers import handle_general_error
logger = logging.getLogger(__name__)
# 初始化FairChem模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
if calc is not None:
return
try:
from fairchem.core import OCPCalculator
calc = OCPCalculator(checkpoint_path=config.FAIRCHEM_MODEL_PATH)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""
将输入内容转换为Atoms对象
Args:
input_format: 输入格式 (cif, xyz, vasp等)
content: 结构内容字符串
Returns:
ASE Atoms对象如果转换失败则返回None
"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""
生成对称性CIF
Args:
structure: Pymatgen Structure对象
Returns:
CIF格式的字符串
"""
analyzer = SpacegroupAnalyzer(structure)
structure_refined = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
return content
def optimize_structure(atoms: Atoms, output_format: str) -> str:
"""
优化晶体结构
Args:
atoms: ASE Atoms对象
output_format: 输出格式 (cif, xyz, vasp等)
Returns:
包含优化结果的格式化字符串
"""
atoms.calc = calc
try:
# 捕获优化过程的输出
temp_output = StringIO()
original_stdout = sys.stdout
sys.stdout = temp_output
# 执行优化
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=config.FMAX)
# 恢复标准输出并获取日志
sys.stdout = original_stdout
optimization_log = temp_output.getvalue()
temp_output.close()
# 获取总能量
total_energy = atoms.get_potential_energy()
# 处理优化后的结构
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
content = remove_symmetry_equiv_xyz(content)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
**Total Energy: {total_energy} eV**
#### Optimizing Log:
```text
{optimization_log}
```
### Optimized {output_format.upper()} Content:
```
{content}
```
"""
return format_result
except Exception as e:
logger.error(f"Failed to optimize structure: {str(e)}")
raise e
@llm_tool(name="optimize_crystal_structure",
description="Optimize crystal structure using FairChem model")
async def optimize_crystal_structure(
content: str,
input_format: str = "cif",
output_format: str = "cif"
) -> str:
"""
Optimize crystal structure using FairChem model.
Args:
content: Crystal structure content string
input_format: Input format (cif, xyz, vasp)
output_format: Output format (cif, xyz, vasp)
Returns:
Optimized structure with energy and optimization log
"""
# 确保模型已初始化
if calc is None:
init_model()
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_optimization():
# 转换结构
atoms = convert_structure(input_format, content)
if atoms is None:
raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确")
# 优化结构
return optimize_structure(atoms, output_format)
try:
# 直接返回结果或抛出异常
return await asyncio.to_thread(run_optimization)
except Exception as e:
return handle_general_error(e)