mattergen调用指定GPU&规范化mattergen的输入

This commit is contained in:
lzy
2025-04-05 20:19:43 +08:00
parent bac8f067e0
commit 71d8dabd17
6 changed files with 379 additions and 45 deletions

View File

@@ -12,6 +12,7 @@ import json
from pathlib import Path
from typing import Dict, Any, Optional, Union, List
import threading
import torch
# 导入mattergen相关模块
# import sys
@@ -38,6 +39,23 @@ class MatterGenService:
_instance = None
_lock = threading.Lock()
# 模型到GPU ID的映射
MODEL_TO_GPU = {
"mattergen_base": "0", # 基础模型使用GPU 0
"dft_mag_density": "1", # 磁密度模型使用GPU 1
"dft_bulk_modulus": "2", # 体积模量模型使用GPU 2
"dft_shear_modulus": "3", # 剪切模量模型使用GPU 3
"energy_above_hull": "4", # 能量模型使用GPU 4
"formation_energy_per_atom": "5", # 形成能模型使用GPU 5
"space_group": "6", # 空间群模型使用GPU 6
"hhi_score": "7", # HHI评分模型使用GPU 7
"ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0
"chemical_system": "1", # 化学系统模型使用GPU 1
"dft_band_gap": "2", # 带隙模型使用GPU 2
"dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3
"chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4
}
@classmethod
def get_instance(cls):
"""
@@ -125,13 +143,14 @@ class MatterGenService:
diffusion_guidance_factor: Controls adherence to target properties
Returns:
tuple: (generator, generator_key, properties_to_condition_on)
tuple: (generator, generator_key, properties_to_condition_on, gpu_id)
"""
# 如果没有属性约束,使用基础生成器
if not properties:
if "base" not in self._generators:
self._init_base_generator()
return self._generators.get("base"), "base", None
gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0
return self._generators.get("base"), "base", None, gpu_id
# 处理属性约束
properties_to_condition_on = {}
@@ -171,6 +190,9 @@ class MatterGenService:
model_dir = first_property
generator_key = f"multi_{first_property}_etc"
# 获取对应的GPU ID
gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0
# 构建完整的模型路径
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
@@ -188,7 +210,7 @@ class MatterGenService:
generator.batch_size = batch_size
generator.num_batches = num_batches
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
return generator, generator_key, properties_to_condition_on
return generator, generator_key, properties_to_condition_on, gpu_id
# 创建新的生成器
try:
@@ -216,13 +238,14 @@ class MatterGenService:
self._generators[generator_key] = generator
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
return generator, generator_key, properties_to_condition_on
return generator, generator_key, properties_to_condition_on, gpu_id
except Exception as e:
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
# 回退到基础生成器
if "base" not in self._generators:
self._init_base_generator()
return self._generators.get("base"), "base", None
base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0")
return self._generators.get("base"), "base", None, base_gpu_id
def generate(
self,
@@ -255,14 +278,24 @@ class MatterGenService:
# 如果为None默认为空字典
properties = properties or {}
# 获取或创建生成器
generator, generator_key, properties_to_condition_on = self._get_or_create_generator(
# 获取或创建生成器和GPU ID
generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator(
properties, batch_size, num_batches, diffusion_guidance_factor
)
print("gpu_id",gpu_id)
if generator is None:
return "Error: Failed to initialize MatterGen generator"
# 使用torch.cuda.set_device()直接设置当前GPU
try:
# 将字符串类型的gpu_id转换为整数
cuda_device_id = int(gpu_id)
torch.cuda.set_device(cuda_device_id)
logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}")
print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}")
except Exception as e:
logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.")
# 生成结构
try:
generator.generate(output_dir=Path(self._output_dir))
@@ -339,4 +372,7 @@ You can use these structures for materials discovery, property prediction, or fu
except Exception as e:
logger.warning(f"Error cleaning up files: {e}")
# GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理
logger.info(f"Generation completed on GPU for model {generator_key}")
return prompt