379 lines
16 KiB
Python
379 lines
16 KiB
Python
"""
|
||
MatterGen service for mars_toolkit.
|
||
|
||
This module provides a service for generating crystal structures using MatterGen.
|
||
The service initializes the CrystalGenerator once and reuses it for multiple
|
||
generation requests, improving performance.
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Dict, Any, Optional, Union, List
|
||
import threading
|
||
import torch
|
||
|
||
# 导入mattergen相关模块
|
||
# import sys
|
||
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../')))
|
||
# from mars_toolkit.core.mattergen_wrapper import generator
|
||
# CrystalGenerator = generator.CrystalGenerator
|
||
# from mattergen.common.data.types import TargetProperty
|
||
# from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||
# from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
|
||
from ..core.mattergen_wrapper import *
|
||
# 导入mars_toolkit配置
|
||
from mars_toolkit.core.config import config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class MatterGenService:
|
||
"""
|
||
Service for generating crystal structures using MatterGen.
|
||
|
||
This service initializes the CrystalGenerator once and reuses it for multiple
|
||
generation requests, improving performance.
|
||
"""
|
||
|
||
_instance = None
|
||
_lock = threading.Lock()
|
||
|
||
# 模型到GPU ID的映射
|
||
MODEL_TO_GPU = {
|
||
"mattergen_base": "0", # 基础模型使用GPU 0
|
||
"dft_mag_density": "1", # 磁密度模型使用GPU 1
|
||
"dft_bulk_modulus": "2", # 体积模量模型使用GPU 2
|
||
"dft_shear_modulus": "3", # 剪切模量模型使用GPU 3
|
||
"energy_above_hull": "4", # 能量模型使用GPU 4
|
||
"formation_energy_per_atom": "5", # 形成能模型使用GPU 5
|
||
"space_group": "6", # 空间群模型使用GPU 6
|
||
"hhi_score": "7", # HHI评分模型使用GPU 7
|
||
"ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0
|
||
"chemical_system": "1", # 化学系统模型使用GPU 1
|
||
"dft_band_gap": "2", # 带隙模型使用GPU 2
|
||
"dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3
|
||
"chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4
|
||
}
|
||
|
||
@classmethod
|
||
def get_instance(cls):
|
||
"""
|
||
Get the singleton instance of MatterGenService.
|
||
|
||
Returns:
|
||
MatterGenService: The singleton instance.
|
||
"""
|
||
if cls._instance is None:
|
||
with cls._lock:
|
||
if cls._instance is None:
|
||
cls._instance = cls()
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
"""
|
||
Initialize the MatterGenService.
|
||
|
||
This initializes the base generator without any property conditioning.
|
||
Specific generators for different property conditions will be initialized
|
||
on demand.
|
||
"""
|
||
self._generators = {}
|
||
self._output_dir = config.MATTERGENMODEL_RESULT_PATH
|
||
|
||
# 确保输出目录存在
|
||
if not os.path.exists(self._output_dir):
|
||
os.makedirs(self._output_dir)
|
||
|
||
# 初始化基础生成器(无条件生成)
|
||
self._init_base_generator()
|
||
|
||
def _init_base_generator(self):
|
||
"""
|
||
Initialize the base generator for unconditional generation.
|
||
"""
|
||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||
|
||
if not os.path.exists(model_path):
|
||
logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.")
|
||
return
|
||
|
||
logger.info(f"Initializing base MatterGen generator from {model_path}")
|
||
|
||
try:
|
||
checkpoint_info = MatterGenCheckpointInfo(
|
||
model_path=Path(model_path).resolve(),
|
||
load_epoch="last",
|
||
config_overrides=[],
|
||
strict_checkpoint_loading=True,
|
||
)
|
||
|
||
generator = CrystalGenerator(
|
||
checkpoint_info=checkpoint_info,
|
||
properties_to_condition_on=None,
|
||
batch_size=2, # 默认值,可在生成时覆盖
|
||
num_batches=1, # 默认值,可在生成时覆盖
|
||
sampling_config_name="default",
|
||
sampling_config_path=None,
|
||
sampling_config_overrides=[],
|
||
record_trajectories=True,
|
||
diffusion_guidance_factor=0.0,
|
||
target_compositions_dict=[],
|
||
)
|
||
|
||
self._generators["base"] = generator
|
||
logger.info("Base MatterGen generator initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize base MatterGen generator: {e}")
|
||
|
||
def _get_or_create_generator(
|
||
self,
|
||
properties: Optional[Dict[str, Any]] = None,
|
||
batch_size: int = 2,
|
||
num_batches: int = 1,
|
||
diffusion_guidance_factor: float = 2.0
|
||
):
|
||
"""
|
||
Get or create a generator for the specified properties.
|
||
|
||
Args:
|
||
properties: Optional property constraints
|
||
batch_size: Number of structures per batch
|
||
num_batches: Number of batches to generate
|
||
diffusion_guidance_factor: Controls adherence to target properties
|
||
|
||
Returns:
|
||
tuple: (generator, generator_key, properties_to_condition_on, gpu_id)
|
||
"""
|
||
# 如果没有属性约束,使用基础生成器
|
||
if not properties:
|
||
if "base" not in self._generators:
|
||
self._init_base_generator()
|
||
gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0
|
||
return self._generators.get("base"), "base", None, gpu_id
|
||
|
||
# 处理属性约束
|
||
properties_to_condition_on = {}
|
||
for property_name, property_value in properties.items():
|
||
properties_to_condition_on[property_name] = property_value
|
||
|
||
# 确定模型目录
|
||
if len(properties) == 1:
|
||
# 单属性条件
|
||
property_name = list(properties.keys())[0]
|
||
property_to_model = {
|
||
"dft_mag_density": "dft_mag_density",
|
||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||
"dft_shear_modulus": "dft_shear_modulus",
|
||
"energy_above_hull": "energy_above_hull",
|
||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||
"space_group": "space_group",
|
||
"hhi_score": "hhi_score",
|
||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||
"chemical_system": "chemical_system",
|
||
"dft_band_gap": "dft_band_gap"
|
||
}
|
||
model_dir = property_to_model.get(property_name, property_name)
|
||
generator_key = f"single_{property_name}"
|
||
else:
|
||
# 多属性条件
|
||
property_keys = set(properties.keys())
|
||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||
model_dir = "dft_mag_density_hhi_score"
|
||
generator_key = "multi_dft_mag_density_hhi_score"
|
||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||
model_dir = "chemical_system_energy_above_hull"
|
||
generator_key = "multi_chemical_system_energy_above_hull"
|
||
else:
|
||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||
first_property = list(properties.keys())[0]
|
||
model_dir = first_property
|
||
generator_key = f"multi_{first_property}_etc"
|
||
|
||
# 获取对应的GPU ID
|
||
gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0
|
||
|
||
# 构建完整的模型路径
|
||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
|
||
|
||
# 检查模型目录是否存在
|
||
if not os.path.exists(model_path):
|
||
# 如果特定模型不存在,回退到基础模型
|
||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||
generator_key = "base"
|
||
|
||
# 检查是否已经有这个生成器
|
||
if generator_key in self._generators:
|
||
# 更新生成器的参数
|
||
generator = self._generators[generator_key]
|
||
generator.batch_size = batch_size
|
||
generator.num_batches = num_batches
|
||
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
|
||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||
|
||
# 创建新的生成器
|
||
try:
|
||
logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}")
|
||
|
||
checkpoint_info = MatterGenCheckpointInfo(
|
||
model_path=Path(model_path).resolve(),
|
||
load_epoch="last",
|
||
config_overrides=[],
|
||
strict_checkpoint_loading=True,
|
||
)
|
||
|
||
generator = CrystalGenerator(
|
||
checkpoint_info=checkpoint_info,
|
||
properties_to_condition_on=properties_to_condition_on,
|
||
batch_size=batch_size,
|
||
num_batches=num_batches,
|
||
sampling_config_name="default",
|
||
sampling_config_path=None,
|
||
sampling_config_overrides=[],
|
||
record_trajectories=True,
|
||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0,
|
||
target_compositions_dict=[],
|
||
)
|
||
|
||
self._generators[generator_key] = generator
|
||
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
|
||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
|
||
# 回退到基础生成器
|
||
if "base" not in self._generators:
|
||
self._init_base_generator()
|
||
base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0")
|
||
return self._generators.get("base"), "base", None, base_gpu_id
|
||
|
||
def generate(
|
||
self,
|
||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||
batch_size: int = 2,
|
||
num_batches: int = 1,
|
||
diffusion_guidance_factor: float = 2.0
|
||
) -> str:
|
||
"""
|
||
Generate crystal structures with optional property constraints.
|
||
|
||
Args:
|
||
properties: Optional property constraints
|
||
batch_size: Number of structures per batch
|
||
num_batches: Number of batches to generate
|
||
diffusion_guidance_factor: Controls adherence to target properties
|
||
|
||
Returns:
|
||
str: Descriptive text with generated crystal structures in CIF format
|
||
"""
|
||
from mars_toolkit.compute.material_gen import format_cif_content
|
||
|
||
# 处理字符串输入(如果提供)
|
||
if isinstance(properties, str):
|
||
try:
|
||
properties = json.loads(properties)
|
||
except json.JSONDecodeError:
|
||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||
|
||
# 如果为None,默认为空字典
|
||
properties = properties or {}
|
||
|
||
# 获取或创建生成器和GPU ID
|
||
generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator(
|
||
properties, batch_size, num_batches, diffusion_guidance_factor
|
||
)
|
||
print("gpu_id",gpu_id)
|
||
if generator is None:
|
||
return "Error: Failed to initialize MatterGen generator"
|
||
|
||
# 使用torch.cuda.set_device()直接设置当前GPU
|
||
try:
|
||
# 将字符串类型的gpu_id转换为整数
|
||
cuda_device_id = int(gpu_id)
|
||
torch.cuda.set_device(cuda_device_id)
|
||
logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}")
|
||
print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}")
|
||
except Exception as e:
|
||
logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.")
|
||
|
||
# 生成结构
|
||
try:
|
||
generator.generate(output_dir=Path(self._output_dir))
|
||
except Exception as e:
|
||
logger.error(f"Error generating structures: {e}")
|
||
return f"Error generating structures: {e}"
|
||
|
||
# 创建字典存储文件内容
|
||
result_dict = {}
|
||
|
||
# 定义文件路径
|
||
cif_zip_path = os.path.join(self._output_dir, "generated_crystals_cif.zip")
|
||
xyz_file_path = os.path.join(self._output_dir, "generated_crystals.extxyz")
|
||
trajectories_zip_path = os.path.join(self._output_dir, "generated_trajectories.zip")
|
||
|
||
# 读取CIF压缩文件
|
||
if os.path.exists(cif_zip_path):
|
||
with open(cif_zip_path, 'rb') as f:
|
||
result_dict['cif_content'] = f.read()
|
||
|
||
# 根据生成类型创建描述性提示
|
||
if not properties:
|
||
generation_type = "unconditional"
|
||
title = "Generated Material Structures"
|
||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||
property_description = "unconditionally"
|
||
elif len(properties) == 1:
|
||
generation_type = "single_property"
|
||
property_name = list(properties.keys())[0]
|
||
property_value = properties[property_name]
|
||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||
property_description = f"conditioned on {property_name} = {property_value}"
|
||
else:
|
||
generation_type = "multi_property"
|
||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||
|
||
# 创建完整的提示
|
||
prompt = f"""
|
||
# {title}
|
||
|
||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||
|
||
{'' if generation_type == 'unconditional' else f'''
|
||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||
the generation adheres to the specified property values. Higher values produce samples that more
|
||
closely match the target properties but may reduce diversity.
|
||
'''}
|
||
|
||
## CIF Files (Crystallographic Information Files)
|
||
|
||
- Standard format for crystallographic structures
|
||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||
- Used by crystallographic software and visualization tools
|
||
|
||
```
|
||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||
```
|
||
|
||
{description}
|
||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||
"""
|
||
|
||
# 清理文件(读取后删除)
|
||
try:
|
||
if os.path.exists(cif_zip_path):
|
||
os.remove(cif_zip_path)
|
||
if os.path.exists(xyz_file_path):
|
||
os.remove(xyz_file_path)
|
||
if os.path.exists(trajectories_zip_path):
|
||
os.remove(trajectories_zip_path)
|
||
except Exception as e:
|
||
logger.warning(f"Error cleaning up files: {e}")
|
||
|
||
# GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理
|
||
logger.info(f"Generation completed on GPU for model {generator_key}")
|
||
|
||
return prompt
|