357 lines
14 KiB
Python
357 lines
14 KiB
Python
|
||
import ast
|
||
import json
|
||
import logging
|
||
import tempfile
|
||
import os
|
||
import datetime
|
||
import asyncio
|
||
import zipfile
|
||
import shutil
|
||
import re
|
||
import multiprocessing
|
||
from multiprocessing import Process, Queue
|
||
from pathlib import Path
|
||
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
||
|
||
# 设置多进程启动方法为spawn,解决CUDA初始化错误
|
||
try:
|
||
multiprocessing.set_start_method('spawn', force=True)
|
||
except RuntimeError:
|
||
# 如果已经设置过启动方法,会抛出RuntimeError
|
||
pass
|
||
|
||
from ase.optimize import FIRE
|
||
from ase.filters import FrechetCellFilter
|
||
from ase.atoms import Atoms
|
||
from ase.io import read, write
|
||
from pymatgen.core.structure import Structure
|
||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||
from pymatgen.io.cif import CifWriter
|
||
|
||
# 导入路径已更新
|
||
from mars_toolkit.core.llm_tools import llm_tool
|
||
from mars_toolkit.core.config import config
|
||
|
||
# 使用mattergen_wrapper
|
||
import sys
|
||
import os
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
||
from ..core.mattergen_wrapper import *
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _process_generate_material_worker(args_queue, result_queue):
|
||
"""
|
||
在新进程中处理材料生成的工作函数
|
||
|
||
Args:
|
||
args_queue: 包含生成参数的队列
|
||
result_queue: 用于返回结果的队列
|
||
"""
|
||
try:
|
||
# 配置日志
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
logger.info("子进程开始执行材料生成...")
|
||
|
||
# 从队列获取参数
|
||
args = args_queue.get()
|
||
logger.info(f"子进程获取到参数: {args}")
|
||
|
||
# 导入MatterGenService
|
||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||
logger.info("子进程成功导入MatterGenService")
|
||
|
||
# 获取MatterGenService实例
|
||
service = MatterGenService.get_instance()
|
||
logger.info("子进程成功获取MatterGenService实例")
|
||
|
||
# 使用服务生成材料
|
||
logger.info("子进程开始调用generate方法...")
|
||
result = service.generate(**args)
|
||
logger.info("子进程generate方法调用完成")
|
||
|
||
# 将结果放入结果队列
|
||
result_queue.put(result)
|
||
logger.info("子进程材料生成完成,结果已放入队列")
|
||
except Exception as e:
|
||
# 如果发生错误,将错误信息放入结果队列
|
||
import traceback
|
||
error_msg = f"材料生成过程中出错: {str(e)}\n{traceback.format_exc()}"
|
||
import logging
|
||
logging.getLogger(__name__).error(error_msg)
|
||
result_queue.put(f"Error: {error_msg}")
|
||
|
||
|
||
def format_cif_content(content):
|
||
"""
|
||
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
||
|
||
Args:
|
||
content: String containing CIF content, possibly with PK headers
|
||
|
||
Returns:
|
||
Formatted string with each CIF file properly labeled and formatted
|
||
"""
|
||
# 如果内容为空,直接返回空字符串
|
||
if not content or content.strip() == '':
|
||
return ''
|
||
|
||
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
|
||
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
|
||
|
||
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
|
||
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
|
||
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
|
||
|
||
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
|
||
# 但我们需要保留这个字段在每个CIF文件中
|
||
cif_blocks = []
|
||
|
||
# 查找所有_chemical_formula_structural的位置
|
||
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
|
||
|
||
# 如果没有找到任何_chemical_formula_structural,返回空字符串
|
||
if not formula_positions:
|
||
return ''
|
||
|
||
# 分割CIF块
|
||
for i in range(len(formula_positions)):
|
||
start_pos = formula_positions[i]
|
||
# 如果是最后一个块,结束位置是字符串末尾
|
||
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
|
||
|
||
cif_block = content[start_pos:end_pos].strip()
|
||
|
||
# 提取formula值
|
||
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
|
||
if formula_match:
|
||
formula = formula_match.group(1)
|
||
cif_blocks.append((formula, cif_block))
|
||
|
||
# 格式化输出
|
||
result = []
|
||
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
|
||
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]"
|
||
result.append(formatted)
|
||
|
||
return "\n".join(result)
|
||
|
||
|
||
def convert_values(data_str):
|
||
"""
|
||
将字符串转换为字典
|
||
|
||
Args:
|
||
data_str: JSON字符串
|
||
|
||
Returns:
|
||
解析后的数据,如果解析失败则返回原字符串
|
||
"""
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
return data_str # 如果无法解析为JSON,返回原字符串
|
||
|
||
return data
|
||
|
||
|
||
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
|
||
"""
|
||
Preprocess a property value based on its name, converting it to the appropriate type.
|
||
|
||
Args:
|
||
property_name: Name of the property
|
||
property_value: Value of the property (can be string, float, or int)
|
||
|
||
Returns:
|
||
Tuple of (property_name, processed_value)
|
||
|
||
Raises:
|
||
ValueError: If the property value is invalid for the given property name
|
||
"""
|
||
valid_properties = [
|
||
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
|
||
"energy_above_hull", "formation_energy_per_atom", "space_group",
|
||
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
|
||
]
|
||
|
||
if property_name not in valid_properties:
|
||
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
|
||
|
||
# Process property_value if it's a string
|
||
if isinstance(property_value, str):
|
||
try:
|
||
# Try to convert string to float for numeric properties
|
||
if property_name != "chemical_system":
|
||
property_value = float(property_value)
|
||
except ValueError:
|
||
# If conversion fails, keep as string (for chemical_system)
|
||
pass
|
||
|
||
# Handle special cases for properties that need specific types
|
||
if property_name == "chemical_system":
|
||
if isinstance(property_value, (int, float)):
|
||
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
|
||
property_value = str(property_value)
|
||
elif property_name == "space_group" :
|
||
space_group = property_value
|
||
if space_group < 1 or space_group > 230:
|
||
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
|
||
|
||
return property_name, property_value
|
||
|
||
|
||
def main(
|
||
output_path: str,
|
||
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
|
||
model_path: str | None = None,
|
||
batch_size: int = 2,
|
||
num_batches: int = 1,
|
||
config_overrides: list[str] | None = None,
|
||
checkpoint_epoch: Literal["best", "last"] | int = "last",
|
||
properties_to_condition_on: TargetProperty | None = None,
|
||
sampling_config_path: str | None = None,
|
||
sampling_config_name: str = "default",
|
||
sampling_config_overrides: list[str] | None = None,
|
||
record_trajectories: bool = True,
|
||
diffusion_guidance_factor: float | None = None,
|
||
strict_checkpoint_loading: bool = True,
|
||
target_compositions: list[dict[str, int]] | None = None,
|
||
):
|
||
"""
|
||
Evaluate diffusion model against molecular metrics.
|
||
|
||
Args:
|
||
model_path: Path to DiffusionLightningModule checkpoint directory.
|
||
output_path: Path to output directory.
|
||
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
|
||
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
|
||
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
|
||
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
|
||
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
|
||
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
|
||
record: Whether to record the trajectories of the generated structures. (default: True)
|
||
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
|
||
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
|
||
Only supported for models trained for crystal structure prediction (CSP) (default: None)
|
||
|
||
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
|
||
"""
|
||
assert (
|
||
pretrained_name is not None or model_path is not None
|
||
), "Either pretrained_name or model_path must be provided."
|
||
assert (
|
||
pretrained_name is None or model_path is None
|
||
), "Only one of pretrained_name or model_path can be provided."
|
||
|
||
if not os.path.exists(output_path):
|
||
os.makedirs(output_path)
|
||
|
||
sampling_config_overrides = sampling_config_overrides or []
|
||
config_overrides = config_overrides or []
|
||
properties_to_condition_on = properties_to_condition_on or {}
|
||
target_compositions = target_compositions or []
|
||
|
||
if pretrained_name is not None:
|
||
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
|
||
pretrained_name, config_overrides=config_overrides
|
||
)
|
||
else:
|
||
checkpoint_info = MatterGenCheckpointInfo(
|
||
model_path=Path(model_path).resolve(),
|
||
load_epoch=checkpoint_epoch,
|
||
config_overrides=config_overrides,
|
||
strict_checkpoint_loading=strict_checkpoint_loading,
|
||
)
|
||
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
|
||
generator = CrystalGenerator(
|
||
checkpoint_info=checkpoint_info,
|
||
properties_to_condition_on=properties_to_condition_on,
|
||
batch_size=batch_size,
|
||
num_batches=num_batches,
|
||
sampling_config_name=sampling_config_name,
|
||
sampling_config_path=_sampling_config_path,
|
||
sampling_config_overrides=sampling_config_overrides,
|
||
record_trajectories=record_trajectories,
|
||
diffusion_guidance_factor=(
|
||
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
|
||
),
|
||
target_compositions_dict=target_compositions,
|
||
)
|
||
generator.generate(output_dir=Path(output_path))
|
||
|
||
|
||
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
|
||
def generate_material(
|
||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||
batch_size: int = 2,
|
||
num_batches: int = 1,
|
||
diffusion_guidance_factor: float = 2.0
|
||
) -> str:
|
||
"""
|
||
Generate crystal structures with optional property constraints.
|
||
|
||
This unified function can generate materials in three modes:
|
||
1. Unconditional generation (no properties specified)
|
||
2. Single property conditional generation (one property specified)
|
||
3. Multi-property conditional generation (multiple properties specified)
|
||
|
||
Args:
|
||
properties: Optional property constraints. Can be:
|
||
- None or empty dict for unconditional generation
|
||
- Dict with single key-value pair for single property conditioning
|
||
- Dict with multiple key-value pairs for multi-property conditioning
|
||
Valid property names include: "dft_band_gap", "chemical_system", etc.
|
||
batch_size: Number of structures per batch
|
||
num_batches: Number of batches to generate
|
||
diffusion_guidance_factor: Controls adherence to target properties
|
||
|
||
Returns:
|
||
Descriptive text with generated crystal structures in CIF format
|
||
"""
|
||
# # 创建队列用于进程间通信
|
||
# args_queue = Queue()
|
||
# result_queue = Queue()
|
||
|
||
# # 将参数放入队列
|
||
# args_queue.put({
|
||
# "properties": properties,
|
||
# "batch_size": batch_size,
|
||
# "num_batches": num_batches,
|
||
# "diffusion_guidance_factor": diffusion_guidance_factor
|
||
# })
|
||
|
||
# # 创建并启动新进程
|
||
# logger.info("启动新进程处理材料生成...")
|
||
# p = Process(target=_process_generate_material_worker, args=(args_queue, result_queue))
|
||
# p.start()
|
||
|
||
# # 等待进程完成并获取结果
|
||
# p.join()
|
||
# result = result_queue.get()
|
||
|
||
# # 检查结果是否为错误信息
|
||
# if isinstance(result, str) and result.startswith("Error:"):
|
||
# # 记录错误日志
|
||
# logger.error(result)
|
||
|
||
# 导入MatterGenService
|
||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||
logger.info("子进程成功导入MatterGenService")
|
||
|
||
# 获取MatterGenService实例
|
||
service = MatterGenService.get_instance()
|
||
logger.info("子进程成功获取MatterGenService实例")
|
||
|
||
# 使用服务生成材料
|
||
logger.info("子进程开始调用generate方法...")
|
||
result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
|
||
logger.info("子进程generate方法调用完成")
|
||
if "Error generating structures" in result:
|
||
return f"Error: Invalid properties {properties}."
|
||
else:
|
||
return result
|