diff --git a/mars_toolkit/__init__.py b/mars_toolkit/__init__.py index fba5167..8c760db 100644 --- a/mars_toolkit/__init__.py +++ b/mars_toolkit/__init__.py @@ -1,21 +1,7 @@ -""" -Mars Toolkit - -A comprehensive toolkit for materials science research, providing tools for: -- Material generation and property prediction -- Structure optimization -- Database queries (Materials Project, OQMD) -- Knowledge base retrieval -- Web search - -Author: Yutang LI -Institution: SIAT-MIC -Contact: yt.li2@siat.ac.cn -""" # Core modules from mars_toolkit.core.config import config -from mars_toolkit.core.utils import setup_logging + # Basic tools from mars_toolkit.misc.misc_tools import get_current_time @@ -35,12 +21,14 @@ from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD from mars_toolkit.query.dify_search import retrieval_from_knowledge_base from mars_toolkit.query.web_search import search_online +# Visualization modules + + from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas -# Initialize logging -setup_logging() + __version__ = "0.1.0" __all__ = ["llm_tool", "get_tools", "get_tool_schemas"] diff --git a/mars_toolkit/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/__pycache__/__init__.cpython-310.pyc index 27f48c3..e4e2215 100644 Binary files a/mars_toolkit/__pycache__/__init__.cpython-310.pyc and b/mars_toolkit/__pycache__/__init__.cpython-310.pyc differ diff --git a/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc b/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc index 25707b7..319fd14 100644 Binary files a/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc and b/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc differ diff --git a/mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc b/mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc index b6732e2..b6b0529 100644 Binary files a/mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc and b/mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc differ diff --git a/mars_toolkit/compute/material_gen.py b/mars_toolkit/compute/material_gen.py index 530790c..ac60b4a 100644 --- a/mars_toolkit/compute/material_gen.py +++ b/mars_toolkit/compute/material_gen.py @@ -1,12 +1,3 @@ -""" -Material Generation Module - -This module provides functions for generating crystal structures with optional property constraints. - -Author: Yutang LI -Institution: SIAT-MIC -Contact: yt.li2@siat.ac.cn -""" import ast import json @@ -276,148 +267,16 @@ async def generate_material( Returns: Descriptive text with generated crystal structures in CIF format """ - # 使用配置中的结果目录 - output_dir = config.MATTERGENMODEL_RESULT_PATH + # 导入MatterGenService + from mars_toolkit.services.mattergen_service import MatterGenService - # 处理字符串输入(如果提供) - if isinstance(properties, str): - try: - properties = json.loads(properties) - except json.JSONDecodeError: - raise ValueError(f"Invalid properties JSON string: {properties}") + # 获取MatterGenService实例 + service = MatterGenService.get_instance() - # 如果为None,默认为空字典 - properties = properties or {} - - # 根据生成模式处理属性 - if not properties: - # 无条件生成 - model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base") - properties_to_condition_on = None - generation_type = "unconditional" - property_description = "unconditionally" - else: - # 条件生成(单属性或多属性) - properties_to_condition_on = {} - - # 处理每个属性 - for property_name, property_value in properties.items(): - _, processed_value = preprocess_property(property_name, property_value) - properties_to_condition_on[property_name] = processed_value - - # 根据属性确定使用哪个模型 - if len(properties) == 1: - # 单属性条件 - property_name = list(properties.keys())[0] - property_to_model = { - "dft_mag_density": "dft_mag_density", - "dft_bulk_modulus": "dft_bulk_modulus", - "dft_shear_modulus": "dft_shear_modulus", - "energy_above_hull": "energy_above_hull", - "formation_energy_per_atom": "formation_energy_per_atom", - "space_group": "space_group", - "hhi_score": "hhi_score", - "ml_bulk_modulus": "ml_bulk_modulus", - "chemical_system": "chemical_system", - "dft_band_gap": "dft_band_gap" - } - model_dir = property_to_model.get(property_name, property_name) - generation_type = "single_property" - property_description = f"conditioned on {property_name} = {properties[property_name]}" - else: - # 多属性条件 - property_keys = set(properties.keys()) - if property_keys == {"dft_mag_density", "hhi_score"}: - model_dir = "dft_mag_density_hhi_score" - elif property_keys == {"chemical_system", "energy_above_hull"}: - model_dir = "chemical_system_energy_above_hull" - else: - # 如果没有特定的多属性模型,使用第一个属性的模型 - first_property = list(properties.keys())[0] - model_dir = first_property - generation_type = "multi_property" - property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}" - - # 构建完整的模型路径 - model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir) - - # 检查模型目录是否存在 - if not os.path.exists(model_path): - # 如果特定模型不存在,回退到基础模型 - logger.warning(f"Model directory for {model_dir} not found. Using base model instead.") - model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base") - - # 使用适当的参数调用main函数 - main( - output_path=output_dir, - model_path=model_path, + # 使用服务生成材料 + return service.generate( + properties=properties, batch_size=batch_size, num_batches=num_batches, - properties_to_condition_on=properties_to_condition_on, - record_trajectories=True, - diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0 + diffusion_guidance_factor=diffusion_guidance_factor ) - - # 创建字典存储文件内容 - result_dict = {} - - # 定义文件路径 - cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip") - xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz") - trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip") - - # 读取CIF压缩文件 - if os.path.exists(cif_zip_path): - with open(cif_zip_path, 'rb') as f: - result_dict['cif_content'] = f.read() - - # 根据生成类型创建描述性提示 - if generation_type == "unconditional": - title = "Generated Material Structures" - description = "These structures were generated unconditionally, meaning no specific properties were targeted." - elif generation_type == "single_property": - property_name = list(properties.keys())[0] - property_value = properties[property_name] - title = f"Generated Material Structures Conditioned on {property_name} = {property_value}" - description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}." - else: # multi_property - title = "Generated Material Structures Conditioned on Multiple Properties" - description = "These structures were generated with multi-property conditioning, targeting the specified property values." - - # 创建完整的提示 - prompt = f""" -# {title} - -This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}. - - {'' if generation_type == 'unconditional' else f''' - A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly - the generation adheres to the specified property values. Higher values produce samples that more - closely match the target properties but may reduce diversity. - '''} - -## CIF Files (Crystallographic Information Files) - -- Standard format for crystallographic structures -- Contains unit cell parameters, atomic positions, and symmetry information -- Used by crystallographic software and visualization tools - -``` -{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))} -``` - -{description} -You can use these structures for materials discovery, property prediction, or further analysis. -""" - - # 清理文件(读取后删除) - try: - if os.path.exists(cif_zip_path): - os.remove(cif_zip_path) - if os.path.exists(xyz_file_path): - os.remove(xyz_file_path) - if os.path.exists(trajectories_zip_path): - os.remove(trajectories_zip_path) - except Exception as e: - logger.warning(f"Error cleaning up files: {e}") - return prompt diff --git a/mars_toolkit/compute/structure_opt.py b/mars_toolkit/compute/structure_opt.py index 9fed655..132fbf8 100644 --- a/mars_toolkit/compute/structure_opt.py +++ b/mars_toolkit/compute/structure_opt.py @@ -23,7 +23,6 @@ from pymatgen.io.cif import CifWriter from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz from mars_toolkit.core.llm_tools import llm_tool from mars_toolkit.core.config import config -from mars_toolkit.core.error_handlers import handle_general_error logger = logging.getLogger(__name__) @@ -189,4 +188,4 @@ async def optimize_crystal_structure( # 直接返回结果或抛出异常 return await asyncio.to_thread(run_optimization) except Exception as e: - return handle_general_error(e) + return str(e) diff --git a/mars_toolkit/core/__init__.py b/mars_toolkit/core/__init__.py index 771eee2..ba4e1fc 100644 --- a/mars_toolkit/core/__init__.py +++ b/mars_toolkit/core/__init__.py @@ -5,9 +5,4 @@ This module provides core functionality for the Mars Toolkit. """ from mars_toolkit.core.config import config -from mars_toolkit.core.utils import settings, setup_logging -from mars_toolkit.core.error_handlers import ( - handle_minio_error, handle_http_error, - handle_validation_error, handle_general_error -) from mars_toolkit.core.llm_tools import llm_tool diff --git a/mars_toolkit/core/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/core/__pycache__/__init__.cpython-310.pyc index 8aa07d2..12e8a11 100644 Binary files a/mars_toolkit/core/__pycache__/__init__.cpython-310.pyc and b/mars_toolkit/core/__pycache__/__init__.cpython-310.pyc differ diff --git a/mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc b/mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc index a7c6fcb..ed92f14 100644 Binary files a/mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc and b/mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc differ diff --git a/mars_toolkit/core/__pycache__/config.cpython-310.pyc b/mars_toolkit/core/__pycache__/config.cpython-310.pyc index 9b78045..7b77453 100644 Binary files a/mars_toolkit/core/__pycache__/config.cpython-310.pyc and b/mars_toolkit/core/__pycache__/config.cpython-310.pyc differ diff --git a/mars_toolkit/core/cif_utils.py b/mars_toolkit/core/cif_utils.py index 7fa40ae..044f0a6 100644 --- a/mars_toolkit/core/cif_utils.py +++ b/mars_toolkit/core/cif_utils.py @@ -3,10 +3,6 @@ CIF Utilities Module This module provides basic functions for handling CIF (Crystallographic Information File) files, which are commonly used in materials science for representing crystal structures. - -Author: Yutang LI -Institution: SIAT-MIC -Contact: yt.li2@siat.ac.cn """ import json diff --git a/mars_toolkit/core/error_handlers.py b/mars_toolkit/core/error_handlers.py deleted file mode 100644 index 64e6484..0000000 --- a/mars_toolkit/core/error_handlers.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -Error Handlers Module - -This module provides error handling utilities for the Mars Toolkit. -It includes functions for handling various types of errors that may occur -during toolkit operations. - -Author: Yutang LI -Institution: SIAT-MIC -Contact: yt.li2@siat.ac.cn -""" - -from fastapi import HTTPException -from typing import Any, Dict -import logging - -logger = logging.getLogger(__name__) - -class APIError(HTTPException): - """自定义API错误类""" - def __init__(self, status_code: int, detail: Any = None): - super().__init__(status_code=status_code, detail=detail) - logger.error(f"API Error: {status_code} - {detail}") - -def handle_minio_error(e: Exception) -> Dict[str, str]: - """处理MinIO相关错误""" - logger.error(f"MinIO operation failed: {str(e)}") - return { - "status": "error", - "data": f"MinIO operation failed: {str(e)}" - } - -def handle_http_error(e: Exception) -> Dict[str, str]: - """处理HTTP请求错误""" - logger.error(f"HTTP request failed: {str(e)}") - return { - "status": "error", - "data": f"HTTP request failed: {str(e)}" - } - -def handle_validation_error(e: Exception) -> Dict[str, str]: - """处理数据验证错误""" - logger.error(f"Validation failed: {str(e)}") - return { - "status": "error", - "data": f"Validation failed: {str(e)}" - } - -def handle_general_error(e: Exception) -> Dict[str, str]: - """处理通用错误""" - logger.error(f"Unexpected error: {str(e)}") - return { - "status": "error", - "data": f"Unexpected error: {str(e)}" - } diff --git a/mars_toolkit/core/utils.py b/mars_toolkit/core/utils.py deleted file mode 100644 index 1a96f9b..0000000 --- a/mars_toolkit/core/utils.py +++ /dev/null @@ -1,75 +0,0 @@ -import os -import boto3 -import logging -import logging.config -from typing import Optional -from pydantic import Field -from pydantic_settings import BaseSettings - -logger = logging.getLogger(__name__) - -class Settings(BaseSettings): - # Material Project - mp_api_key: Optional[str] = Field(None, env="MP_API_KEY") - mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT") - mp_topk: Optional[int] = Field(3, env="MP_TOPK") - - # Proxy - http_proxy: Optional[str] = Field(None, env="HTTP_PROXY") - https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY") - - # FairChem - fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH") - fmax: Optional[float] = Field(0.05, env="FMAX") - - # MinIO - minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT") - internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT") - minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY") - minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY") - minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET") - - class Config: - env_file = ".env" - env_file_encoding = "utf-8" - -def setup_logging(): - """配置日志记录""" - parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) - log_file_path = os.path.join(parent_dir, 'mars_toolkit.log') - - logging.config.dictConfig({ - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'standard': { - 'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s', - 'datefmt': '%Y-%m-%d %H:%M:%S' - }, - }, - 'handlers': { - 'console': { - 'level': 'INFO', - 'class': 'logging.StreamHandler', - 'formatter': 'standard' - }, - 'file': { - 'level': 'DEBUG', - 'class': 'logging.handlers.RotatingFileHandler', - 'filename': log_file_path, - 'maxBytes': 10485760, # 10MB - 'backupCount': 5, - 'formatter': 'standard' - } - }, - 'loggers': { - '': { - 'handlers': ['console', 'file'], - 'level': 'INFO', - 'propagate': True - } - } - }) - -# 初始化配置 -settings = Settings() diff --git a/mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc b/mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc index db793d7..f398bea 100644 Binary files a/mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc and b/mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc differ diff --git a/mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc b/mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc index 254d8f9..acb5ae2 100644 Binary files a/mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc and b/mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc differ diff --git a/mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc b/mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc index 2f44daf..8473082 100644 Binary files a/mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc and b/mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc differ diff --git a/mars_toolkit/query/dify_search.py b/mars_toolkit/query/dify_search.py index 82c8672..04ba69b 100644 --- a/mars_toolkit/query/dify_search.py +++ b/mars_toolkit/query/dify_search.py @@ -1,13 +1,4 @@ -""" -Dify Search Module -This module provides functions for retrieving information from local materials science -literature knowledge base using Dify API. - -Author: Yutang LI -Institution: SIAT-MIC -Contact: yt.li2@siat.ac.cn -""" import asyncio import json diff --git a/mars_toolkit/query/mp_query.py b/mars_toolkit/query/mp_query.py index 3ac2d44..a7a1853 100644 --- a/mars_toolkit/query/mp_query.py +++ b/mars_toolkit/query/mp_query.py @@ -1,14 +1,3 @@ -""" -Materials Project Query Module - -This module provides functions for querying the Materials Project database, -processing search results, and formatting responses. - -Author: Yutang LI -Institution: SIAT-MIC -Contact: yt.li2@siat.ac.cn -""" - import glob import json import asyncio @@ -25,7 +14,6 @@ from pymatgen.io.cif import CifWriter from mars_toolkit.core.llm_tools import llm_tool from mars_toolkit.core.config import config -from mars_toolkit.core.error_handlers import handle_general_error from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz logger = logging.getLogger(__name__) diff --git a/mars_toolkit/query/oqmd_query.py b/mars_toolkit/query/oqmd_query.py index d4ed1ea..1a582a2 100644 --- a/mars_toolkit/query/oqmd_query.py +++ b/mars_toolkit/query/oqmd_query.py @@ -1,13 +1,3 @@ -""" -OQMD Query Module - -This module provides functions for querying the Open Quantum Materials Database (OQMD). - -Author: Yutang LI -Institution: SIAT-MIC -Contact: yt.li2@siat.ac.cn -""" - import logging import httpx import pandas as pd diff --git a/mars_toolkit/services/__init__.py b/mars_toolkit/services/__init__.py new file mode 100644 index 0000000..268869a --- /dev/null +++ b/mars_toolkit/services/__init__.py @@ -0,0 +1,12 @@ +""" +Services module for mars_toolkit. + +This module contains service classes that provide persistent functionality +across multiple function calls, such as maintaining initialized models. +""" + +# Import services for easy access +from mars_toolkit.services.mattergen_service import MatterGenService + +# Export services +__all__ = ['MatterGenService'] diff --git a/mars_toolkit/services/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/services/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..cad4626 Binary files /dev/null and b/mars_toolkit/services/__pycache__/__init__.cpython-310.pyc differ diff --git a/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc b/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc new file mode 100644 index 0000000..77575bd Binary files /dev/null and b/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc differ diff --git a/mars_toolkit/services/mattergen_service.py b/mars_toolkit/services/mattergen_service.py new file mode 100644 index 0000000..1578268 --- /dev/null +++ b/mars_toolkit/services/mattergen_service.py @@ -0,0 +1,342 @@ +""" +MatterGen service for mars_toolkit. + +This module provides a service for generating crystal structures using MatterGen. +The service initializes the CrystalGenerator once and reuses it for multiple +generation requests, improving performance. +""" + +import os +import logging +import json +from pathlib import Path +from typing import Dict, Any, Optional, Union, List +import threading + +# 导入mattergen相关模块 +import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../'))) +from mattergen_wrapper import generator +CrystalGenerator = generator.CrystalGenerator +from mattergen.common.data.types import TargetProperty +from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo +from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME + +# 导入mars_toolkit配置 +from mars_toolkit.core.config import config + +logger = logging.getLogger(__name__) + +class MatterGenService: + """ + Service for generating crystal structures using MatterGen. + + This service initializes the CrystalGenerator once and reuses it for multiple + generation requests, improving performance. + """ + + _instance = None + _lock = threading.Lock() + + @classmethod + def get_instance(cls): + """ + Get the singleton instance of MatterGenService. + + Returns: + MatterGenService: The singleton instance. + """ + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + """ + Initialize the MatterGenService. + + This initializes the base generator without any property conditioning. + Specific generators for different property conditions will be initialized + on demand. + """ + self._generators = {} + self._output_dir = config.MATTERGENMODEL_RESULT_PATH + + # 确保输出目录存在 + if not os.path.exists(self._output_dir): + os.makedirs(self._output_dir) + + # 初始化基础生成器(无条件生成) + self._init_base_generator() + + def _init_base_generator(self): + """ + Initialize the base generator for unconditional generation. + """ + model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base") + + if not os.path.exists(model_path): + logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.") + return + + logger.info(f"Initializing base MatterGen generator from {model_path}") + + try: + checkpoint_info = MatterGenCheckpointInfo( + model_path=Path(model_path).resolve(), + load_epoch="last", + config_overrides=[], + strict_checkpoint_loading=True, + ) + + generator = CrystalGenerator( + checkpoint_info=checkpoint_info, + properties_to_condition_on=None, + batch_size=2, # 默认值,可在生成时覆盖 + num_batches=1, # 默认值,可在生成时覆盖 + sampling_config_name="default", + sampling_config_path=None, + sampling_config_overrides=[], + record_trajectories=True, + diffusion_guidance_factor=0.0, + target_compositions_dict=[], + ) + + self._generators["base"] = generator + logger.info("Base MatterGen generator initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize base MatterGen generator: {e}") + + def _get_or_create_generator( + self, + properties: Optional[Dict[str, Any]] = None, + batch_size: int = 2, + num_batches: int = 1, + diffusion_guidance_factor: float = 2.0 + ): + """ + Get or create a generator for the specified properties. + + Args: + properties: Optional property constraints + batch_size: Number of structures per batch + num_batches: Number of batches to generate + diffusion_guidance_factor: Controls adherence to target properties + + Returns: + tuple: (generator, generator_key, properties_to_condition_on) + """ + # 如果没有属性约束,使用基础生成器 + if not properties: + if "base" not in self._generators: + self._init_base_generator() + return self._generators.get("base"), "base", None + + # 处理属性约束 + properties_to_condition_on = {} + for property_name, property_value in properties.items(): + properties_to_condition_on[property_name] = property_value + + # 确定模型目录 + if len(properties) == 1: + # 单属性条件 + property_name = list(properties.keys())[0] + property_to_model = { + "dft_mag_density": "dft_mag_density", + "dft_bulk_modulus": "dft_bulk_modulus", + "dft_shear_modulus": "dft_shear_modulus", + "energy_above_hull": "energy_above_hull", + "formation_energy_per_atom": "formation_energy_per_atom", + "space_group": "space_group", + "hhi_score": "hhi_score", + "ml_bulk_modulus": "ml_bulk_modulus", + "chemical_system": "chemical_system", + "dft_band_gap": "dft_band_gap" + } + model_dir = property_to_model.get(property_name, property_name) + generator_key = f"single_{property_name}" + else: + # 多属性条件 + property_keys = set(properties.keys()) + if property_keys == {"dft_mag_density", "hhi_score"}: + model_dir = "dft_mag_density_hhi_score" + generator_key = "multi_dft_mag_density_hhi_score" + elif property_keys == {"chemical_system", "energy_above_hull"}: + model_dir = "chemical_system_energy_above_hull" + generator_key = "multi_chemical_system_energy_above_hull" + else: + # 如果没有特定的多属性模型,使用第一个属性的模型 + first_property = list(properties.keys())[0] + model_dir = first_property + generator_key = f"multi_{first_property}_etc" + + # 构建完整的模型路径 + model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir) + + # 检查模型目录是否存在 + if not os.path.exists(model_path): + # 如果特定模型不存在,回退到基础模型 + logger.warning(f"Model directory for {model_dir} not found. Using base model instead.") + model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base") + generator_key = "base" + + # 检查是否已经有这个生成器 + if generator_key in self._generators: + # 更新生成器的参数 + generator = self._generators[generator_key] + generator.batch_size = batch_size + generator.num_batches = num_batches + generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0 + return generator, generator_key, properties_to_condition_on + + # 创建新的生成器 + try: + logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}") + + checkpoint_info = MatterGenCheckpointInfo( + model_path=Path(model_path).resolve(), + load_epoch="last", + config_overrides=[], + strict_checkpoint_loading=True, + ) + + generator = CrystalGenerator( + checkpoint_info=checkpoint_info, + properties_to_condition_on=properties_to_condition_on, + batch_size=batch_size, + num_batches=num_batches, + sampling_config_name="default", + sampling_config_path=None, + sampling_config_overrides=[], + record_trajectories=True, + diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0, + target_compositions_dict=[], + ) + + self._generators[generator_key] = generator + logger.info(f"MatterGen generator for {generator_key} initialized successfully") + return generator, generator_key, properties_to_condition_on + except Exception as e: + logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}") + # 回退到基础生成器 + if "base" not in self._generators: + self._init_base_generator() + return self._generators.get("base"), "base", None + + def generate( + self, + properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None, + batch_size: int = 2, + num_batches: int = 1, + diffusion_guidance_factor: float = 2.0 + ) -> str: + """ + Generate crystal structures with optional property constraints. + + Args: + properties: Optional property constraints + batch_size: Number of structures per batch + num_batches: Number of batches to generate + diffusion_guidance_factor: Controls adherence to target properties + + Returns: + str: Descriptive text with generated crystal structures in CIF format + """ + from mars_toolkit.compute.material_gen import format_cif_content + + # 处理字符串输入(如果提供) + if isinstance(properties, str): + try: + properties = json.loads(properties) + except json.JSONDecodeError: + raise ValueError(f"Invalid properties JSON string: {properties}") + + # 如果为None,默认为空字典 + properties = properties or {} + + # 获取或创建生成器 + generator, generator_key, properties_to_condition_on = self._get_or_create_generator( + properties, batch_size, num_batches, diffusion_guidance_factor + ) + + if generator is None: + return "Error: Failed to initialize MatterGen generator" + + # 生成结构 + try: + generator.generate(output_dir=Path(self._output_dir)) + except Exception as e: + logger.error(f"Error generating structures: {e}") + return f"Error generating structures: {e}" + + # 创建字典存储文件内容 + result_dict = {} + + # 定义文件路径 + cif_zip_path = os.path.join(self._output_dir, "generated_crystals_cif.zip") + xyz_file_path = os.path.join(self._output_dir, "generated_crystals.extxyz") + trajectories_zip_path = os.path.join(self._output_dir, "generated_trajectories.zip") + + # 读取CIF压缩文件 + if os.path.exists(cif_zip_path): + with open(cif_zip_path, 'rb') as f: + result_dict['cif_content'] = f.read() + + # 根据生成类型创建描述性提示 + if not properties: + generation_type = "unconditional" + title = "Generated Material Structures" + description = "These structures were generated unconditionally, meaning no specific properties were targeted." + property_description = "unconditionally" + elif len(properties) == 1: + generation_type = "single_property" + property_name = list(properties.keys())[0] + property_value = properties[property_name] + title = f"Generated Material Structures Conditioned on {property_name} = {property_value}" + description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}." + property_description = f"conditioned on {property_name} = {property_value}" + else: + generation_type = "multi_property" + title = "Generated Material Structures Conditioned on Multiple Properties" + description = "These structures were generated with multi-property conditioning, targeting the specified property values." + property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}" + + # 创建完整的提示 + prompt = f""" +# {title} + +This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}. + + {'' if generation_type == 'unconditional' else f''' + A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly + the generation adheres to the specified property values. Higher values produce samples that more + closely match the target properties but may reduce diversity. + '''} + +## CIF Files (Crystallographic Information Files) + +- Standard format for crystallographic structures +- Contains unit cell parameters, atomic positions, and symmetry information +- Used by crystallographic software and visualization tools + +``` +{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))} +``` + +{description} +You can use these structures for materials discovery, property prediction, or further analysis. +""" + + # 清理文件(读取后删除) + try: + if os.path.exists(cif_zip_path): + os.remove(cif_zip_path) + if os.path.exists(xyz_file_path): + os.remove(xyz_file_path) + if os.path.exists(trajectories_zip_path): + os.remove(trajectories_zip_path) + except Exception as e: + logger.warning(f"Error cleaning up files: {e}") + + return prompt diff --git a/mars_toolkit/visualization/__init__.py b/mars_toolkit/visualization/__init__.py index e69de29..e714c1f 100644 --- a/mars_toolkit/visualization/__init__.py +++ b/mars_toolkit/visualization/__init__.py @@ -0,0 +1,2 @@ + +f diff --git a/mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..ce3c413 Binary files /dev/null and b/mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc differ diff --git a/mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc b/mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc new file mode 100644 index 0000000..4673119 Binary files /dev/null and b/mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc differ diff --git a/mars_toolkit/visualization/__pycache__/crystal_vis.cpython-310.pyc b/mars_toolkit/visualization/__pycache__/crystal_vis.cpython-310.pyc new file mode 100644 index 0000000..305ad81 Binary files /dev/null and b/mars_toolkit/visualization/__pycache__/crystal_vis.cpython-310.pyc differ diff --git a/mattergen_api.py b/mattergen_api.py new file mode 100644 index 0000000..f8ac878 --- /dev/null +++ b/mattergen_api.py @@ -0,0 +1,151 @@ +from fastapi import FastAPI, HTTPException, Request, status +from fastapi.responses import JSONResponse +from pydantic import BaseModel +import uvicorn +from typing import Dict, Any, Optional, Union, List +import logging +import traceback +import sys + +# 配置日志 +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[logging.StreamHandler(sys.stdout)]) +logger = logging.getLogger(__name__) + +# 初始化FastAPI +app = FastAPI(title="MatterGen API Service") + +# 请求模型 +class MaterialGenerationRequest(BaseModel): + properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None + batch_size: int = 2 + num_batches: int = 1 + diffusion_guidance_factor: float = 2.0 + +# 响应模型 +class MaterialGenerationResponse(BaseModel): + content: str + success: bool + message: str + +# 全局变量,用于跟踪服务状态 +service_status = { + "initialized": False, + "error": None, + "mattergen_service": None +} + +# 初始化MatterGenService +try: + logger.info("Importing MatterGenService...") + from mars_toolkit.services.mattergen_service import MatterGenService + + logger.info("Initializing MatterGenService...") + mattergen_service = MatterGenService.get_instance() + service_status["mattergen_service"] = mattergen_service + service_status["initialized"] = True + logger.info("MatterGenService initialized successfully") +except Exception as e: + error_msg = f"Failed to initialize MatterGenService: {str(e)}" + logger.error(error_msg) + logger.error(traceback.format_exc()) + service_status["error"] = error_msg + +# 中间件:检查服务状态 +@app.middleware("http") +async def check_service_status(request: Request, call_next): + # 健康检查端点不需要检查服务状态 + if request.url.path == "/health": + return await call_next(request) + + # 如果服务未初始化,返回503错误 + if not service_status["initialized"]: + error_msg = service_status["error"] or "MatterGenService not initialized" + return JSONResponse( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + content={"detail": error_msg} + ) + + # 继续处理请求 + return await call_next(request) + +@app.post("/generate_material", response_model=MaterialGenerationResponse) +async def generate_material(request: MaterialGenerationRequest): + """生成晶体结构,可选择性地指定属性约束""" + try: + logger.info(f"Received material generation request with properties: {request.properties}") + print("request",request) + # 调用MatterGenService生成材料 + result = mattergen_service.generate( + properties=request.properties, + batch_size=request.batch_size, + num_batches=request.num_batches, + diffusion_guidance_factor=request.diffusion_guidance_factor + ) + + logger.info("Material generation completed successfully") + + return { + "content": result, + "success": True, + "message": "Material generation successful" + } + except Exception as e: + # 记录详细错误信息 + error_msg = f"Error generating material: {str(e)}" + logger.error(error_msg) + logger.error(traceback.format_exc()) + + # 返回错误响应 + return { + "content": "", + "success": False, + "message": error_msg + } + +@app.get("/health") +async def health_check(): + """健康检查端点,检查MatterGenService的状态""" + if service_status["initialized"]: + return { + "status": "healthy", + "service": "MatterGen API", + "mattergen_service": "initialized" + } + else: + error_msg = service_status["error"] or "MatterGenService not initialized" + return { + "status": "unhealthy", + "service": "MatterGen API", + "error": error_msg + } + +@app.get("/") +async def root(): + """API根端点,提供基本信息""" + return { + "service": "MatterGen API Service", + "description": "API for generating crystal structures with optional property constraints", + "status": "healthy" if service_status["initialized"] else "unhealthy", + "endpoints": { + "/generate_material": "POST - Generate crystal structures", + "/health": "GET - Health check", + "/docs": "GET - API documentation" + } + } + +# 全局异常处理 +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + logger.error(f"Unhandled exception: {str(exc)}") + logger.error(traceback.format_exc()) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": f"Internal server error: {str(exc)}"} + ) + +if __name__ == "__main__": + # 启动服务 + logger.info("Starting MatterGen API Service...") + uvicorn.run(app, host="0.0.0.0", port=8051) diff --git a/mattergen_client_example.py b/mattergen_client_example.py new file mode 100644 index 0000000..928237c --- /dev/null +++ b/mattergen_client_example.py @@ -0,0 +1,134 @@ +import requests +import json +import argparse +import sys + +def generate_material( + url="http://localhost:8051/generate_material", + properties=None, + batch_size=2, + num_batches=1, + diffusion_guidance_factor=2.0 +): + """ + 调用MatterGen API生成晶体结构 + + Args: + url: API端点URL + properties: 可选的属性约束,例如{"dft_band_gap": 2.0} + batch_size: 每批生成的结构数量 + num_batches: 批次数量 + diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 + + Returns: + 生成的结构内容或错误信息 + """ + # 构建请求负载 + payload = { + "properties": properties , + "batch_size": batch_size, + "num_batches": num_batches, + "diffusion_guidance_factor": diffusion_guidance_factor + } + + print(f"发送请求到 {url}") + print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}") + + try: + # 添加headers参数,包含accept头 + headers = { + "Content-Type": "application/json", + "accept": "application/json" + } + + # 打印完整请求信息(调试用) + print(f"完整请求URL: {url}") + print(f"请求头: {headers}") + print(f"请求体: {json.dumps(payload)}") + + # 禁用代理设置 + proxies = { + "http": None, + "https": None + } + + # 发送POST请求,添加headers参数,禁用代理,增加超时时间 + response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300) + + # 打印响应信息(调试用) + print(f"响应状态码: {response.status_code}") + print(f"响应头: {dict(response.headers)}") + print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长 + + # 检查响应状态 + if response.status_code == 200: + result = response.json() + + if result["success"]: + print("\n生成成功!") + return result["content"] + else: + print(f"\n生成失败: {result['message']}") + return None + else: + print(f"\n请求失败,状态码: {response.status_code}") + print(f"响应内容: {response.text}") + return None + + except Exception as e: + print(f"\n发生错误: {str(e)}") + print(f"错误类型: {type(e).__name__}") + import traceback + print(f"错误堆栈: {traceback.format_exc()}") + return None + +def main(): + """命令行入口函数""" + parser = argparse.ArgumentParser(description="MatterGen API客户端示例") + + # 添加命令行参数 + parser.add_argument("--url", default="http://localhost:8051/generate_material", + help="MatterGen API端点URL") + parser.add_argument("--property-name", default='dft_mag_density',help="属性名称,例如dft_band_gap") + parser.add_argument("--property-value",default=0.15,help="属性值,例如2.0") + parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量") + parser.add_argument("--num-batches", type=int, default=1, help="批次数量") + parser.add_argument("--guidance-factor", type=float, default=2.0, + help="控制生成结构与目标属性的符合程度") + + args = parser.parse_args() + + # 构建属性字典 + properties = None + if args.property_name and args.property_value: + try: + # 尝试将属性值转换为数字 + try: + value = float(args.property_value) + # 如果是整数,转换为整数 + if value.is_integer(): + value = int(value) + except ValueError: + # 如果无法转换为数字,保持为字符串 + value = args.property_value + + properties = {args.property_name: value} + except Exception as e: + print(f"解析属性值时出错: {str(e)}") + return + + # 调用API + result = generate_material( + url=args.url, + properties=properties, + batch_size=args.batch_size, + num_batches=args.num_batches, + diffusion_guidance_factor=args.guidance_factor + ) + + if result: + print("\n生成的结构:") + print(result) + +if __name__ == "__main__": + main() diff --git a/test_mars_toolkit.py b/test_mars_toolkit.py index 0c2b8a9..f3b997d 100644 --- a/test_mars_toolkit.py +++ b/test_mars_toolkit.py @@ -171,7 +171,7 @@ if __name__ == "__main__": ] # 选择要测试的工具 - tool_name = tools_to_test[5] # 测试 search_online 工具 + tool_name = tools_to_test[6] # 测试 search_online 工具 # 运行测试 result = asyncio.run(test_tool(tool_name))