mattergen转服务
This commit is contained in:
@@ -1,21 +1,7 @@
|
|||||||
"""
|
|
||||||
Mars Toolkit
|
|
||||||
|
|
||||||
A comprehensive toolkit for materials science research, providing tools for:
|
|
||||||
- Material generation and property prediction
|
|
||||||
- Structure optimization
|
|
||||||
- Database queries (Materials Project, OQMD)
|
|
||||||
- Knowledge base retrieval
|
|
||||||
- Web search
|
|
||||||
|
|
||||||
Author: Yutang LI
|
|
||||||
Institution: SIAT-MIC
|
|
||||||
Contact: yt.li2@siat.ac.cn
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Core modules
|
# Core modules
|
||||||
from mars_toolkit.core.config import config
|
from mars_toolkit.core.config import config
|
||||||
from mars_toolkit.core.utils import setup_logging
|
|
||||||
|
|
||||||
# Basic tools
|
# Basic tools
|
||||||
from mars_toolkit.misc.misc_tools import get_current_time
|
from mars_toolkit.misc.misc_tools import get_current_time
|
||||||
@@ -35,12 +21,14 @@ from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
|||||||
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||||
from mars_toolkit.query.web_search import search_online
|
from mars_toolkit.query.web_search import search_online
|
||||||
|
|
||||||
|
# Visualization modules
|
||||||
|
|
||||||
|
|
||||||
from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
|
from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize logging
|
|
||||||
setup_logging()
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
__all__ = ["llm_tool", "get_tools", "get_tool_schemas"]
|
__all__ = ["llm_tool", "get_tools", "get_tool_schemas"]
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,12 +1,3 @@
|
|||||||
"""
|
|
||||||
Material Generation Module
|
|
||||||
|
|
||||||
This module provides functions for generating crystal structures with optional property constraints.
|
|
||||||
|
|
||||||
Author: Yutang LI
|
|
||||||
Institution: SIAT-MIC
|
|
||||||
Contact: yt.li2@siat.ac.cn
|
|
||||||
"""
|
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import json
|
import json
|
||||||
@@ -276,148 +267,16 @@ async def generate_material(
|
|||||||
Returns:
|
Returns:
|
||||||
Descriptive text with generated crystal structures in CIF format
|
Descriptive text with generated crystal structures in CIF format
|
||||||
"""
|
"""
|
||||||
# 使用配置中的结果目录
|
# 导入MatterGenService
|
||||||
output_dir = config.MATTERGENMODEL_RESULT_PATH
|
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||||
|
|
||||||
# 处理字符串输入(如果提供)
|
# 获取MatterGenService实例
|
||||||
if isinstance(properties, str):
|
service = MatterGenService.get_instance()
|
||||||
try:
|
|
||||||
properties = json.loads(properties)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
|
||||||
|
|
||||||
# 如果为None,默认为空字典
|
# 使用服务生成材料
|
||||||
properties = properties or {}
|
return service.generate(
|
||||||
|
properties=properties,
|
||||||
# 根据生成模式处理属性
|
|
||||||
if not properties:
|
|
||||||
# 无条件生成
|
|
||||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
|
||||||
properties_to_condition_on = None
|
|
||||||
generation_type = "unconditional"
|
|
||||||
property_description = "unconditionally"
|
|
||||||
else:
|
|
||||||
# 条件生成(单属性或多属性)
|
|
||||||
properties_to_condition_on = {}
|
|
||||||
|
|
||||||
# 处理每个属性
|
|
||||||
for property_name, property_value in properties.items():
|
|
||||||
_, processed_value = preprocess_property(property_name, property_value)
|
|
||||||
properties_to_condition_on[property_name] = processed_value
|
|
||||||
|
|
||||||
# 根据属性确定使用哪个模型
|
|
||||||
if len(properties) == 1:
|
|
||||||
# 单属性条件
|
|
||||||
property_name = list(properties.keys())[0]
|
|
||||||
property_to_model = {
|
|
||||||
"dft_mag_density": "dft_mag_density",
|
|
||||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
|
||||||
"dft_shear_modulus": "dft_shear_modulus",
|
|
||||||
"energy_above_hull": "energy_above_hull",
|
|
||||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
|
||||||
"space_group": "space_group",
|
|
||||||
"hhi_score": "hhi_score",
|
|
||||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
|
||||||
"chemical_system": "chemical_system",
|
|
||||||
"dft_band_gap": "dft_band_gap"
|
|
||||||
}
|
|
||||||
model_dir = property_to_model.get(property_name, property_name)
|
|
||||||
generation_type = "single_property"
|
|
||||||
property_description = f"conditioned on {property_name} = {properties[property_name]}"
|
|
||||||
else:
|
|
||||||
# 多属性条件
|
|
||||||
property_keys = set(properties.keys())
|
|
||||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
|
||||||
model_dir = "dft_mag_density_hhi_score"
|
|
||||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
|
||||||
model_dir = "chemical_system_energy_above_hull"
|
|
||||||
else:
|
|
||||||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
|
||||||
first_property = list(properties.keys())[0]
|
|
||||||
model_dir = first_property
|
|
||||||
generation_type = "multi_property"
|
|
||||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
|
||||||
|
|
||||||
# 构建完整的模型路径
|
|
||||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
|
|
||||||
|
|
||||||
# 检查模型目录是否存在
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
# 如果特定模型不存在,回退到基础模型
|
|
||||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
|
||||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
|
||||||
|
|
||||||
# 使用适当的参数调用main函数
|
|
||||||
main(
|
|
||||||
output_path=output_dir,
|
|
||||||
model_path=model_path,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_batches=num_batches,
|
num_batches=num_batches,
|
||||||
properties_to_condition_on=properties_to_condition_on,
|
diffusion_guidance_factor=diffusion_guidance_factor
|
||||||
record_trajectories=True,
|
|
||||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建字典存储文件内容
|
|
||||||
result_dict = {}
|
|
||||||
|
|
||||||
# 定义文件路径
|
|
||||||
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
|
|
||||||
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
|
|
||||||
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
|
|
||||||
|
|
||||||
# 读取CIF压缩文件
|
|
||||||
if os.path.exists(cif_zip_path):
|
|
||||||
with open(cif_zip_path, 'rb') as f:
|
|
||||||
result_dict['cif_content'] = f.read()
|
|
||||||
|
|
||||||
# 根据生成类型创建描述性提示
|
|
||||||
if generation_type == "unconditional":
|
|
||||||
title = "Generated Material Structures"
|
|
||||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
|
||||||
elif generation_type == "single_property":
|
|
||||||
property_name = list(properties.keys())[0]
|
|
||||||
property_value = properties[property_name]
|
|
||||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
|
||||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
|
||||||
else: # multi_property
|
|
||||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
|
||||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
|
||||||
|
|
||||||
# 创建完整的提示
|
|
||||||
prompt = f"""
|
|
||||||
# {title}
|
|
||||||
|
|
||||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
|
||||||
|
|
||||||
{'' if generation_type == 'unconditional' else f'''
|
|
||||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
|
||||||
the generation adheres to the specified property values. Higher values produce samples that more
|
|
||||||
closely match the target properties but may reduce diversity.
|
|
||||||
'''}
|
|
||||||
|
|
||||||
## CIF Files (Crystallographic Information Files)
|
|
||||||
|
|
||||||
- Standard format for crystallographic structures
|
|
||||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
|
||||||
- Used by crystallographic software and visualization tools
|
|
||||||
|
|
||||||
```
|
|
||||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
|
||||||
```
|
|
||||||
|
|
||||||
{description}
|
|
||||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 清理文件(读取后删除)
|
|
||||||
try:
|
|
||||||
if os.path.exists(cif_zip_path):
|
|
||||||
os.remove(cif_zip_path)
|
|
||||||
if os.path.exists(xyz_file_path):
|
|
||||||
os.remove(xyz_file_path)
|
|
||||||
if os.path.exists(trajectories_zip_path):
|
|
||||||
os.remove(trajectories_zip_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error cleaning up files: {e}")
|
|
||||||
return prompt
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from pymatgen.io.cif import CifWriter
|
|||||||
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
|
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
|
||||||
from mars_toolkit.core.llm_tools import llm_tool
|
from mars_toolkit.core.llm_tools import llm_tool
|
||||||
from mars_toolkit.core.config import config
|
from mars_toolkit.core.config import config
|
||||||
from mars_toolkit.core.error_handlers import handle_general_error
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -189,4 +188,4 @@ async def optimize_crystal_structure(
|
|||||||
# 直接返回结果或抛出异常
|
# 直接返回结果或抛出异常
|
||||||
return await asyncio.to_thread(run_optimization)
|
return await asyncio.to_thread(run_optimization)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return handle_general_error(e)
|
return str(e)
|
||||||
|
|||||||
@@ -5,9 +5,4 @@ This module provides core functionality for the Mars Toolkit.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from mars_toolkit.core.config import config
|
from mars_toolkit.core.config import config
|
||||||
from mars_toolkit.core.utils import settings, setup_logging
|
|
||||||
from mars_toolkit.core.error_handlers import (
|
|
||||||
handle_minio_error, handle_http_error,
|
|
||||||
handle_validation_error, handle_general_error
|
|
||||||
)
|
|
||||||
from mars_toolkit.core.llm_tools import llm_tool
|
from mars_toolkit.core.llm_tools import llm_tool
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -3,10 +3,6 @@ CIF Utilities Module
|
|||||||
|
|
||||||
This module provides basic functions for handling CIF (Crystallographic Information File) files,
|
This module provides basic functions for handling CIF (Crystallographic Information File) files,
|
||||||
which are commonly used in materials science for representing crystal structures.
|
which are commonly used in materials science for representing crystal structures.
|
||||||
|
|
||||||
Author: Yutang LI
|
|
||||||
Institution: SIAT-MIC
|
|
||||||
Contact: yt.li2@siat.ac.cn
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
"""
|
|
||||||
Error Handlers Module
|
|
||||||
|
|
||||||
This module provides error handling utilities for the Mars Toolkit.
|
|
||||||
It includes functions for handling various types of errors that may occur
|
|
||||||
during toolkit operations.
|
|
||||||
|
|
||||||
Author: Yutang LI
|
|
||||||
Institution: SIAT-MIC
|
|
||||||
Contact: yt.li2@siat.ac.cn
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from typing import Any, Dict
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class APIError(HTTPException):
|
|
||||||
"""自定义API错误类"""
|
|
||||||
def __init__(self, status_code: int, detail: Any = None):
|
|
||||||
super().__init__(status_code=status_code, detail=detail)
|
|
||||||
logger.error(f"API Error: {status_code} - {detail}")
|
|
||||||
|
|
||||||
def handle_minio_error(e: Exception) -> Dict[str, str]:
|
|
||||||
"""处理MinIO相关错误"""
|
|
||||||
logger.error(f"MinIO operation failed: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"data": f"MinIO operation failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def handle_http_error(e: Exception) -> Dict[str, str]:
|
|
||||||
"""处理HTTP请求错误"""
|
|
||||||
logger.error(f"HTTP request failed: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"data": f"HTTP request failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def handle_validation_error(e: Exception) -> Dict[str, str]:
|
|
||||||
"""处理数据验证错误"""
|
|
||||||
logger.error(f"Validation failed: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"data": f"Validation failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def handle_general_error(e: Exception) -> Dict[str, str]:
|
|
||||||
"""处理通用错误"""
|
|
||||||
logger.error(f"Unexpected error: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"data": f"Unexpected error: {str(e)}"
|
|
||||||
}
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
import os
|
|
||||||
import boto3
|
|
||||||
import logging
|
|
||||||
import logging.config
|
|
||||||
from typing import Optional
|
|
||||||
from pydantic import Field
|
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
# Material Project
|
|
||||||
mp_api_key: Optional[str] = Field(None, env="MP_API_KEY")
|
|
||||||
mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT")
|
|
||||||
mp_topk: Optional[int] = Field(3, env="MP_TOPK")
|
|
||||||
|
|
||||||
# Proxy
|
|
||||||
http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
|
|
||||||
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
|
|
||||||
|
|
||||||
# FairChem
|
|
||||||
fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH")
|
|
||||||
fmax: Optional[float] = Field(0.05, env="FMAX")
|
|
||||||
|
|
||||||
# MinIO
|
|
||||||
minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT")
|
|
||||||
internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT")
|
|
||||||
minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY")
|
|
||||||
minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY")
|
|
||||||
minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
env_file = ".env"
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
|
|
||||||
def setup_logging():
|
|
||||||
"""配置日志记录"""
|
|
||||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
|
|
||||||
log_file_path = os.path.join(parent_dir, 'mars_toolkit.log')
|
|
||||||
|
|
||||||
logging.config.dictConfig({
|
|
||||||
'version': 1,
|
|
||||||
'disable_existing_loggers': False,
|
|
||||||
'formatters': {
|
|
||||||
'standard': {
|
|
||||||
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
'datefmt': '%Y-%m-%d %H:%M:%S'
|
|
||||||
},
|
|
||||||
},
|
|
||||||
'handlers': {
|
|
||||||
'console': {
|
|
||||||
'level': 'INFO',
|
|
||||||
'class': 'logging.StreamHandler',
|
|
||||||
'formatter': 'standard'
|
|
||||||
},
|
|
||||||
'file': {
|
|
||||||
'level': 'DEBUG',
|
|
||||||
'class': 'logging.handlers.RotatingFileHandler',
|
|
||||||
'filename': log_file_path,
|
|
||||||
'maxBytes': 10485760, # 10MB
|
|
||||||
'backupCount': 5,
|
|
||||||
'formatter': 'standard'
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'loggers': {
|
|
||||||
'': {
|
|
||||||
'handlers': ['console', 'file'],
|
|
||||||
'level': 'INFO',
|
|
||||||
'propagate': True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# 初始化配置
|
|
||||||
settings = Settings()
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,13 +1,4 @@
|
|||||||
"""
|
|
||||||
Dify Search Module
|
|
||||||
|
|
||||||
This module provides functions for retrieving information from local materials science
|
|
||||||
literature knowledge base using Dify API.
|
|
||||||
|
|
||||||
Author: Yutang LI
|
|
||||||
Institution: SIAT-MIC
|
|
||||||
Contact: yt.li2@siat.ac.cn
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|||||||
@@ -1,14 +1,3 @@
|
|||||||
"""
|
|
||||||
Materials Project Query Module
|
|
||||||
|
|
||||||
This module provides functions for querying the Materials Project database,
|
|
||||||
processing search results, and formatting responses.
|
|
||||||
|
|
||||||
Author: Yutang LI
|
|
||||||
Institution: SIAT-MIC
|
|
||||||
Contact: yt.li2@siat.ac.cn
|
|
||||||
"""
|
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -25,7 +14,6 @@ from pymatgen.io.cif import CifWriter
|
|||||||
|
|
||||||
from mars_toolkit.core.llm_tools import llm_tool
|
from mars_toolkit.core.llm_tools import llm_tool
|
||||||
from mars_toolkit.core.config import config
|
from mars_toolkit.core.config import config
|
||||||
from mars_toolkit.core.error_handlers import handle_general_error
|
|
||||||
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
|
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,13 +1,3 @@
|
|||||||
"""
|
|
||||||
OQMD Query Module
|
|
||||||
|
|
||||||
This module provides functions for querying the Open Quantum Materials Database (OQMD).
|
|
||||||
|
|
||||||
Author: Yutang LI
|
|
||||||
Institution: SIAT-MIC
|
|
||||||
Contact: yt.li2@siat.ac.cn
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import httpx
|
import httpx
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|||||||
12
mars_toolkit/services/__init__.py
Normal file
12
mars_toolkit/services/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Services module for mars_toolkit.
|
||||||
|
|
||||||
|
This module contains service classes that provide persistent functionality
|
||||||
|
across multiple function calls, such as maintaining initialized models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import services for easy access
|
||||||
|
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||||
|
|
||||||
|
# Export services
|
||||||
|
__all__ = ['MatterGenService']
|
||||||
BIN
mars_toolkit/services/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/services/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
342
mars_toolkit/services/mattergen_service.py
Normal file
342
mars_toolkit/services/mattergen_service.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
MatterGen service for mars_toolkit.
|
||||||
|
|
||||||
|
This module provides a service for generating crystal structures using MatterGen.
|
||||||
|
The service initializes the CrystalGenerator once and reuses it for multiple
|
||||||
|
generation requests, improving performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional, Union, List
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# 导入mattergen相关模块
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../')))
|
||||||
|
from mattergen_wrapper import generator
|
||||||
|
CrystalGenerator = generator.CrystalGenerator
|
||||||
|
from mattergen.common.data.types import TargetProperty
|
||||||
|
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||||||
|
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
|
||||||
|
|
||||||
|
# 导入mars_toolkit配置
|
||||||
|
from mars_toolkit.core.config import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MatterGenService:
|
||||||
|
"""
|
||||||
|
Service for generating crystal structures using MatterGen.
|
||||||
|
|
||||||
|
This service initializes the CrystalGenerator once and reuses it for multiple
|
||||||
|
generation requests, improving performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
"""
|
||||||
|
Get the singleton instance of MatterGenService.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MatterGenService: The singleton instance.
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the MatterGenService.
|
||||||
|
|
||||||
|
This initializes the base generator without any property conditioning.
|
||||||
|
Specific generators for different property conditions will be initialized
|
||||||
|
on demand.
|
||||||
|
"""
|
||||||
|
self._generators = {}
|
||||||
|
self._output_dir = config.MATTERGENMODEL_RESULT_PATH
|
||||||
|
|
||||||
|
# 确保输出目录存在
|
||||||
|
if not os.path.exists(self._output_dir):
|
||||||
|
os.makedirs(self._output_dir)
|
||||||
|
|
||||||
|
# 初始化基础生成器(无条件生成)
|
||||||
|
self._init_base_generator()
|
||||||
|
|
||||||
|
def _init_base_generator(self):
|
||||||
|
"""
|
||||||
|
Initialize the base generator for unconditional generation.
|
||||||
|
"""
|
||||||
|
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
logger.warning(f"Base model directory not found at {model_path}. MatterGen service may not work properly.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Initializing base MatterGen generator from {model_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint_info = MatterGenCheckpointInfo(
|
||||||
|
model_path=Path(model_path).resolve(),
|
||||||
|
load_epoch="last",
|
||||||
|
config_overrides=[],
|
||||||
|
strict_checkpoint_loading=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = CrystalGenerator(
|
||||||
|
checkpoint_info=checkpoint_info,
|
||||||
|
properties_to_condition_on=None,
|
||||||
|
batch_size=2, # 默认值,可在生成时覆盖
|
||||||
|
num_batches=1, # 默认值,可在生成时覆盖
|
||||||
|
sampling_config_name="default",
|
||||||
|
sampling_config_path=None,
|
||||||
|
sampling_config_overrides=[],
|
||||||
|
record_trajectories=True,
|
||||||
|
diffusion_guidance_factor=0.0,
|
||||||
|
target_compositions_dict=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._generators["base"] = generator
|
||||||
|
logger.info("Base MatterGen generator initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize base MatterGen generator: {e}")
|
||||||
|
|
||||||
|
def _get_or_create_generator(
|
||||||
|
self,
|
||||||
|
properties: Optional[Dict[str, Any]] = None,
|
||||||
|
batch_size: int = 2,
|
||||||
|
num_batches: int = 1,
|
||||||
|
diffusion_guidance_factor: float = 2.0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get or create a generator for the specified properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
properties: Optional property constraints
|
||||||
|
batch_size: Number of structures per batch
|
||||||
|
num_batches: Number of batches to generate
|
||||||
|
diffusion_guidance_factor: Controls adherence to target properties
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (generator, generator_key, properties_to_condition_on)
|
||||||
|
"""
|
||||||
|
# 如果没有属性约束,使用基础生成器
|
||||||
|
if not properties:
|
||||||
|
if "base" not in self._generators:
|
||||||
|
self._init_base_generator()
|
||||||
|
return self._generators.get("base"), "base", None
|
||||||
|
|
||||||
|
# 处理属性约束
|
||||||
|
properties_to_condition_on = {}
|
||||||
|
for property_name, property_value in properties.items():
|
||||||
|
properties_to_condition_on[property_name] = property_value
|
||||||
|
|
||||||
|
# 确定模型目录
|
||||||
|
if len(properties) == 1:
|
||||||
|
# 单属性条件
|
||||||
|
property_name = list(properties.keys())[0]
|
||||||
|
property_to_model = {
|
||||||
|
"dft_mag_density": "dft_mag_density",
|
||||||
|
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||||
|
"dft_shear_modulus": "dft_shear_modulus",
|
||||||
|
"energy_above_hull": "energy_above_hull",
|
||||||
|
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||||
|
"space_group": "space_group",
|
||||||
|
"hhi_score": "hhi_score",
|
||||||
|
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||||
|
"chemical_system": "chemical_system",
|
||||||
|
"dft_band_gap": "dft_band_gap"
|
||||||
|
}
|
||||||
|
model_dir = property_to_model.get(property_name, property_name)
|
||||||
|
generator_key = f"single_{property_name}"
|
||||||
|
else:
|
||||||
|
# 多属性条件
|
||||||
|
property_keys = set(properties.keys())
|
||||||
|
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||||
|
model_dir = "dft_mag_density_hhi_score"
|
||||||
|
generator_key = "multi_dft_mag_density_hhi_score"
|
||||||
|
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||||
|
model_dir = "chemical_system_energy_above_hull"
|
||||||
|
generator_key = "multi_chemical_system_energy_above_hull"
|
||||||
|
else:
|
||||||
|
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||||||
|
first_property = list(properties.keys())[0]
|
||||||
|
model_dir = first_property
|
||||||
|
generator_key = f"multi_{first_property}_etc"
|
||||||
|
|
||||||
|
# 构建完整的模型路径
|
||||||
|
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
|
||||||
|
|
||||||
|
# 检查模型目录是否存在
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
# 如果特定模型不存在,回退到基础模型
|
||||||
|
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||||
|
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||||
|
generator_key = "base"
|
||||||
|
|
||||||
|
# 检查是否已经有这个生成器
|
||||||
|
if generator_key in self._generators:
|
||||||
|
# 更新生成器的参数
|
||||||
|
generator = self._generators[generator_key]
|
||||||
|
generator.batch_size = batch_size
|
||||||
|
generator.num_batches = num_batches
|
||||||
|
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
|
||||||
|
return generator, generator_key, properties_to_condition_on
|
||||||
|
|
||||||
|
# 创建新的生成器
|
||||||
|
try:
|
||||||
|
logger.info(f"Initializing new MatterGen generator for {generator_key} from {model_path}")
|
||||||
|
|
||||||
|
checkpoint_info = MatterGenCheckpointInfo(
|
||||||
|
model_path=Path(model_path).resolve(),
|
||||||
|
load_epoch="last",
|
||||||
|
config_overrides=[],
|
||||||
|
strict_checkpoint_loading=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = CrystalGenerator(
|
||||||
|
checkpoint_info=checkpoint_info,
|
||||||
|
properties_to_condition_on=properties_to_condition_on,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_batches=num_batches,
|
||||||
|
sampling_config_name="default",
|
||||||
|
sampling_config_path=None,
|
||||||
|
sampling_config_overrides=[],
|
||||||
|
record_trajectories=True,
|
||||||
|
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0,
|
||||||
|
target_compositions_dict=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._generators[generator_key] = generator
|
||||||
|
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
|
||||||
|
return generator, generator_key, properties_to_condition_on
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
|
||||||
|
# 回退到基础生成器
|
||||||
|
if "base" not in self._generators:
|
||||||
|
self._init_base_generator()
|
||||||
|
return self._generators.get("base"), "base", None
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||||
|
batch_size: int = 2,
|
||||||
|
num_batches: int = 1,
|
||||||
|
diffusion_guidance_factor: float = 2.0
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate crystal structures with optional property constraints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
properties: Optional property constraints
|
||||||
|
batch_size: Number of structures per batch
|
||||||
|
num_batches: Number of batches to generate
|
||||||
|
diffusion_guidance_factor: Controls adherence to target properties
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Descriptive text with generated crystal structures in CIF format
|
||||||
|
"""
|
||||||
|
from mars_toolkit.compute.material_gen import format_cif_content
|
||||||
|
|
||||||
|
# 处理字符串输入(如果提供)
|
||||||
|
if isinstance(properties, str):
|
||||||
|
try:
|
||||||
|
properties = json.loads(properties)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||||
|
|
||||||
|
# 如果为None,默认为空字典
|
||||||
|
properties = properties or {}
|
||||||
|
|
||||||
|
# 获取或创建生成器
|
||||||
|
generator, generator_key, properties_to_condition_on = self._get_or_create_generator(
|
||||||
|
properties, batch_size, num_batches, diffusion_guidance_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
if generator is None:
|
||||||
|
return "Error: Failed to initialize MatterGen generator"
|
||||||
|
|
||||||
|
# 生成结构
|
||||||
|
try:
|
||||||
|
generator.generate(output_dir=Path(self._output_dir))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating structures: {e}")
|
||||||
|
return f"Error generating structures: {e}"
|
||||||
|
|
||||||
|
# 创建字典存储文件内容
|
||||||
|
result_dict = {}
|
||||||
|
|
||||||
|
# 定义文件路径
|
||||||
|
cif_zip_path = os.path.join(self._output_dir, "generated_crystals_cif.zip")
|
||||||
|
xyz_file_path = os.path.join(self._output_dir, "generated_crystals.extxyz")
|
||||||
|
trajectories_zip_path = os.path.join(self._output_dir, "generated_trajectories.zip")
|
||||||
|
|
||||||
|
# 读取CIF压缩文件
|
||||||
|
if os.path.exists(cif_zip_path):
|
||||||
|
with open(cif_zip_path, 'rb') as f:
|
||||||
|
result_dict['cif_content'] = f.read()
|
||||||
|
|
||||||
|
# 根据生成类型创建描述性提示
|
||||||
|
if not properties:
|
||||||
|
generation_type = "unconditional"
|
||||||
|
title = "Generated Material Structures"
|
||||||
|
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||||
|
property_description = "unconditionally"
|
||||||
|
elif len(properties) == 1:
|
||||||
|
generation_type = "single_property"
|
||||||
|
property_name = list(properties.keys())[0]
|
||||||
|
property_value = properties[property_name]
|
||||||
|
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||||
|
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||||
|
property_description = f"conditioned on {property_name} = {property_value}"
|
||||||
|
else:
|
||||||
|
generation_type = "multi_property"
|
||||||
|
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||||
|
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||||
|
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||||
|
|
||||||
|
# 创建完整的提示
|
||||||
|
prompt = f"""
|
||||||
|
# {title}
|
||||||
|
|
||||||
|
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||||
|
|
||||||
|
{'' if generation_type == 'unconditional' else f'''
|
||||||
|
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||||
|
the generation adheres to the specified property values. Higher values produce samples that more
|
||||||
|
closely match the target properties but may reduce diversity.
|
||||||
|
'''}
|
||||||
|
|
||||||
|
## CIF Files (Crystallographic Information Files)
|
||||||
|
|
||||||
|
- Standard format for crystallographic structures
|
||||||
|
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||||
|
- Used by crystallographic software and visualization tools
|
||||||
|
|
||||||
|
```
|
||||||
|
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||||
|
```
|
||||||
|
|
||||||
|
{description}
|
||||||
|
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 清理文件(读取后删除)
|
||||||
|
try:
|
||||||
|
if os.path.exists(cif_zip_path):
|
||||||
|
os.remove(cif_zip_path)
|
||||||
|
if os.path.exists(xyz_file_path):
|
||||||
|
os.remove(xyz_file_path)
|
||||||
|
if os.path.exists(trajectories_zip_path):
|
||||||
|
os.remove(trajectories_zip_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning up files: {e}")
|
||||||
|
|
||||||
|
return prompt
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
|
||||||
|
f
|
||||||
|
|||||||
BIN
mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc
Normal file
BIN
mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
151
mattergen_api.py
Normal file
151
mattergen_api.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from fastapi import FastAPI, HTTPException, Request, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import uvicorn
|
||||||
|
from typing import Dict, Any, Optional, Union, List
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 初始化FastAPI
|
||||||
|
app = FastAPI(title="MatterGen API Service")
|
||||||
|
|
||||||
|
# 请求模型
|
||||||
|
class MaterialGenerationRequest(BaseModel):
|
||||||
|
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None
|
||||||
|
batch_size: int = 2
|
||||||
|
num_batches: int = 1
|
||||||
|
diffusion_guidance_factor: float = 2.0
|
||||||
|
|
||||||
|
# 响应模型
|
||||||
|
class MaterialGenerationResponse(BaseModel):
|
||||||
|
content: str
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
# 全局变量,用于跟踪服务状态
|
||||||
|
service_status = {
|
||||||
|
"initialized": False,
|
||||||
|
"error": None,
|
||||||
|
"mattergen_service": None
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化MatterGenService
|
||||||
|
try:
|
||||||
|
logger.info("Importing MatterGenService...")
|
||||||
|
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||||
|
|
||||||
|
logger.info("Initializing MatterGenService...")
|
||||||
|
mattergen_service = MatterGenService.get_instance()
|
||||||
|
service_status["mattergen_service"] = mattergen_service
|
||||||
|
service_status["initialized"] = True
|
||||||
|
logger.info("MatterGenService initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to initialize MatterGenService: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
service_status["error"] = error_msg
|
||||||
|
|
||||||
|
# 中间件:检查服务状态
|
||||||
|
@app.middleware("http")
|
||||||
|
async def check_service_status(request: Request, call_next):
|
||||||
|
# 健康检查端点不需要检查服务状态
|
||||||
|
if request.url.path == "/health":
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# 如果服务未初始化,返回503错误
|
||||||
|
if not service_status["initialized"]:
|
||||||
|
error_msg = service_status["error"] or "MatterGenService not initialized"
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
content={"detail": error_msg}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 继续处理请求
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
@app.post("/generate_material", response_model=MaterialGenerationResponse)
|
||||||
|
async def generate_material(request: MaterialGenerationRequest):
|
||||||
|
"""生成晶体结构,可选择性地指定属性约束"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Received material generation request with properties: {request.properties}")
|
||||||
|
print("request",request)
|
||||||
|
# 调用MatterGenService生成材料
|
||||||
|
result = mattergen_service.generate(
|
||||||
|
properties=request.properties,
|
||||||
|
batch_size=request.batch_size,
|
||||||
|
num_batches=request.num_batches,
|
||||||
|
diffusion_guidance_factor=request.diffusion_guidance_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Material generation completed successfully")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": result,
|
||||||
|
"success": True,
|
||||||
|
"message": "Material generation successful"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
# 记录详细错误信息
|
||||||
|
error_msg = f"Error generating material: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
# 返回错误响应
|
||||||
|
return {
|
||||||
|
"content": "",
|
||||||
|
"success": False,
|
||||||
|
"message": error_msg
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""健康检查端点,检查MatterGenService的状态"""
|
||||||
|
if service_status["initialized"]:
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "MatterGen API",
|
||||||
|
"mattergen_service": "initialized"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_msg = service_status["error"] or "MatterGenService not initialized"
|
||||||
|
return {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"service": "MatterGen API",
|
||||||
|
"error": error_msg
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""API根端点,提供基本信息"""
|
||||||
|
return {
|
||||||
|
"service": "MatterGen API Service",
|
||||||
|
"description": "API for generating crystal structures with optional property constraints",
|
||||||
|
"status": "healthy" if service_status["initialized"] else "unhealthy",
|
||||||
|
"endpoints": {
|
||||||
|
"/generate_material": "POST - Generate crystal structures",
|
||||||
|
"/health": "GET - Health check",
|
||||||
|
"/docs": "GET - API documentation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 全局异常处理
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
logger.error(f"Unhandled exception: {str(exc)}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
content={"detail": f"Internal server error: {str(exc)}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 启动服务
|
||||||
|
logger.info("Starting MatterGen API Service...")
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8051)
|
||||||
134
mattergen_client_example.py
Normal file
134
mattergen_client_example.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def generate_material(
|
||||||
|
url="http://localhost:8051/generate_material",
|
||||||
|
properties=None,
|
||||||
|
batch_size=2,
|
||||||
|
num_batches=1,
|
||||||
|
diffusion_guidance_factor=2.0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用MatterGen API生成晶体结构
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: API端点URL
|
||||||
|
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
||||||
|
batch_size: 每批生成的结构数量
|
||||||
|
num_batches: 批次数量
|
||||||
|
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成的结构内容或错误信息
|
||||||
|
"""
|
||||||
|
# 构建请求负载
|
||||||
|
payload = {
|
||||||
|
"properties": properties ,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"num_batches": num_batches,
|
||||||
|
"diffusion_guidance_factor": diffusion_guidance_factor
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"发送请求到 {url}")
|
||||||
|
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 添加headers参数,包含accept头
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"accept": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 打印完整请求信息(调试用)
|
||||||
|
print(f"完整请求URL: {url}")
|
||||||
|
print(f"请求头: {headers}")
|
||||||
|
print(f"请求体: {json.dumps(payload)}")
|
||||||
|
|
||||||
|
# 禁用代理设置
|
||||||
|
proxies = {
|
||||||
|
"http": None,
|
||||||
|
"https": None
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送POST请求,添加headers参数,禁用代理,增加超时时间
|
||||||
|
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
|
||||||
|
|
||||||
|
# 打印响应信息(调试用)
|
||||||
|
print(f"响应状态码: {response.status_code}")
|
||||||
|
print(f"响应头: {dict(response.headers)}")
|
||||||
|
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长
|
||||||
|
|
||||||
|
# 检查响应状态
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result["success"]:
|
||||||
|
print("\n生成成功!")
|
||||||
|
return result["content"]
|
||||||
|
else:
|
||||||
|
print(f"\n生成失败: {result['message']}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
print(f"\n请求失败,状态码: {response.status_code}")
|
||||||
|
print(f"响应内容: {response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n发生错误: {str(e)}")
|
||||||
|
print(f"错误类型: {type(e).__name__}")
|
||||||
|
import traceback
|
||||||
|
print(f"错误堆栈: {traceback.format_exc()}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""命令行入口函数"""
|
||||||
|
parser = argparse.ArgumentParser(description="MatterGen API客户端示例")
|
||||||
|
|
||||||
|
# 添加命令行参数
|
||||||
|
parser.add_argument("--url", default="http://localhost:8051/generate_material",
|
||||||
|
help="MatterGen API端点URL")
|
||||||
|
parser.add_argument("--property-name", default='dft_mag_density',help="属性名称,例如dft_band_gap")
|
||||||
|
parser.add_argument("--property-value",default=0.15,help="属性值,例如2.0")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量")
|
||||||
|
parser.add_argument("--num-batches", type=int, default=1, help="批次数量")
|
||||||
|
parser.add_argument("--guidance-factor", type=float, default=2.0,
|
||||||
|
help="控制生成结构与目标属性的符合程度")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 构建属性字典
|
||||||
|
properties = None
|
||||||
|
if args.property_name and args.property_value:
|
||||||
|
try:
|
||||||
|
# 尝试将属性值转换为数字
|
||||||
|
try:
|
||||||
|
value = float(args.property_value)
|
||||||
|
# 如果是整数,转换为整数
|
||||||
|
if value.is_integer():
|
||||||
|
value = int(value)
|
||||||
|
except ValueError:
|
||||||
|
# 如果无法转换为数字,保持为字符串
|
||||||
|
value = args.property_value
|
||||||
|
|
||||||
|
properties = {args.property_name: value}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析属性值时出错: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 调用API
|
||||||
|
result = generate_material(
|
||||||
|
url=args.url,
|
||||||
|
properties=properties,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.num_batches,
|
||||||
|
diffusion_guidance_factor=args.guidance_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
print("\n生成的结构:")
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -171,7 +171,7 @@ if __name__ == "__main__":
|
|||||||
]
|
]
|
||||||
|
|
||||||
# 选择要测试的工具
|
# 选择要测试的工具
|
||||||
tool_name = tools_to_test[5] # 测试 search_online 工具
|
tool_name = tools_to_test[6] # 测试 search_online 工具
|
||||||
|
|
||||||
# 运行测试
|
# 运行测试
|
||||||
result = asyncio.run(test_tool(tool_name))
|
result = asyncio.run(test_tool(tool_name))
|
||||||
|
|||||||
Reference in New Issue
Block a user