Files
mars-mcp/mattergen_api.py
2025-04-16 11:15:01 +08:00

152 lines
4.9 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import uvicorn
from typing import Dict, Any, Optional, Union, List
import logging
import traceback
import sys
# 配置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)])
logger = logging.getLogger(__name__)
# 初始化FastAPI
app = FastAPI(title="MatterGen API Service")
# 请求模型
class MaterialGenerationRequest(BaseModel):
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None
batch_size: int = 2
num_batches: int = 1
diffusion_guidance_factor: float = 2.0
# 响应模型
class MaterialGenerationResponse(BaseModel):
content: str
success: bool
message: str
# 全局变量,用于跟踪服务状态
service_status = {
"initialized": False,
"error": None,
"mattergen_service": None
}
# 初始化MatterGenService
try:
logger.info("Importing MatterGenService...")
from mars_toolkit.services.mattergen_service import MatterGenService
logger.info("Initializing MatterGenService...")
mattergen_service = MatterGenService.get_instance()
service_status["mattergen_service"] = mattergen_service
service_status["initialized"] = True
logger.info("MatterGenService initialized successfully")
except Exception as e:
error_msg = f"Failed to initialize MatterGenService: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
service_status["error"] = error_msg
# 中间件:检查服务状态
@app.middleware("http")
async def check_service_status(request: Request, call_next):
# 健康检查端点不需要检查服务状态
if request.url.path == "/health":
return await call_next(request)
# 如果服务未初始化返回503错误
if not service_status["initialized"]:
error_msg = service_status["error"] or "MatterGenService not initialized"
return JSONResponse(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
content={"detail": error_msg}
)
# 继续处理请求
return await call_next(request)
@app.post("/generate_material", response_model=MaterialGenerationResponse)
async def generate_material(request: MaterialGenerationRequest):
"""生成晶体结构,可选择性地指定属性约束"""
try:
logger.info(f"Received material generation request with properties: {request.properties}")
print("request",request)
# 调用MatterGenService生成材料
result = mattergen_service.generate(
properties=request.properties,
batch_size=request.batch_size,
num_batches=request.num_batches,
diffusion_guidance_factor=request.diffusion_guidance_factor
)
logger.info("Material generation completed successfully")
return {
"content": result,
"success": True,
"message": "Material generation successful"
}
except Exception as e:
# 记录详细错误信息
error_msg = f"Error generating material: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
# 返回错误响应
return {
"content": "",
"success": False,
"message": error_msg
}
@app.get("/health")
async def health_check():
"""健康检查端点检查MatterGenService的状态"""
if service_status["initialized"]:
return {
"status": "healthy",
"service": "MatterGen API",
"mattergen_service": "initialized"
}
else:
error_msg = service_status["error"] or "MatterGenService not initialized"
return {
"status": "unhealthy",
"service": "MatterGen API",
"error": error_msg
}
@app.get("/")
async def root():
"""API根端点提供基本信息"""
return {
"service": "MatterGen API Service",
"description": "API for generating crystal structures with optional property constraints",
"status": "healthy" if service_status["initialized"] else "unhealthy",
"endpoints": {
"/generate_material": "POST - Generate crystal structures",
"/health": "GET - Health check",
"/docs": "GET - API documentation"
}
}
# 全局异常处理
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled exception: {str(exc)}")
logger.error(traceback.format_exc())
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": f"Internal server error: {str(exc)}"}
)
if __name__ == "__main__":
# 启动服务
logger.info("Starting MatterGen API Service...")
uvicorn.run(app, host="0.0.0.0", port=8051)