mattergen转服务

This commit is contained in:
lzy
2025-04-02 16:24:50 +08:00
parent a77c2cd377
commit 7034566ee6
30 changed files with 656 additions and 339 deletions

View File

@@ -1,21 +1,7 @@
"""
Mars Toolkit
A comprehensive toolkit for materials science research, providing tools for:
- Material generation and property prediction
- Structure optimization
- Database queries (Materials Project, OQMD)
- Knowledge base retrieval
- Web search
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
# Core modules
from mars_toolkit.core.config import config
from mars_toolkit.core.utils import setup_logging
# Basic tools
from mars_toolkit.misc.misc_tools import get_current_time
@@ -35,12 +21,14 @@ from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
from mars_toolkit.query.web_search import search_online
# Visualization modules
from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
# Initialize logging
setup_logging()
__version__ = "0.1.0"
__all__ = ["llm_tool", "get_tools", "get_tool_schemas"]

View File

@@ -1,12 +1,3 @@
"""
Material Generation Module
This module provides functions for generating crystal structures with optional property constraints.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import ast
import json
@@ -276,148 +267,16 @@ async def generate_material(
Returns:
Descriptive text with generated crystal structures in CIF format
"""
# 使用配置中的结果目录
output_dir = config.MATTERGENMODEL_RESULT_PATH
# 导入MatterGenService
from mars_toolkit.services.mattergen_service import MatterGenService
# 处理字符串输入(如果提供)
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError:
raise ValueError(f"Invalid properties JSON string: {properties}")
# 获取MatterGenService实例
service = MatterGenService.get_instance()
# 如果为None默认为空字典
properties = properties or {}
# 根据生成模式处理属性
if not properties:
# 无条件生成
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
properties_to_condition_on = None
generation_type = "unconditional"
property_description = "unconditionally"
else:
# 条件生成(单属性或多属性)
properties_to_condition_on = {}
# 处理每个属性
for property_name, property_value in properties.items():
_, processed_value = preprocess_property(property_name, property_value)
properties_to_condition_on[property_name] = processed_value
# 根据属性确定使用哪个模型
if len(properties) == 1:
# 单属性条件
property_name = list(properties.keys())[0]
property_to_model = {
"dft_mag_density": "dft_mag_density",
"dft_bulk_modulus": "dft_bulk_modulus",
"dft_shear_modulus": "dft_shear_modulus",
"energy_above_hull": "energy_above_hull",
"formation_energy_per_atom": "formation_energy_per_atom",
"space_group": "space_group",
"hhi_score": "hhi_score",
"ml_bulk_modulus": "ml_bulk_modulus",
"chemical_system": "chemical_system",
"dft_band_gap": "dft_band_gap"
}
model_dir = property_to_model.get(property_name, property_name)
generation_type = "single_property"
property_description = f"conditioned on {property_name} = {properties[property_name]}"
else:
# 多属性条件
property_keys = set(properties.keys())
if property_keys == {"dft_mag_density", "hhi_score"}:
model_dir = "dft_mag_density_hhi_score"
elif property_keys == {"chemical_system", "energy_above_hull"}:
model_dir = "chemical_system_energy_above_hull"
else:
# 如果没有特定的多属性模型,使用第一个属性的模型
first_property = list(properties.keys())[0]
model_dir = first_property
generation_type = "multi_property"
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
# 构建完整的模型路径
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
# 检查模型目录是否存在
if not os.path.exists(model_path):
# 如果特定模型不存在,回退到基础模型
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
# 使用适当的参数调用main函数
main(
output_path=output_dir,
model_path=model_path,
# 使用服务生成材料
return service.generate(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
properties_to_condition_on=properties_to_condition_on,
record_trajectories=True,
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
diffusion_guidance_factor=diffusion_guidance_factor
)
# 创建字典存储文件内容
result_dict = {}
# 定义文件路径
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
# 读取CIF压缩文件
if os.path.exists(cif_zip_path):
with open(cif_zip_path, 'rb') as f:
result_dict['cif_content'] = f.read()
# 根据生成类型创建描述性提示
if generation_type == "unconditional":
title = "Generated Material Structures"
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
elif generation_type == "single_property":
property_name = list(properties.keys())[0]
property_value = properties[property_name]
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
else: # multi_property
title = "Generated Material Structures Conditioned on Multiple Properties"
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
# 创建完整的提示
prompt = f"""
# {title}
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
{'' if generation_type == 'unconditional' else f'''
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
the generation adheres to the specified property values. Higher values produce samples that more
closely match the target properties but may reduce diversity.
'''}
## CIF Files (Crystallographic Information Files)
- Standard format for crystallographic structures
- Contains unit cell parameters, atomic positions, and symmetry information
- Used by crystallographic software and visualization tools
```
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
```
{description}
You can use these structures for materials discovery, property prediction, or further analysis.
"""
# 清理文件(读取后删除)
try:
if os.path.exists(cif_zip_path):
os.remove(cif_zip_path)
if os.path.exists(xyz_file_path):
os.remove(xyz_file_path)
if os.path.exists(trajectories_zip_path):
os.remove(trajectories_zip_path)
except Exception as e:
logger.warning(f"Error cleaning up files: {e}")
return prompt

View File

@@ -23,7 +23,6 @@ from pymatgen.io.cif import CifWriter
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
from mars_toolkit.core.error_handlers import handle_general_error
logger = logging.getLogger(__name__)
@@ -189,4 +188,4 @@ async def optimize_crystal_structure(
# 直接返回结果或抛出异常
return await asyncio.to_thread(run_optimization)
except Exception as e:
return handle_general_error(e)
return str(e)

View File

@@ -5,9 +5,4 @@ This module provides core functionality for the Mars Toolkit.
"""
from mars_toolkit.core.config import config
from mars_toolkit.core.utils import settings, setup_logging
from mars_toolkit.core.error_handlers import (
handle_minio_error, handle_http_error,
handle_validation_error, handle_general_error
)
from mars_toolkit.core.llm_tools import llm_tool

View File

@@ -3,10 +3,6 @@ CIF Utilities Module
This module provides basic functions for handling CIF (Crystallographic Information File) files,
which are commonly used in materials science for representing crystal structures.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import json

View File

@@ -1,55 +0,0 @@
"""
Error Handlers Module
This module provides error handling utilities for the Mars Toolkit.
It includes functions for handling various types of errors that may occur
during toolkit operations.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
from fastapi import HTTPException
from typing import Any, Dict
import logging
logger = logging.getLogger(__name__)
class APIError(HTTPException):
"""自定义API错误类"""
def __init__(self, status_code: int, detail: Any = None):
super().__init__(status_code=status_code, detail=detail)
logger.error(f"API Error: {status_code} - {detail}")
def handle_minio_error(e: Exception) -> Dict[str, str]:
"""处理MinIO相关错误"""
logger.error(f"MinIO operation failed: {str(e)}")
return {
"status": "error",
"data": f"MinIO operation failed: {str(e)}"
}
def handle_http_error(e: Exception) -> Dict[str, str]:
"""处理HTTP请求错误"""
logger.error(f"HTTP request failed: {str(e)}")
return {
"status": "error",
"data": f"HTTP request failed: {str(e)}"
}
def handle_validation_error(e: Exception) -> Dict[str, str]:
"""处理数据验证错误"""
logger.error(f"Validation failed: {str(e)}")
return {
"status": "error",
"data": f"Validation failed: {str(e)}"
}
def handle_general_error(e: Exception) -> Dict[str, str]:
"""处理通用错误"""
logger.error(f"Unexpected error: {str(e)}")
return {
"status": "error",
"data": f"Unexpected error: {str(e)}"
}

View File

@@ -1,75 +0,0 @@
import os
import boto3
import logging
import logging.config
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
logger = logging.getLogger(__name__)
class Settings(BaseSettings):
# Material Project
mp_api_key: Optional[str] = Field(None, env="MP_API_KEY")
mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT")
mp_topk: Optional[int] = Field(3, env="MP_TOPK")
# Proxy
http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
# FairChem
fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH")
fmax: Optional[float] = Field(0.05, env="FMAX")
# MinIO
minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT")
internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT")
minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY")
minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY")
minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET")
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
def setup_logging():
"""配置日志记录"""
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
log_file_path = os.path.join(parent_dir, 'mars_toolkit.log')
logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S'
},
},
'handlers': {
'console': {
'level': 'INFO',
'class': 'logging.StreamHandler',
'formatter': 'standard'
},
'file': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': log_file_path,
'maxBytes': 10485760, # 10MB
'backupCount': 5,
'formatter': 'standard'
}
},
'loggers': {
'': {
'handlers': ['console', 'file'],
'level': 'INFO',
'propagate': True
}
}
})
# 初始化配置
settings = Settings()

View File

@@ -1,13 +1,4 @@
"""
Dify Search Module
This module provides functions for retrieving information from local materials science
literature knowledge base using Dify API.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import asyncio
import json

View File

@@ -1,14 +1,3 @@
"""
Materials Project Query Module
This module provides functions for querying the Materials Project database,
processing search results, and formatting responses.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import glob
import json
import asyncio
@@ -25,7 +14,6 @@ from pymatgen.io.cif import CifWriter
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
from mars_toolkit.core.error_handlers import handle_general_error
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
logger = logging.getLogger(__name__)

View File

@@ -1,13 +1,3 @@
"""
OQMD Query Module
This module provides functions for querying the Open Quantum Materials Database (OQMD).
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import logging
import httpx
import pandas as pd

View File

@@ -0,0 +1,12 @@
"""
Services module for mars_toolkit.
This module contains service classes that provide persistent functionality
across multiple function calls, such as maintaining initialized models.
"""
# Import services for easy access
from mars_toolkit.services.mattergen_service import MatterGenService
# Export services
__all__ = ['MatterGenService']

View File

@@ -0,0 +1,342 @@
"""
MatterGen service for mars_toolkit.
This module provides a service for generating crystal structures using MatterGen.
The service initializes the CrystalGenerator once and reuses it for multiple
generation requests, improving performance.
"""
import os
import logging
import json
from pathlib import Path
from typing import Dict, Any, Optional, Union, List
import threading
# 导入mattergen相关模块
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../')))
from mattergen_wrapper import generator
CrystalGenerator = generator.CrystalGenerator
from mattergen.common.data.types import TargetProperty
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
# 导入mars_toolkit配置
from mars_toolkit.core.config import config
logger = logging.getLogger(__name__)
class MatterGenService:
"""
Service for generating crystal structures using MatterGen.
This service initializes the CrystalGenerator once and reuses it for multiple
generation requests, improving performance.
"""
_instance = None
_lock = threading.Lock()
@classmethod
def get_instance(cls):
"""
Get the singleton instance of MatterGenService.
Returns:
MatterGenService: The singleton instance.
"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""
Initialize the MatterGenService.
This initializes the base generator without any property conditioning.
Specific generators for different property conditions will be initialized
on demand.
"""
self._generators = {}
self._output_dir = config.MATTERGENMODEL_RESULT_PATH
# 确保输出目录存在
if not os.path.exists(self._output_dir):
os.makedirs(self._output_dir)
# 初始化基础生成器(无条件生成)
self._init_base_generator()
def _init_base_generator(self):
"""
Initialize the base generator for unconditional generation.
"""
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
if not os.path.exists(model_path):
logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.")
return
logger.info(f"Initializing base MatterGen generator from {model_path}")
try:
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch="last",
config_overrides=[],
strict_checkpoint_loading=True,
)
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=None,
batch_size=2, # 默认值,可在生成时覆盖
num_batches=1, # 默认值,可在生成时覆盖
sampling_config_name="default",
sampling_config_path=None,
sampling_config_overrides=[],
record_trajectories=True,
diffusion_guidance_factor=0.0,
target_compositions_dict=[],
)
self._generators["base"] = generator
logger.info("Base MatterGen generator initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize base MatterGen generator: {e}")
def _get_or_create_generator(
self,
properties: Optional[Dict[str, Any]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
):
"""
Get or create a generator for the specified properties.
Args:
properties: Optional property constraints
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
tuple: (generator, generator_key, properties_to_condition_on)
"""
# 如果没有属性约束,使用基础生成器
if not properties:
if "base" not in self._generators:
self._init_base_generator()
return self._generators.get("base"), "base", None
# 处理属性约束
properties_to_condition_on = {}
for property_name, property_value in properties.items():
properties_to_condition_on[property_name] = property_value
# 确定模型目录
if len(properties) == 1:
# 单属性条件
property_name = list(properties.keys())[0]
property_to_model = {
"dft_mag_density": "dft_mag_density",
"dft_bulk_modulus": "dft_bulk_modulus",
"dft_shear_modulus": "dft_shear_modulus",
"energy_above_hull": "energy_above_hull",
"formation_energy_per_atom": "formation_energy_per_atom",
"space_group": "space_group",
"hhi_score": "hhi_score",
"ml_bulk_modulus": "ml_bulk_modulus",
"chemical_system": "chemical_system",
"dft_band_gap": "dft_band_gap"
}
model_dir = property_to_model.get(property_name, property_name)
generator_key = f"single_{property_name}"
else:
# 多属性条件
property_keys = set(properties.keys())
if property_keys == {"dft_mag_density", "hhi_score"}:
model_dir = "dft_mag_density_hhi_score"
generator_key = "multi_dft_mag_density_hhi_score"
elif property_keys == {"chemical_system", "energy_above_hull"}:
model_dir = "chemical_system_energy_above_hull"
generator_key = "multi_chemical_system_energy_above_hull"
else:
# 如果没有特定的多属性模型,使用第一个属性的模型
first_property = list(properties.keys())[0]
model_dir = first_property
generator_key = f"multi_{first_property}_etc"
# 构建完整的模型路径
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
# 检查模型目录是否存在
if not os.path.exists(model_path):
# 如果特定模型不存在,回退到基础模型
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
generator_key = "base"
# 检查是否已经有这个生成器
if generator_key in self._generators:
# 更新生成器的参数
generator = self._generators[generator_key]
generator.batch_size = batch_size
generator.num_batches = num_batches
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
return generator, generator_key, properties_to_condition_on
# 创建新的生成器
try:
logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}")
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch="last",
config_overrides=[],
strict_checkpoint_loading=True,
)
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=properties_to_condition_on,
batch_size=batch_size,
num_batches=num_batches,
sampling_config_name="default",
sampling_config_path=None,
sampling_config_overrides=[],
record_trajectories=True,
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0,
target_compositions_dict=[],
)
self._generators[generator_key] = generator
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
return generator, generator_key, properties_to_condition_on
except Exception as e:
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
# 回退到基础生成器
if "base" not in self._generators:
self._init_base_generator()
return self._generators.get("base"), "base", None
def generate(
self,
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
) -> str:
"""
Generate crystal structures with optional property constraints.
Args:
properties: Optional property constraints
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
str: Descriptive text with generated crystal structures in CIF format
"""
from mars_toolkit.compute.material_gen import format_cif_content
# 处理字符串输入(如果提供)
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError:
raise ValueError(f"Invalid properties JSON string: {properties}")
# 如果为None默认为空字典
properties = properties or {}
# 获取或创建生成器
generator, generator_key, properties_to_condition_on = self._get_or_create_generator(
properties, batch_size, num_batches, diffusion_guidance_factor
)
if generator is None:
return "Error: Failed to initialize MatterGen generator"
# 生成结构
try:
generator.generate(output_dir=Path(self._output_dir))
except Exception as e:
logger.error(f"Error generating structures: {e}")
return f"Error generating structures: {e}"
# 创建字典存储文件内容
result_dict = {}
# 定义文件路径
cif_zip_path = os.path.join(self._output_dir, "generated_crystals_cif.zip")
xyz_file_path = os.path.join(self._output_dir, "generated_crystals.extxyz")
trajectories_zip_path = os.path.join(self._output_dir, "generated_trajectories.zip")
# 读取CIF压缩文件
if os.path.exists(cif_zip_path):
with open(cif_zip_path, 'rb') as f:
result_dict['cif_content'] = f.read()
# 根据生成类型创建描述性提示
if not properties:
generation_type = "unconditional"
title = "Generated Material Structures"
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
property_description = "unconditionally"
elif len(properties) == 1:
generation_type = "single_property"
property_name = list(properties.keys())[0]
property_value = properties[property_name]
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
property_description = f"conditioned on {property_name} = {property_value}"
else:
generation_type = "multi_property"
title = "Generated Material Structures Conditioned on Multiple Properties"
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
# 创建完整的提示
prompt = f"""
# {title}
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
{'' if generation_type == 'unconditional' else f'''
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
the generation adheres to the specified property values. Higher values produce samples that more
closely match the target properties but may reduce diversity.
'''}
## CIF Files (Crystallographic Information Files)
- Standard format for crystallographic structures
- Contains unit cell parameters, atomic positions, and symmetry information
- Used by crystallographic software and visualization tools
```
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
```
{description}
You can use these structures for materials discovery, property prediction, or further analysis.
"""
# 清理文件(读取后删除)
try:
if os.path.exists(cif_zip_path):
os.remove(cif_zip_path)
if os.path.exists(xyz_file_path):
os.remove(xyz_file_path)
if os.path.exists(trajectories_zip_path):
os.remove(trajectories_zip_path)
except Exception as e:
logger.warning(f"Error cleaning up files: {e}")
return prompt

