152 lines
4.9 KiB
Python
Executable File
152 lines
4.9 KiB
Python
Executable File
from fastapi import FastAPI, HTTPException, Request, status
|
||
from fastapi.responses import JSONResponse
|
||
from pydantic import BaseModel
|
||
import uvicorn
|
||
from typing import Dict, Any, Optional, Union, List
|
||
import logging
|
||
import traceback
|
||
import sys
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[logging.StreamHandler(sys.stdout)])
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 初始化FastAPI
|
||
app = FastAPI(title="MatterGen API Service")
|
||
|
||
# 请求模型
|
||
class MaterialGenerationRequest(BaseModel):
|
||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None
|
||
batch_size: int = 2
|
||
num_batches: int = 1
|
||
diffusion_guidance_factor: float = 2.0
|
||
|
||
# 响应模型
|
||
class MaterialGenerationResponse(BaseModel):
|
||
content: str
|
||
success: bool
|
||
message: str
|
||
|
||
# 全局变量,用于跟踪服务状态
|
||
service_status = {
|
||
"initialized": False,
|
||
"error": None,
|
||
"mattergen_service": None
|
||
}
|
||
|
||
# 初始化MatterGenService
|
||
try:
|
||
logger.info("Importing MatterGenService...")
|
||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||
|
||
logger.info("Initializing MatterGenService...")
|
||
mattergen_service = MatterGenService.get_instance()
|
||
service_status["mattergen_service"] = mattergen_service
|
||
service_status["initialized"] = True
|
||
logger.info("MatterGenService initialized successfully")
|
||
except Exception as e:
|
||
error_msg = f"Failed to initialize MatterGenService: {str(e)}"
|
||
logger.error(error_msg)
|
||
logger.error(traceback.format_exc())
|
||
service_status["error"] = error_msg
|
||
|
||
# 中间件:检查服务状态
|
||
@app.middleware("http")
|
||
async def check_service_status(request: Request, call_next):
|
||
# 健康检查端点不需要检查服务状态
|
||
if request.url.path == "/health":
|
||
return await call_next(request)
|
||
|
||
# 如果服务未初始化,返回503错误
|
||
if not service_status["initialized"]:
|
||
error_msg = service_status["error"] or "MatterGenService not initialized"
|
||
return JSONResponse(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
content={"detail": error_msg}
|
||
)
|
||
|
||
# 继续处理请求
|
||
return await call_next(request)
|
||
|
||
@app.post("/generate_material", response_model=MaterialGenerationResponse)
|
||
async def generate_material(request: MaterialGenerationRequest):
|
||
"""生成晶体结构,可选择性地指定属性约束"""
|
||
try:
|
||
logger.info(f"Received material generation request with properties: {request.properties}")
|
||
print("request",request)
|
||
# 调用MatterGenService生成材料
|
||
result = mattergen_service.generate(
|
||
properties=request.properties,
|
||
batch_size=request.batch_size,
|
||
num_batches=request.num_batches,
|
||
diffusion_guidance_factor=request.diffusion_guidance_factor
|
||
)
|
||
|
||
logger.info("Material generation completed successfully")
|
||
|
||
return {
|
||
"content": result,
|
||
"success": True,
|
||
"message": "Material generation successful"
|
||
}
|
||
except Exception as e:
|
||
# 记录详细错误信息
|
||
error_msg = f"Error generating material: {str(e)}"
|
||
logger.error(error_msg)
|
||
logger.error(traceback.format_exc())
|
||
|
||
# 返回错误响应
|
||
return {
|
||
"content": "",
|
||
"success": False,
|
||
"message": error_msg
|
||
}
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查端点,检查MatterGenService的状态"""
|
||
if service_status["initialized"]:
|
||
return {
|
||
"status": "healthy",
|
||
"service": "MatterGen API",
|
||
"mattergen_service": "initialized"
|
||
}
|
||
else:
|
||
error_msg = service_status["error"] or "MatterGenService not initialized"
|
||
return {
|
||
"status": "unhealthy",
|
||
"service": "MatterGen API",
|
||
"error": error_msg
|
||
}
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""API根端点,提供基本信息"""
|
||
return {
|
||
"service": "MatterGen API Service",
|
||
"description": "API for generating crystal structures with optional property constraints",
|
||
"status": "healthy" if service_status["initialized"] else "unhealthy",
|
||
"endpoints": {
|
||
"/generate_material": "POST - Generate crystal structures",
|
||
"/health": "GET - Health check",
|
||
"/docs": "GET - API documentation"
|
||
}
|
||
}
|
||
|
||
# 全局异常处理
|
||
@app.exception_handler(Exception)
|
||
async def global_exception_handler(request: Request, exc: Exception):
|
||
logger.error(f"Unhandled exception: {str(exc)}")
|
||
logger.error(traceback.format_exc())
|
||
return JSONResponse(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
content={"detail": f"Internal server error: {str(exc)}"}
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
# 启动服务
|
||
logger.info("Starting MatterGen API Service...")
|
||
uvicorn.run(app, host="0.0.0.0", port=8051)
|