mattergen转服务
This commit is contained in:
@@ -1,21 +1,7 @@
|
||||
"""
|
||||
Mars Toolkit
|
||||
|
||||
A comprehensive toolkit for materials science research, providing tools for:
|
||||
- Material generation and property prediction
|
||||
- Structure optimization
|
||||
- Database queries (Materials Project, OQMD)
|
||||
- Knowledge base retrieval
|
||||
- Web search
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
# Core modules
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.utils import setup_logging
|
||||
|
||||
|
||||
# Basic tools
|
||||
from mars_toolkit.misc.misc_tools import get_current_time
|
||||
@@ -35,12 +21,14 @@ from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
from mars_toolkit.query.web_search import search_online
|
||||
|
||||
# Visualization modules
|
||||
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
|
||||
|
||||
|
||||
|
||||
# Initialize logging
|
||||
setup_logging()
|
||||
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = ["llm_tool", "get_tools", "get_tool_schemas"]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,12 +1,3 @@
|
||||
"""
|
||||
Material Generation Module
|
||||
|
||||
This module provides functions for generating crystal structures with optional property constraints.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
@@ -276,148 +267,16 @@ async def generate_material(
|
||||
Returns:
|
||||
Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
# 使用配置中的结果目录
|
||||
output_dir = config.MATTERGENMODEL_RESULT_PATH
|
||||
# 导入MatterGenService
|
||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||
|
||||
# 处理字符串输入(如果提供)
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||
# 获取MatterGenService实例
|
||||
service = MatterGenService.get_instance()
|
||||
|
||||
# 如果为None,默认为空字典
|
||||
properties = properties or {}
|
||||
|
||||
# 根据生成模式处理属性
|
||||
if not properties:
|
||||
# 无条件生成
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
properties_to_condition_on = None
|
||||
generation_type = "unconditional"
|
||||
property_description = "unconditionally"
|
||||
else:
|
||||
# 条件生成(单属性或多属性)
|
||||
properties_to_condition_on = {}
|
||||
|
||||
# 处理每个属性
|
||||
for property_name, property_value in properties.items():
|
||||
_, processed_value = preprocess_property(property_name, property_value)
|
||||
properties_to_condition_on[property_name] = processed_value
|
||||
|
||||
# 根据属性确定使用哪个模型
|
||||
if len(properties) == 1:
|
||||
# 单属性条件
|
||||
property_name = list(properties.keys())[0]
|
||||
property_to_model = {
|
||||
"dft_mag_density": "dft_mag_density",
|
||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||
"dft_shear_modulus": "dft_shear_modulus",
|
||||
"energy_above_hull": "energy_above_hull",
|
||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||
"space_group": "space_group",
|
||||
"hhi_score": "hhi_score",
|
||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||
"chemical_system": "chemical_system",
|
||||
"dft_band_gap": "dft_band_gap"
|
||||
}
|
||||
model_dir = property_to_model.get(property_name, property_name)
|
||||
generation_type = "single_property"
|
||||
property_description = f"conditioned on {property_name} = {properties[property_name]}"
|
||||
else:
|
||||
# 多属性条件
|
||||
property_keys = set(properties.keys())
|
||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||
model_dir = "dft_mag_density_hhi_score"
|
||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||
model_dir = "chemical_system_energy_above_hull"
|
||||
else:
|
||||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||||
first_property = list(properties.keys())[0]
|
||||
model_dir = first_property
|
||||
generation_type = "multi_property"
|
||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||
|
||||
# 构建完整的模型路径
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
# 检查模型目录是否存在
|
||||
if not os.path.exists(model_path):
|
||||
# 如果特定模型不存在,回退到基础模型
|
||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
|
||||
# 使用适当的参数调用main函数
|
||||
main(
|
||||
output_path=output_dir,
|
||||
model_path=model_path,
|
||||
# 使用服务生成材料
|
||||
return service.generate(
|
||||
properties=properties,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
|
||||
diffusion_guidance_factor=diffusion_guidance_factor
|
||||
)
|
||||
|
||||
# 创建字典存储文件内容
|
||||
result_dict = {}
|
||||
|
||||
# 定义文件路径
|
||||
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
|
||||
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
|
||||
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
|
||||
|
||||
# 读取CIF压缩文件
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
|
||||
# 根据生成类型创建描述性提示
|
||||
if generation_type == "unconditional":
|
||||
title = "Generated Material Structures"
|
||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||
elif generation_type == "single_property":
|
||||
property_name = list(properties.keys())[0]
|
||||
property_value = properties[property_name]
|
||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||
else: # multi_property
|
||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||
|
||||
# 创建完整的提示
|
||||
prompt = f"""
|
||||
# {title}
|
||||
|
||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||
|
||||
{'' if generation_type == 'unconditional' else f'''
|
||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||
the generation adheres to the specified property values. Higher values produce samples that more
|
||||
closely match the target properties but may reduce diversity.
|
||||
'''}
|
||||
|
||||
## CIF Files (Crystallographic Information Files)
|
||||
|
||||
- Standard format for crystallographic structures
|
||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||
- Used by crystallographic software and visualization tools
|
||||
|
||||
```
|
||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||
```
|
||||
|
||||
{description}
|
||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||
"""
|
||||
|
||||
# 清理文件(读取后删除)
|
||||
try:
|
||||
if os.path.exists(cif_zip_path):
|
||||
os.remove(cif_zip_path)
|
||||
if os.path.exists(xyz_file_path):
|
||||
os.remove(xyz_file_path)
|
||||
if os.path.exists(trajectories_zip_path):
|
||||
os.remove(trajectories_zip_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up files: {e}")
|
||||
return prompt
|
||||
|
||||
@@ -23,7 +23,6 @@ from pymatgen.io.cif import CifWriter
|
||||
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.error_handlers import handle_general_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -189,4 +188,4 @@ async def optimize_crystal_structure(
|
||||
# 直接返回结果或抛出异常
|
||||
return await asyncio.to_thread(run_optimization)
|
||||
except Exception as e:
|
||||
return handle_general_error(e)
|
||||
return str(e)
|
||||
|
||||
@@ -5,9 +5,4 @@ This module provides core functionality for the Mars Toolkit.
|
||||
"""
|
||||
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.utils import settings, setup_logging
|
||||
from mars_toolkit.core.error_handlers import (
|
||||
handle_minio_error, handle_http_error,
|
||||
handle_validation_error, handle_general_error
|
||||
)
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -3,10 +3,6 @@ CIF Utilities Module
|
||||
|
||||
This module provides basic functions for handling CIF (Crystallographic Information File) files,
|
||||
which are commonly used in materials science for representing crystal structures.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""
|
||||
Error Handlers Module
|
||||
|
||||
This module provides error handling utilities for the Mars Toolkit.
|
||||
It includes functions for handling various types of errors that may occur
|
||||
during toolkit operations.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class APIError(HTTPException):
|
||||
"""自定义API错误类"""
|
||||
def __init__(self, status_code: int, detail: Any = None):
|
||||
super().__init__(status_code=status_code, detail=detail)
|
||||
logger.error(f"API Error: {status_code} - {detail}")
|
||||
|
||||
def handle_minio_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理MinIO相关错误"""
|
||||
logger.error(f"MinIO operation failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"MinIO operation failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_http_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理HTTP请求错误"""
|
||||
logger.error(f"HTTP request failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"HTTP request failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_validation_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理数据验证错误"""
|
||||
logger.error(f"Validation failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"Validation failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_general_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理通用错误"""
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"Unexpected error: {str(e)}"
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
import os
|
||||
import boto3
|
||||
import logging
|
||||
import logging.config
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Material Project
|
||||
mp_api_key: Optional[str] = Field(None, env="MP_API_KEY")
|
||||
mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT")
|
||||
mp_topk: Optional[int] = Field(3, env="MP_TOPK")
|
||||
|
||||
# Proxy
|
||||
http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
|
||||
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
|
||||
|
||||
# FairChem
|
||||
fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH")
|
||||
fmax: Optional[float] = Field(0.05, env="FMAX")
|
||||
|
||||
# MinIO
|
||||
minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT")
|
||||
internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT")
|
||||
minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY")
|
||||
minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY")
|
||||
minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET")
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
def setup_logging():
|
||||
"""配置日志记录"""
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
log_file_path = os.path.join(parent_dir, 'mars_toolkit.log')
|
||||
|
||||
logging.config.dictConfig({
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'standard': {
|
||||
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
'datefmt': '%Y-%m-%d %H:%M:%S'
|
||||
},
|
||||
},
|
||||
'handlers': {
|
||||
'console': {
|
||||
'level': 'INFO',
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': 'standard'
|
||||
},
|
||||
'file': {
|
||||
'level': 'DEBUG',
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'filename': log_file_path,
|
||||
'maxBytes': 10485760, # 10MB
|
||||
'backupCount': 5,
|
||||
'formatter': 'standard'
|
||||
}
|
||||
},
|
||||
'loggers': {
|
||||
'': {
|
||||
'handlers': ['console', 'file'],
|
||||
'level': 'INFO',
|
||||
'propagate': True
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
# 初始化配置
|
||||
settings = Settings()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,13 +1,4 @@
|
||||
"""
|
||||
Dify Search Module
|
||||
|
||||
This module provides functions for retrieving information from local materials science
|
||||
literature knowledge base using Dify API.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
@@ -1,14 +1,3 @@
|
||||
"""
|
||||
Materials Project Query Module
|
||||
|
||||
This module provides functions for querying the Materials Project database,
|
||||
processing search results, and formatting responses.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import asyncio
|
||||
@@ -25,7 +14,6 @@ from pymatgen.io.cif import CifWriter
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.error_handlers import handle_general_error
|
||||
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,13 +1,3 @@
|
||||
"""
|
||||
OQMD Query Module
|
||||
|
||||
This module provides functions for querying the Open Quantum Materials Database (OQMD).
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
import pandas as pd
|
||||
|
||||
12
mars_toolkit/services/__init__.py
Normal file
12
mars_toolkit/services/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Services module for mars_toolkit.
|
||||
|
||||
This module contains service classes that provide persistent functionality
|
||||
across multiple function calls, such as maintaining initialized models.
|
||||
"""
|
||||
|
||||
# Import services for easy access
|
||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||
|
||||
# Export services
|
||||
__all__ = ['MatterGenService']
|
||||
BIN
mars_toolkit/services/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/services/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
342
mars_toolkit/services/mattergen_service.py
Normal file
342
mars_toolkit/services/mattergen_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
MatterGen service for mars_toolkit.
|
||||
|
||||
This module provides a service for generating crystal structures using MatterGen.
|
||||
The service initializes the CrystalGenerator once and reuses it for multiple
|
||||
generation requests, improving performance.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
import threading
|
||||
|
||||
# 导入mattergen相关模块
|
||||
import sys
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../')))
|
||||
from mattergen_wrapper import generator
|
||||
CrystalGenerator = generator.CrystalGenerator
|
||||
from mattergen.common.data.types import TargetProperty
|
||||
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||||
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
|
||||
|
||||
# 导入mars_toolkit配置
|
||||
from mars_toolkit.core.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MatterGenService:
|
||||
"""
|
||||
Service for generating crystal structures using MatterGen.
|
||||
|
||||
This service initializes the CrystalGenerator once and reuses it for multiple
|
||||
generation requests, improving performance.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""
|
||||
Get the singleton instance of MatterGenService.
|
||||
|
||||
Returns:
|
||||
MatterGenService: The singleton instance.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the MatterGenService.
|
||||
|
||||
This initializes the base generator without any property conditioning.
|
||||
Specific generators for different property conditions will be initialized
|
||||
on demand.
|
||||
"""
|
||||
self._generators = {}
|
||||
self._output_dir = config.MATTERGENMODEL_RESULT_PATH
|
||||
|
||||
# 确保输出目录存在
|
||||
if not os.path.exists(self._output_dir):
|
||||
os.makedirs(self._output_dir)
|
||||
|
||||
# 初始化基础生成器(无条件生成)
|
||||
self._init_base_generator()
|
||||
|
||||
def _init_base_generator(self):
|
||||
"""
|
||||
Initialize the base generator for unconditional generation.
|
||||
"""
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.")
|
||||
return
|
||||
|
||||
logger.info(f"Initializing base MatterGen generator from {model_path}")
|
||||
|
||||
try:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch="last",
|
||||
config_overrides=[],
|
||||
strict_checkpoint_loading=True,
|
||||
)
|
||||
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=None,
|
||||
batch_size=2, # 默认值,可在生成时覆盖
|
||||
num_batches=1, # 默认值,可在生成时覆盖
|
||||
sampling_config_name="default",
|
||||
sampling_config_path=None,
|
||||
sampling_config_overrides=[],
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=0.0,
|
||||
target_compositions_dict=[],
|
||||
)
|
||||
|
||||
self._generators["base"] = generator
|
||||
logger.info("Base MatterGen generator initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize base MatterGen generator: {e}")
|
||||
|
||||
def _get_or_create_generator(
|
||||
self,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
):
|
||||
"""
|
||||
Get or create a generator for the specified properties.
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
tuple: (generator, generator_key, properties_to_condition_on)
|
||||
"""
|
||||
# 如果没有属性约束,使用基础生成器
|
||||
if not properties:
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
return self._generators.get("base"), "base", None
|
||||
|
||||
# 处理属性约束
|
||||
properties_to_condition_on = {}
|
||||
for property_name, property_value in properties.items():
|
||||
properties_to_condition_on[property_name] = property_value
|
||||
|
||||
# 确定模型目录
|
||||
if len(properties) == 1:
|
||||
# 单属性条件
|
||||
property_name = list(properties.keys())[0]
|
||||
property_to_model = {
|
||||
"dft_mag_density": "dft_mag_density",
|
||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||
"dft_shear_modulus": "dft_shear_modulus",
|
||||
"energy_above_hull": "energy_above_hull",
|
||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||
"space_group": "space_group",
|
||||
"hhi_score": "hhi_score",
|
||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||
"chemical_system": "chemical_system",
|
||||
"dft_band_gap": "dft_band_gap"
|
||||
}
|
||||
model_dir = property_to_model.get(property_name, property_name)
|
||||
generator_key = f"single_{property_name}"
|
||||
else:
|
||||
# 多属性条件
|
||||
property_keys = set(properties.keys())
|
||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||
model_dir = "dft_mag_density_hhi_score"
|
||||
generator_key = "multi_dft_mag_density_hhi_score"
|
||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||
model_dir = "chemical_system_energy_above_hull"
|
||||
generator_key = "multi_chemical_system_energy_above_hull"
|
||||
else:
|
||||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||||
first_property = list(properties.keys())[0]
|
||||
model_dir = first_property
|
||||
generator_key = f"multi_{first_property}_etc"
|
||||
|
||||
# 构建完整的模型路径
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
# 检查模型目录是否存在
|
||||
if not os.path.exists(model_path):
|
||||
# 如果特定模型不存在,回退到基础模型
|
||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
generator_key = "base"
|
||||
|
||||
# 检查是否已经有这个生成器
|
||||
if generator_key in self._generators:
|
||||
# 更新生成器的参数
|
||||
generator = self._generators[generator_key]
|
||||
generator.batch_size = batch_size
|
||||
generator.num_batches = num_batches
|
||||
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
|
||||
return generator, generator_key, properties_to_condition_on
|
||||
|
||||
# 创建新的生成器
|
||||
try:
|
||||
logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}")
|
||||
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch="last",
|
||||
config_overrides=[],
|
||||
strict_checkpoint_loading=True,
|
||||
)
|
||||
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name="default",
|
||||
sampling_config_path=None,
|
||||
sampling_config_overrides=[],
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0,
|
||||
target_compositions_dict=[],
|
||||
)
|
||||
|
||||
self._generators[generator_key] = generator
|
||||
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
|
||||
return generator, generator_key, properties_to_condition_on
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
|
||||
# 回退到基础生成器
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
return self._generators.get("base"), "base", None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
str: Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
from mars_toolkit.compute.material_gen import format_cif_content
|
||||
|
||||
# 处理字符串输入(如果提供)
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||
|
||||
# 如果为None,默认为空字典
|
||||
properties = properties or {}
|
||||
|
||||
# 获取或创建生成器
|
||||
generator, generator_key, properties_to_condition_on = self._get_or_create_generator(
|
||||
properties, batch_size, num_batches, diffusion_guidance_factor
|
||||
)
|
||||
|
||||
if generator is None:
|
||||
return "Error: Failed to initialize MatterGen generator"
|
||||
|
||||
# 生成结构
|
||||
try:
|
||||
generator.generate(output_dir=Path(self._output_dir))
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating structures: {e}")
|
||||
return f"Error generating structures: {e}"
|
||||
|
||||
# 创建字典存储文件内容
|
||||
result_dict = {}
|
||||
|
||||
# 定义文件路径
|
||||
cif_zip_path = os.path.join(self._output_dir, "generated_crystals_cif.zip")
|
||||
xyz_file_path = os.path.join(self._output_dir, "generated_crystals.extxyz")
|
||||
trajectories_zip_path = os.path.join(self._output_dir, "generated_trajectories.zip")
|
||||
|
||||
# 读取CIF压缩文件
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
|
||||
# 根据生成类型创建描述性提示
|
||||
if not properties:
|
||||
generation_type = "unconditional"
|
||||
title = "Generated Material Structures"
|
||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||
property_description = "unconditionally"
|
||||
elif len(properties) == 1:
|
||||
generation_type = "single_property"
|
||||
property_name = list(properties.keys())[0]
|
||||
property_value = properties[property_name]
|
||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||
property_description = f"conditioned on {property_name} = {property_value}"
|
||||
else:
|
||||
generation_type = "multi_property"
|
||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||
|
||||
# 创建完整的提示
|
||||
prompt = f"""
|
||||
# {title}
|
||||
|
||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||
|
||||
{'' if generation_type == 'unconditional' else f'''
|
||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||
the generation adheres to the specified property values. Higher values produce samples that more
|
||||
closely match the target properties but may reduce diversity.
|
||||
'''}
|
||||
|
||||
## CIF Files (Crystallographic Information Files)
|
||||
|
||||
- Standard format for crystallographic structures
|
||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||
- Used by crystallographic software and visualization tools
|
||||
|
||||
```
|
||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||
```
|
||||
|
||||
{description}
|
||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||
"""
|
||||
|
||||
# 清理文件(读取后删除)
|
||||
try:
|
||||
if os.path.exists(cif_zip_path):
|
||||
os.remove(cif_zip_path)
|
||||
if os.path.exists(xyz_file_path):
|
||||
os.remove(xyz_file_path)
|
||||
if os.path.exists(trajectories_zip_path):
|
||||
os.remove(trajectories_zip_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up files: {e}")
|
||||
|
||||
return prompt
|
||||
@@ -0,0 +1,2 @@
|
||||
|
||||
f
|
||||
|
||||
BIN
mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc
Normal file
BIN
mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
151
mattergen_api.py
Normal file
151
mattergen_api.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
import logging
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler(sys.stdout)])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FastAPI
|
||||
app = FastAPI(title="MatterGen API Service")
|
||||
|
||||
# 请求模型
|
||||
class MaterialGenerationRequest(BaseModel):
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None
|
||||
batch_size: int = 2
|
||||
num_batches: int = 1
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
|
||||
# 响应模型
|
||||
class MaterialGenerationResponse(BaseModel):
|
||||
content: str
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
# 全局变量,用于跟踪服务状态
|
||||
service_status = {
|
||||
"initialized": False,
|
||||
"error": None,
|
||||
"mattergen_service": None
|
||||
}
|
||||
|
||||
# 初始化MatterGenService
|
||||
try:
|
||||
logger.info("Importing MatterGenService...")
|
||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||
|
||||
logger.info("Initializing MatterGenService...")
|
||||
mattergen_service = MatterGenService.get_instance()
|
||||
service_status["mattergen_service"] = mattergen_service
|
||||
service_status["initialized"] = True
|
||||
logger.info("MatterGenService initialized successfully")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to initialize MatterGenService: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.error(traceback.format_exc())
|
||||
service_status["error"] = error_msg
|
||||
|
||||
# 中间件:检查服务状态
|
||||
@app.middleware("http")
|
||||
async def check_service_status(request: Request, call_next):
|
||||
# 健康检查端点不需要检查服务状态
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
# 如果服务未初始化,返回503错误
|
||||
if not service_status["initialized"]:
|
||||
error_msg = service_status["error"] or "MatterGenService not initialized"
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
content={"detail": error_msg}
|
||||
)
|
||||
|
||||
# 继续处理请求
|
||||
return await call_next(request)
|
||||
|
||||
@app.post("/generate_material", response_model=MaterialGenerationResponse)
|
||||
async def generate_material(request: MaterialGenerationRequest):
|
||||
"""生成晶体结构,可选择性地指定属性约束"""
|
||||
try:
|
||||
logger.info(f"Received material generation request with properties: {request.properties}")
|
||||
print("request",request)
|
||||
# 调用MatterGenService生成材料
|
||||
result = mattergen_service.generate(
|
||||
properties=request.properties,
|
||||
batch_size=request.batch_size,
|
||||
num_batches=request.num_batches,
|
||||
diffusion_guidance_factor=request.diffusion_guidance_factor
|
||||
)
|
||||
|
||||
logger.info("Material generation completed successfully")
|
||||
|
||||
return {
|
||||
"content": result,
|
||||
"success": True,
|
||||
"message": "Material generation successful"
|
||||
}
|
||||
except Exception as e:
|
||||
# 记录详细错误信息
|
||||
error_msg = f"Error generating material: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 返回错误响应
|
||||
return {
|
||||
"content": "",
|
||||
"success": False,
|
||||
"message": error_msg
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点,检查MatterGenService的状态"""
|
||||
if service_status["initialized"]:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "MatterGen API",
|
||||
"mattergen_service": "initialized"
|
||||
}
|
||||
else:
|
||||
error_msg = service_status["error"] or "MatterGenService not initialized"
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "MatterGen API",
|
||||
"error": error_msg
|
||||
}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""API根端点,提供基本信息"""
|
||||
return {
|
||||
"service": "MatterGen API Service",
|
||||
"description": "API for generating crystal structures with optional property constraints",
|
||||
"status": "healthy" if service_status["initialized"] else "unhealthy",
|
||||
"endpoints": {
|
||||
"/generate_material": "POST - Generate crystal structures",
|
||||
"/health": "GET - Health check",
|
||||
"/docs": "GET - API documentation"
|
||||
}
|
||||
}
|
||||
|
||||
# 全局异常处理
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unhandled exception: {str(exc)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": f"Internal server error: {str(exc)}"}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动服务
|
||||
logger.info("Starting MatterGen API Service...")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8051)
|
||||
134
mattergen_client_example.py
Normal file
134
mattergen_client_example.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import requests
|
||||
import json
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
def generate_material(
|
||||
url="http://localhost:8051/generate_material",
|
||||
properties=None,
|
||||
batch_size=2,
|
||||
num_batches=1,
|
||||
diffusion_guidance_factor=2.0
|
||||
):
|
||||
"""
|
||||
调用MatterGen API生成晶体结构
|
||||
|
||||
Args:
|
||||
url: API端点URL
|
||||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
||||
batch_size: 每批生成的结构数量
|
||||
num_batches: 批次数量
|
||||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
||||
|
||||
Returns:
|
||||
生成的结构内容或错误信息
|
||||
"""
|
||||
# 构建请求负载
|
||||
payload = {
|
||||
"properties": properties ,
|
||||
"batch_size": batch_size,
|
||||
"num_batches": num_batches,
|
||||
"diffusion_guidance_factor": diffusion_guidance_factor
|
||||
}
|
||||
|
||||
print(f"发送请求到 {url}")
|
||||
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
||||
|
||||
try:
|
||||
# 添加headers参数,包含accept头
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"accept": "application/json"
|
||||
}
|
||||
|
||||
# 打印完整请求信息(调试用)
|
||||
print(f"完整请求URL: {url}")
|
||||
print(f"请求头: {headers}")
|
||||
print(f"请求体: {json.dumps(payload)}")
|
||||
|
||||
# 禁用代理设置
|
||||
proxies = {
|
||||
"http": None,
|
||||
"https": None
|
||||
}
|
||||
|
||||
# 发送POST请求,添加headers参数,禁用代理,增加超时时间
|
||||
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
|
||||
|
||||
# 打印响应信息(调试用)
|
||||
print(f"响应状态码: {response.status_code}")
|
||||
print(f"响应头: {dict(response.headers)}")
|
||||
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
||||
if result["success"]:
|
||||
print("\n生成成功!")
|
||||
return result["content"]
|
||||
else:
|
||||
print(f"\n生成失败: {result['message']}")
|
||||
return None
|
||||
else:
|
||||
print(f"\n请求失败,状态码: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n发生错误: {str(e)}")
|
||||
print(f"错误类型: {type(e).__name__}")
|
||||
import traceback
|
||||
print(f"错误堆栈: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def main():
|
||||
"""命令行入口函数"""
|
||||
parser = argparse.ArgumentParser(description="MatterGen API客户端示例")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument("--url", default="http://localhost:8051/generate_material",
|
||||
help="MatterGen API端点URL")
|
||||
parser.add_argument("--property-name", default='dft_mag_density',help="属性名称,例如dft_band_gap")
|
||||
parser.add_argument("--property-value",default=0.15,help="属性值,例如2.0")
|
||||
parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量")
|
||||
parser.add_argument("--num-batches", type=int, default=1, help="批次数量")
|
||||
parser.add_argument("--guidance-factor", type=float, default=2.0,
|
||||
help="控制生成结构与目标属性的符合程度")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 构建属性字典
|
||||
properties = None
|
||||
if args.property_name and args.property_value:
|
||||
try:
|
||||
# 尝试将属性值转换为数字
|
||||
try:
|
||||
value = float(args.property_value)
|
||||
# 如果是整数,转换为整数
|
||||
if value.is_integer():
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
# 如果无法转换为数字,保持为字符串
|
||||
value = args.property_value
|
||||
|
||||
properties = {args.property_name: value}
|
||||
except Exception as e:
|
||||
print(f"解析属性值时出错: {str(e)}")
|
||||
return
|
||||
|
||||
# 调用API
|
||||
result = generate_material(
|
||||
url=args.url,
|
||||
properties=properties,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.num_batches,
|
||||
diffusion_guidance_factor=args.guidance_factor
|
||||
)
|
||||
|
||||
if result:
|
||||
print("\n生成的结构:")
|
||||
print(result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -171,7 +171,7 @@ if __name__ == "__main__":
|
||||
]
|
||||
|
||||
# 选择要测试的工具
|
||||
tool_name = tools_to_test[5] # 测试 search_online 工具
|
||||
tool_name = tools_to_test[6] # 测试 search_online 工具
|
||||
|
||||
# 运行测试
|
||||
result = asyncio.run(test_tool(tool_name))
|
||||
|
||||
Reference in New Issue
Block a user