View File

@@ -0,0 +1,2 @@
f

151
mattergen_api.py Normal file
View File

@@ -0,0 +1,151 @@
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import uvicorn
from typing import Dict, Any, Optional, Union, List
import logging
import traceback
import sys
# 配置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)])
logger = logging.getLogger(__name__)
# 初始化FastAPI
app = FastAPI(title="MatterGen API Service")
# 请求模型
class MaterialGenerationRequest(BaseModel):
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None
batch_size: int = 2
num_batches: int = 1
diffusion_guidance_factor: float = 2.0
# 响应模型
class MaterialGenerationResponse(BaseModel):
content: str
success: bool
message: str
# 全局变量,用于跟踪服务状态
service_status = {
"initialized": False,
"error": None,
"mattergen_service": None
}
# 初始化MatterGenService
try:
logger.info("Importing MatterGenService...")
from mars_toolkit.services.mattergen_service import MatterGenService
logger.info("Initializing MatterGenService...")
mattergen_service = MatterGenService.get_instance()
service_status["mattergen_service"] = mattergen_service
service_status["initialized"] = True
logger.info("MatterGenService initialized successfully")
except Exception as e:
error_msg = f"Failed to initialize MatterGenService: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
service_status["error"] = error_msg
# 中间件:检查服务状态
@app.middleware("http")
async def check_service_status(request: Request, call_next):
# 健康检查端点不需要检查服务状态
if request.url.path == "/health":
return await call_next(request)
# 如果服务未初始化返回503错误
if not service_status["initialized"]:
error_msg = service_status["error"] or "MatterGenService not initialized"
return JSONResponse(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
content={"detail": error_msg}
)
# 继续处理请求
return await call_next(request)
@app.post("/generate_material", response_model=MaterialGenerationResponse)
async def generate_material(request: MaterialGenerationRequest):
"""生成晶体结构,可选择性地指定属性约束"""
try:
logger.info(f"Received material generation request with properties: {request.properties}")
print("request",request)
# 调用MatterGenService生成材料
result = mattergen_service.generate(
properties=request.properties,
batch_size=request.batch_size,
num_batches=request.num_batches,
diffusion_guidance_factor=request.diffusion_guidance_factor
)
logger.info("Material generation completed successfully")
return {
"content": result,
"success": True,
"message": "Material generation successful"
}
except Exception as e:
# 记录详细错误信息
error_msg = f"Error generating material: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
# 返回错误响应
return {
"content": "",
"success": False,
"message": error_msg
}
@app.get("/health")
async def health_check():
"""健康检查端点检查MatterGenService的状态"""
if service_status["initialized"]:
return {
"status": "healthy",
"service": "MatterGen API",
"mattergen_service": "initialized"
}
else:
error_msg = service_status["error"] or "MatterGenService not initialized"
return {
"status": "unhealthy",
"service": "MatterGen API",
"error": error_msg
}
@app.get("/")
async def root():
"""API根端点提供基本信息"""
return {
"service": "MatterGen API Service",
"description": "API for generating crystal structures with optional property constraints",
"status": "healthy" if service_status["initialized"] else "unhealthy",
"endpoints": {
"/generate_material": "POST - Generate crystal structures",
"/health": "GET - Health check",
"/docs": "GET - API documentation"
}
}
# 全局异常处理
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled exception: {str(exc)}")
logger.error(traceback.format_exc())
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": f"Internal server error: {str(exc)}"}
)
if __name__ == "__main__":
# 启动服务
logger.info("Starting MatterGen API Service...")
uvicorn.run(app, host="0.0.0.0", port=8051)

