import ast import json import logging import tempfile import os import datetime import asyncio import zipfile import shutil import re import multiprocessing from multiprocessing import Process, Queue from pathlib import Path from typing import Literal, Dict, Any, Tuple, Union, Optional, List # 设置多进程启动方法为spawn,解决CUDA初始化错误 try: multiprocessing.set_start_method('spawn', force=True) except RuntimeError: # 如果已经设置过启动方法,会抛出RuntimeError pass from ase.optimize import FIRE from ase.filters import FrechetCellFilter from ase.atoms import Atoms from ase.io import read, write from pymatgen.core.structure import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.io.cif import CifWriter # 导入路径已更新 from mars_toolkit.core.llm_tools import llm_tool from mars_toolkit.core.config import config # 使用mattergen_wrapper import sys import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) from ..core.mattergen_wrapper import * logger = logging.getLogger(__name__) def _process_generate_material_worker(args_queue, result_queue): """ 在新进程中处理材料生成的工作函数 Args: args_queue: 包含生成参数的队列 result_queue: 用于返回结果的队列 """ try: # 配置日志 import logging logger = logging.getLogger(__name__) logger.info("子进程开始执行材料生成...") # 从队列获取参数 args = args_queue.get() logger.info(f"子进程获取到参数: {args}") # 导入MatterGenService from mars_toolkit.services.mattergen_service import MatterGenService logger.info("子进程成功导入MatterGenService") # 获取MatterGenService实例 service = MatterGenService.get_instance() logger.info("子进程成功获取MatterGenService实例") # 使用服务生成材料 logger.info("子进程开始调用generate方法...") result = service.generate(**args) logger.info("子进程generate方法调用完成") # 将结果放入结果队列 result_queue.put(result) logger.info("子进程材料生成完成,结果已放入队列") except Exception as e: # 如果发生错误,将错误信息放入结果队列 import traceback error_msg = f"材料生成过程中出错: {str(e)}\n{traceback.format_exc()}" import logging logging.getLogger(__name__).error(error_msg) result_queue.put(f"Error: {error_msg}") def format_cif_content(content): """ Format CIF content by removing unnecessary headers and organizing each CIF file. Args: content: String containing CIF content, possibly with PK headers Returns: Formatted string with each CIF file properly labeled and formatted """ # 如果内容为空,直接返回空字符串 if not content or content.strip() == '': return '' # 删除从PK开始到第一个_chemical_formula_structural之前的所有内容 content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL) # 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容 content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL) content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL) # 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件 # 但我们需要保留这个字段在每个CIF文件中 cif_blocks = [] # 查找所有_chemical_formula_structural的位置 formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)] # 如果没有找到任何_chemical_formula_structural,返回空字符串 if not formula_positions: return '' # 分割CIF块 for i in range(len(formula_positions)): start_pos = formula_positions[i] # 如果是最后一个块,结束位置是字符串末尾 end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content) cif_block = content[start_pos:end_pos].strip() # 提取formula值 formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block) if formula_match: formula = formula_match.group(1) cif_blocks.append((formula, cif_block)) # 格式化输出 result = [] for i, (formula, cif_content) in enumerate(cif_blocks, 1): formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]" result.append(formatted) return "\n".join(result) def convert_values(data_str): """ 将字符串转换为字典 Args: data_str: JSON字符串 Returns: 解析后的数据,如果解析失败则返回原字符串 """ try: data = json.loads(data_str) except json.JSONDecodeError: return data_str # 如果无法解析为JSON,返回原字符串 return data def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]: """ Preprocess a property value based on its name, converting it to the appropriate type. Args: property_name: Name of the property property_value: Value of the property (can be string, float, or int) Returns: Tuple of (property_name, processed_value) Raises: ValueError: If the property value is invalid for the given property name """ valid_properties = [ "dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus", "energy_above_hull", "formation_energy_per_atom", "space_group", "hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap" ] if property_name not in valid_properties: raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}") # Process property_value if it's a string if isinstance(property_value, str): try: # Try to convert string to float for numeric properties if property_name != "chemical_system": property_value = float(property_value) except ValueError: # If conversion fails, keep as string (for chemical_system) pass # Handle special cases for properties that need specific types if property_name == "chemical_system": if isinstance(property_value, (int, float)): logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property") property_value = str(property_value) elif property_name == "space_group" : space_group = property_value if space_group < 1 or space_group > 230: raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.") return property_name, property_value def main( output_path: str, pretrained_name: PRETRAINED_MODEL_NAME | None = None, model_path: str | None = None, batch_size: int = 2, num_batches: int = 1, config_overrides: list[str] | None = None, checkpoint_epoch: Literal["best", "last"] | int = "last", properties_to_condition_on: TargetProperty | None = None, sampling_config_path: str | None = None, sampling_config_name: str = "default", sampling_config_overrides: list[str] | None = None, record_trajectories: bool = True, diffusion_guidance_factor: float | None = None, strict_checkpoint_loading: bool = True, target_compositions: list[dict[str, int]] | None = None, ): """ Evaluate diffusion model against molecular metrics. Args: model_path: Path to DiffusionLightningModule checkpoint directory. output_path: Path to output directory. config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`. properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn. sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py) sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default) sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`. load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None) record: Whether to record the trajectories of the generated structures. (default: True) strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model. target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on. Only supported for models trained for crystal structure prediction (CSP) (default: None) NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`. """ assert ( pretrained_name is not None or model_path is not None ), "Either pretrained_name or model_path must be provided." assert ( pretrained_name is None or model_path is None ), "Only one of pretrained_name or model_path can be provided." if not os.path.exists(output_path): os.makedirs(output_path) sampling_config_overrides = sampling_config_overrides or [] config_overrides = config_overrides or [] properties_to_condition_on = properties_to_condition_on or {} target_compositions = target_compositions or [] if pretrained_name is not None: checkpoint_info = MatterGenCheckpointInfo.from_hf_hub( pretrained_name, config_overrides=config_overrides ) else: checkpoint_info = MatterGenCheckpointInfo( model_path=Path(model_path).resolve(), load_epoch=checkpoint_epoch, config_overrides=config_overrides, strict_checkpoint_loading=strict_checkpoint_loading, ) _sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None generator = CrystalGenerator( checkpoint_info=checkpoint_info, properties_to_condition_on=properties_to_condition_on, batch_size=batch_size, num_batches=num_batches, sampling_config_name=sampling_config_name, sampling_config_path=_sampling_config_path, sampling_config_overrides=sampling_config_overrides, record_trajectories=record_trajectories, diffusion_guidance_factor=( diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0 ), target_compositions_dict=target_compositions, ) generator.generate(output_dir=Path(output_path)) @llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints") def generate_material( properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None, batch_size: int = 2, num_batches: int = 1, diffusion_guidance_factor: float = 2.0 ) -> str: """ Generate crystal structures with optional property constraints. This unified function can generate materials in three modes: 1. Unconditional generation (no properties specified) 2. Single property conditional generation (one property specified) 3. Multi-property conditional generation (multiple properties specified) Args: properties: Optional property constraints. Can be: - None or empty dict for unconditional generation - Dict with single key-value pair for single property conditioning - Dict with multiple key-value pairs for multi-property conditioning Valid property names include: "dft_band_gap", "chemical_system", etc. batch_size: Number of structures per batch num_batches: Number of batches to generate diffusion_guidance_factor: Controls adherence to target properties Returns: Descriptive text with generated crystal structures in CIF format """ # # 创建队列用于进程间通信 # args_queue = Queue() # result_queue = Queue() # # 将参数放入队列 # args_queue.put({ # "properties": properties, # "batch_size": batch_size, # "num_batches": num_batches, # "diffusion_guidance_factor": diffusion_guidance_factor # }) # # 创建并启动新进程 # logger.info("启动新进程处理材料生成...") # p = Process(target=_process_generate_material_worker, args=(args_queue, result_queue)) # p.start() # # 等待进程完成并获取结果 # p.join() # result = result_queue.get() # # 检查结果是否为错误信息 # if isinstance(result, str) and result.startswith("Error:"): # # 记录错误日志 # logger.error(result) # 导入MatterGenService from mars_toolkit.services.mattergen_service import MatterGenService logger.info("子进程成功导入MatterGenService") # 获取MatterGenService实例 service = MatterGenService.get_instance() logger.info("子进程成功获取MatterGenService实例") # 使用服务生成材料 logger.info("子进程开始调用generate方法...") result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor) logger.info("子进程generate方法调用完成") if "Error generating structures" in result: return f"Error: Invalid properties {properties}." else: return result