134
mattergen_client_example.py Normal file
View File

@@ -0,0 +1,134 @@
import requests
import json
import argparse
import sys
def generate_material(
url="http://localhost:8051/generate_material",
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen API生成晶体结构
Args:
url: API端点URL
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
# 构建请求负载
payload = {
"properties": properties ,
"batch_size": batch_size,
"num_batches": num_batches,
"diffusion_guidance_factor": diffusion_guidance_factor
}
print(f"发送请求到 {url}")
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
# 添加headers参数包含accept头
headers = {
"Content-Type": "application/json",
"accept": "application/json"
}
# 打印完整请求信息(调试用)
print(f"完整请求URL: {url}")
print(f"请求头: {headers}")
print(f"请求体: {json.dumps(payload)}")
# 禁用代理设置
proxies = {
"http": None,
"https": None
}
# 发送POST请求添加headers参数禁用代理增加超时时间
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
# 打印响应信息(调试用)
print(f"响应状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符避免输出过长
# 检查响应状态
if response.status_code == 200:
result = response.json()
if result["success"]:
print("\n生成成功!")
return result["content"]
else:
print(f"\n生成失败: {result['message']}")
return None
else:
print(f"\n请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return None
except Exception as e:
print(f"\n发生错误: {str(e)}")
print(f"错误类型: {type(e).__name__}")
import traceback
print(f"错误堆栈: {traceback.format_exc()}")
return None
def main():
"""命令行入口函数"""
parser = argparse.ArgumentParser(description="MatterGen API客户端示例")
# 添加命令行参数
parser.add_argument("--url", default="http://localhost:8051/generate_material",
help="MatterGen API端点URL")
parser.add_argument("--property-name", default='dft_mag_density',help="属性名称例如dft_band_gap")
parser.add_argument("--property-value",default=0.15,help="属性值例如2.0")
parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量")
parser.add_argument("--num-batches", type=int, default=1, help="批次数量")
parser.add_argument("--guidance-factor", type=float, default=2.0,
help="控制生成结构与目标属性的符合程度")
args = parser.parse_args()
# 构建属性字典
properties = None
if args.property_name and args.property_value:
try:
# 尝试将属性值转换为数字
try:
value = float(args.property_value)
# 如果是整数,转换为整数
if value.is_integer():
value = int(value)
except ValueError:
# 如果无法转换为数字,保持为字符串
value = args.property_value
properties = {args.property_name: value}
except Exception as e:
print(f"解析属性值时出错: {str(e)}")
return
# 调用API
result = generate_material(
url=args.url,
properties=properties,
batch_size=args.batch_size,
num_batches=args.num_batches,
diffusion_guidance_factor=args.guidance_factor
)
if result:
print("\n生成的结构:")
print(result)
if __name__ == "__main__":
main()

View File

@@ -171,7 +171,7 @@ if __name__ == "__main__":
]
# 选择要测试的工具
tool_name = tools_to_test[5] # 测试 search_online 工具
tool_name = tools_to_test[6] # 测试 search_online 工具
# 运行测试
result = asyncio.run(test_tool(tool_name))