重构代码
This commit is contained in:
22
constant.py
22
constant.py
@@ -2,20 +2,20 @@ TOPK_RESULT = 1
|
|||||||
TIME_OUT = 60
|
TIME_OUT = 60
|
||||||
|
|
||||||
# MP Configuration
|
# MP Configuration
|
||||||
MP_API_KEY = "gfBp2in8qxm9Xm2SwLKFwNxDyZvNTAEt"
|
MP_API_KEY = None
|
||||||
MP_ENDPOINT = "https://api.materialsproject.org/"
|
MP_ENDPOINT = None
|
||||||
|
|
||||||
# Proxy
|
# Proxy
|
||||||
HTTP_PROXY = "http://127.0.0.1:7897"
|
HTTP_PROXY = None
|
||||||
HTTPS_PROXY = "http://127.0.0.1:7897"
|
HTTPS_PROXY = None
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
FAIRCHEM_MODEL_PATH = "/home/ubuntu/sas0/LYT/mars1215/mars_toolkit/model/eqV2_86M_omat_mp_salex.pt"
|
FAIRCHEM_MODEL_PATH = None
|
||||||
FMAX = 0.05
|
FMAX = None
|
||||||
|
|
||||||
# MinIO configuration
|
# MinIO configuration
|
||||||
MINIO_ENDPOINT = "https://s3-api.siat-mic.com"
|
MINIO_ENDPOINT = None
|
||||||
INTERNEL_MINIO_ENDPOINT = "http://100.85.52.31:9000" # 内网地址,如果有就填,上传会更快。
|
INTERNEL_MINIO_ENDPOINT = None
|
||||||
MINIO_ACCESS_KEY = "9bUtQL1Gpo9JB6o3pSGr"
|
MINIO_ACCESS_KEY = None
|
||||||
MINIO_SECRET_KEY = "1Qug5H73R3kP8boIHvdVcFtcb1jU9GRWnlmMpx0g"
|
MINIO_SECRET_KEY = None
|
||||||
MINIO_BUCKET = "temp"
|
MINIO_BUCKET = None
|
||||||
|
|||||||
49
error_handlers.py
Normal file
49
error_handlers.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
Author: Yutang LI
|
||||||
|
Institution: SIAT-MIC
|
||||||
|
Contact: yt.li2@siat.ac.cn
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from typing import Any, Dict
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class APIError(HTTPException):
|
||||||
|
"""自定义API错误类"""
|
||||||
|
def __init__(self, status_code: int, detail: Any = None):
|
||||||
|
super().__init__(status_code=status_code, detail=detail)
|
||||||
|
logger.error(f"API Error: {status_code} - {detail}")
|
||||||
|
|
||||||
|
def handle_minio_error(e: Exception) -> Dict[str, str]:
|
||||||
|
"""处理MinIO相关错误"""
|
||||||
|
logger.error(f"MinIO operation failed: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"data": f"MinIO operation failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
def handle_http_error(e: Exception) -> Dict[str, str]:
|
||||||
|
"""处理HTTP请求错误"""
|
||||||
|
logger.error(f"HTTP request failed: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"data": f"HTTP request failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
def handle_validation_error(e: Exception) -> Dict[str, str]:
|
||||||
|
"""处理数据验证错误"""
|
||||||
|
logger.error(f"Validation failed: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"data": f"Validation failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
def handle_general_error(e: Exception) -> Dict[str, str]:
|
||||||
|
"""处理通用错误"""
|
||||||
|
logger.error(f"Unexpected error: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"data": f"Unexpected error: {str(e)}"
|
||||||
|
}
|
||||||
58
main.py
58
main.py
@@ -5,25 +5,55 @@ Contact: yt.li2@siat.ac.cn
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
import logging
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import os
|
from fastapi.middleware import Middleware
|
||||||
from database.material_project_router import router as material_project_router
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from database.oqmd_router import router as oqmd_router
|
from router.mp_router import router as material_router
|
||||||
from model.fairchem_router import router as fairchem_router, init_model
|
from router.oqmd_router import router as oqmd_router
|
||||||
|
from router.fairchem_router import router as fairchem_router
|
||||||
|
from error_handlers import (
|
||||||
logging.basicConfig(
|
handle_general_error,
|
||||||
level=logging.INFO,
|
handle_http_error,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
handle_validation_error
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
from utils import setup_logging
|
||||||
|
from router.fairchem_router import init_model
|
||||||
|
|
||||||
app = FastAPI()
|
# 初始化日志配置
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
# 创建中间件列表
|
||||||
|
middleware = [
|
||||||
|
Middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
app = FastAPI(middleware=middleware)
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
def startup_event():
|
async def startup_event():
|
||||||
|
"""应用启动时初始化模型"""
|
||||||
init_model()
|
init_model()
|
||||||
|
|
||||||
app.include_router(material_project_router)
|
# 注册路由
|
||||||
|
app.include_router(material_router)
|
||||||
app.include_router(oqmd_router)
|
app.include_router(oqmd_router)
|
||||||
app.include_router(fairchem_router)
|
app.include_router(fairchem_router)
|
||||||
|
|
||||||
|
# 添加全局异常处理
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def global_exception_handler(request, exc):
|
||||||
|
return handle_general_error(exc)
|
||||||
|
|
||||||
|
@app.exception_handler(ValueError)
|
||||||
|
async def validation_exception_handler(request, exc):
|
||||||
|
return handle_validation_error(exc)
|
||||||
|
|
||||||
|
@app.exception_handler(ConnectionError)
|
||||||
|
async def http_exception_handler(request, exc):
|
||||||
|
return handle_http_error(exc)
|
||||||
|
|||||||
@@ -1,165 +0,0 @@
|
|||||||
from fastapi import APIRouter, Body, Query
|
|
||||||
from fairchem.core import OCPCalculator
|
|
||||||
from ase.optimize import FIRE
|
|
||||||
from ase.filters import FrechetCellFilter
|
|
||||||
from ase.atoms import Atoms
|
|
||||||
from ase.io import read, write
|
|
||||||
from pymatgen.core.structure import Structure
|
|
||||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
|
||||||
from pymatgen.io.cif import CifWriter
|
|
||||||
import tempfile
|
|
||||||
import os
|
|
||||||
import boto3
|
|
||||||
from constant import FAIRCHEM_MODEL_PATH, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT, FMAX
|
|
||||||
from typing import Optional
|
|
||||||
import logging
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/fairchem", tags=["fairchem"])
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
calc = None
|
|
||||||
|
|
||||||
def init_model():
|
|
||||||
global calc
|
|
||||||
calc = OCPCalculator(checkpoint_path=FAIRCHEM_MODEL_PATH)
|
|
||||||
|
|
||||||
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
|
||||||
"""将输入内容转换为Atoms对象"""
|
|
||||||
try:
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
|
||||||
tmp_file.write(content)
|
|
||||||
tmp_path = tmp_file.name
|
|
||||||
|
|
||||||
atoms = read(tmp_path)
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
return atoms
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to convert structure: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def generate_symmetry_cif(structure: Structure) -> str:
|
|
||||||
"""生成对称性CIF"""
|
|
||||||
analyzer = SpacegroupAnalyzer(structure)
|
|
||||||
structure = analyzer.get_refined_structure()
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
|
||||||
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
|
|
||||||
cif_writer.write_file(tmp_file.name)
|
|
||||||
tmp_file.seek(0)
|
|
||||||
return tmp_file.read()
|
|
||||||
|
|
||||||
def upload_to_minio(file_path: str, file_name: str) -> str:
|
|
||||||
"""上传文件到MinIO并返回预签名URL"""
|
|
||||||
try:
|
|
||||||
minio_client = boto3.client(
|
|
||||||
's3',
|
|
||||||
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
|
|
||||||
aws_access_key_id=MINIO_ACCESS_KEY,
|
|
||||||
aws_secret_access_key=MINIO_SECRET_KEY
|
|
||||||
)
|
|
||||||
|
|
||||||
bucket_name = MINIO_BUCKET
|
|
||||||
minio_client.upload_file(file_path, bucket_name, file_name, ExtraArgs={"ACL": "private"})
|
|
||||||
|
|
||||||
# 生成预签名 URL
|
|
||||||
url = minio_client.generate_presigned_url(
|
|
||||||
'get_object',
|
|
||||||
Params={'Bucket': bucket_name, 'Key': file_name},
|
|
||||||
ExpiresIn=3600
|
|
||||||
)
|
|
||||||
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to upload to MinIO: {str(e)}")
|
|
||||||
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
|
|
||||||
|
|
||||||
from io import StringIO
|
|
||||||
import sys
|
|
||||||
|
|
||||||
def optimize_structure(atoms: Atoms, output_format: str):
|
|
||||||
"""优化晶体结构"""
|
|
||||||
atoms.calc = calc
|
|
||||||
|
|
||||||
# 捕获优化日志
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = log_capture = StringIO()
|
|
||||||
|
|
||||||
try:
|
|
||||||
dyn = FIRE(FrechetCellFilter(atoms))
|
|
||||||
dyn.run(fmax=FMAX)
|
|
||||||
total_energy = atoms.get_total_energy()
|
|
||||||
optimization_log = log_capture.getvalue()
|
|
||||||
finally:
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
|
|
||||||
# 处理对称性
|
|
||||||
if output_format == "cif":
|
|
||||||
optimized_structure = Structure.from_ase_atoms(atoms)
|
|
||||||
content = generate_symmetry_cif(optimized_structure)
|
|
||||||
else:
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
|
||||||
write(tmp_file.name, atoms)
|
|
||||||
tmp_file.seek(0)
|
|
||||||
content = tmp_file.read()
|
|
||||||
|
|
||||||
# 保存优化结果到临时文件
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
||||||
file_name = f"optimized_structure_{timestamp}.{output_format}"
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
|
|
||||||
tmp_file.write(content)
|
|
||||||
tmp_path = tmp_file.name
|
|
||||||
try:
|
|
||||||
# 上传到MinIO
|
|
||||||
url = upload_to_minio(tmp_path, file_name)
|
|
||||||
return total_energy, content, url, optimization_log
|
|
||||||
finally:
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
|
|
||||||
@router.post("/optimize_structure")
|
|
||||||
async def optimize_structure_endpoint(
|
|
||||||
content: str = Body(..., description="Input structure content"),
|
|
||||||
input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"),
|
|
||||||
output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)")
|
|
||||||
):
|
|
||||||
# 转换输入结构
|
|
||||||
atoms = convert_structure(input_format, content)
|
|
||||||
if atoms is None:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"data": f"Invalid {input_format} content"
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 优化结构
|
|
||||||
total_energy, optimized_content, download_url, optimization_log = optimize_structure(atoms, output_format)
|
|
||||||
|
|
||||||
# 格式化返回结果
|
|
||||||
format_result = f"""
|
|
||||||
The following is the optimized crystal structure information:
|
|
||||||
|
|
||||||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
|
||||||
```text
|
|
||||||
{optimization_log}
|
|
||||||
```
|
|
||||||
Finally, the Total Energy is: {total_energy} eV
|
|
||||||
Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link:
|
|
||||||
👉 Click [here]({download_url}) to download the {output_format.upper()} file
|
|
||||||
|
|
||||||
Please ensure that the Optimization Results and download link are fully conveyed to the user, as this is very important for them.
|
|
||||||
"""
|
|
||||||
print(format_result)
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"data": format_result
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Optimization failed: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"data": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
init_model()
|
|
||||||
60
router/fairchem_router.py
Normal file
60
router/fairchem_router.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""
|
||||||
|
Author: Yutang LI
|
||||||
|
Institution: SIAT-MIC
|
||||||
|
Contact: yt.li2@siat.ac.cn
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Query
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import logging
|
||||||
|
from error_handlers import handle_general_error
|
||||||
|
from services.fairchem_service import (
|
||||||
|
init_model,
|
||||||
|
convert_structure,
|
||||||
|
optimize_structure
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/fairchem", tags=["fairchem"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
init_model()
|
||||||
|
|
||||||
|
@router.post("/optimize_structure")
|
||||||
|
async def optimize_structure_endpoint(
|
||||||
|
content: str = Body(..., description="Input structure content"),
|
||||||
|
input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"),
|
||||||
|
output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# 转换输入结构
|
||||||
|
atoms = convert_structure(input_format, content)
|
||||||
|
if atoms is None:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"status": "error", "data": f"Invalid {input_format} content"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 优化结构
|
||||||
|
total_energy, optimized_content, download_url = optimize_structure(atoms, output_format)
|
||||||
|
|
||||||
|
# 格式化返回结果
|
||||||
|
format_result = f"""
|
||||||
|
The following is the optimized crystal structure information:
|
||||||
|
|
||||||
|
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||||
|
Total Energy: {total_energy} eV
|
||||||
|
|
||||||
|
Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link:
|
||||||
|
👉 Click [here]({download_url}) to download the {output_format.upper()} file
|
||||||
|
"""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={"status": "success", "data": format_result}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return handle_general_error(e)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
init_model()
|
||||||
70
router/mp_router.py
Normal file
70
router/mp_router.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""
|
||||||
|
Author: Yutang LI
|
||||||
|
Institution: SIAT-MIC
|
||||||
|
Contact: yt.li2@siat.ac.cn
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
from typing import Dict
|
||||||
|
from services.mp_service import (
|
||||||
|
parse_search_parameters,
|
||||||
|
process_search_results,
|
||||||
|
execute_search
|
||||||
|
)
|
||||||
|
from utils import handle_minio_upload
|
||||||
|
from error_handlers import handle_general_error
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/mp", tags=["Material Project"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@router.get("/search")
|
||||||
|
async def search_from_material_project(request: Request):
|
||||||
|
# 打印请求日志
|
||||||
|
logger.info(f"Received request: {request.method} {request.url}")
|
||||||
|
logger.info(f"Query parameters: {request.query_params}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析查询参数
|
||||||
|
search_args = parse_search_parameters(request.query_params)
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
docs = await execute_search(search_args)
|
||||||
|
|
||||||
|
# 处理搜索结果
|
||||||
|
res = process_search_results(docs)
|
||||||
|
|
||||||
|
if len(res) == 0:
|
||||||
|
return {"status": "success", "data": "No results found, please try again."}
|
||||||
|
|
||||||
|
# 上传结果到MinIO
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
file_name = f"mp_search_results_{timestamp}.json"
|
||||||
|
|
||||||
|
# 将结果写入临时文件
|
||||||
|
with open(file_name, 'w') as f:
|
||||||
|
json.dump(res, f, indent=2)
|
||||||
|
|
||||||
|
# 上传并获取URL
|
||||||
|
url = handle_minio_upload(file_name, file_name)
|
||||||
|
|
||||||
|
# 删除临时文件
|
||||||
|
os.remove(file_name)
|
||||||
|
|
||||||
|
# 格式化返回结果
|
||||||
|
res_chunk = "```json\n" + json.dumps(res[:5], indent=2) + "\n```"
|
||||||
|
res_template = f"""
|
||||||
|
好的,以下是用户的查询结果:
|
||||||
|
由于返回长度的限制,我们只能返回前5个结果。如下:
|
||||||
|
{res_chunk}
|
||||||
|
如果用户需要更多的结果,请提示用户修改查询条件,或者尝试使用其他查询参数。
|
||||||
|
同时我们将全部的的查询结果上传到MinIO中,请你提示用户可以通过以下链接下载:
|
||||||
|
[Download]({url})
|
||||||
|
"""
|
||||||
|
return {"status": "success", "data": res_template}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return handle_general_error(e)
|
||||||
57
router/oqmd_router.py
Normal file
57
router/oqmd_router.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""
|
||||||
|
Author: Yutang LI
|
||||||
|
Institution: SIAT-MIC
|
||||||
|
Contact: yt.li2@siat.ac.cn
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import logging
|
||||||
|
from error_handlers import handle_general_error
|
||||||
|
from services.oqmd_service import (
|
||||||
|
fetch_oqmd_data,
|
||||||
|
parse_oqmd_html,
|
||||||
|
render_and_save_charts
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/oqmd", tags=["OQMD"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@router.get("/search")
|
||||||
|
async def search_from_oqmd_by_composition(request: Request):
|
||||||
|
"""通过成分搜索OQMD数据"""
|
||||||
|
try:
|
||||||
|
# 打印请求日志
|
||||||
|
logger.info(f"Received request: {request.method} {request.url}")
|
||||||
|
logger.info(f"Query parameters: {request.query_params}")
|
||||||
|
|
||||||
|
# 获取并解析数据
|
||||||
|
composition = request.query_params['composition']
|
||||||
|
html = await fetch_oqmd_data(composition)
|
||||||
|
basic_data, table_data, phase_data = parse_oqmd_html(html)
|
||||||
|
|
||||||
|
# 渲染并保存图表
|
||||||
|
phase_diagram_url = await render_and_save_charts(phase_data)
|
||||||
|
|
||||||
|
# 返回格式化后的响应
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"status": "success",
|
||||||
|
"data": format_response(basic_data, table_data, phase_diagram_url)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return handle_general_error(e)
|
||||||
|
|
||||||
|
def format_response(basic_data: list, table_data: str, phase_data: str) -> str:
|
||||||
|
"""格式化响应数据"""
|
||||||
|
response = "### OQMD Data\n"
|
||||||
|
for item in basic_data:
|
||||||
|
response += f"**{item}**\n"
|
||||||
|
response += "\n### Phase Diagram\n\n"
|
||||||
|
response += f"\n\n"
|
||||||
|
response += "\n### Compounds at this composition\n\n"
|
||||||
|
response += f"{table_data}\n"
|
||||||
|
return response
|
||||||
0
services/__init__.py
Normal file
0
services/__init__.py
Normal file
92
services/fairchem_service.py
Normal file
92
services/fairchem_service.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""
|
||||||
|
Author: Yutang LI
|
||||||
|
Institution: SIAT-MIC
|
||||||
|
Contact: yt.li2@siat.ac.cn
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from ase.optimize import FIRE
|
||||||
|
from ase.filters import FrechetCellFilter
|
||||||
|
from ase.atoms import Atoms
|
||||||
|
from ase.io import read, write
|
||||||
|
from pymatgen.core.structure import Structure
|
||||||
|
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||||
|
from pymatgen.io.cif import CifWriter
|
||||||
|
from utils import settings, handle_minio_upload
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
calc = None
|
||||||
|
|
||||||
|
def init_model():
|
||||||
|
"""初始化FairChem模型"""
|
||||||
|
global calc
|
||||||
|
try:
|
||||||
|
from fairchem.core import OCPCalculator
|
||||||
|
calc = OCPCalculator(checkpoint_path=settings.fairchem_model_path)
|
||||||
|
logger.info("FairChem model initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
||||||
|
"""将输入内容转换为Atoms对象"""
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||||||
|
tmp_file.write(content)
|
||||||
|
tmp_path = tmp_file.name
|
||||||
|
|
||||||
|
atoms = read(tmp_path)
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
return atoms
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to convert structure: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def generate_symmetry_cif(structure: Structure) -> str:
|
||||||
|
"""生成对称性CIF"""
|
||||||
|
analyzer = SpacegroupAnalyzer(structure)
|
||||||
|
structure = analyzer.get_refined_structure()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||||
|
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
|
||||||
|
cif_writer.write_file(tmp_file.name)
|
||||||
|
tmp_file.seek(0)
|
||||||
|
return tmp_file.read()
|
||||||
|
|
||||||
|
def optimize_structure(atoms: Atoms, output_format: str):
|
||||||
|
"""优化晶体结构"""
|
||||||
|
atoms.calc = calc
|
||||||
|
|
||||||
|
try:
|
||||||
|
dyn = FIRE(FrechetCellFilter(atoms))
|
||||||
|
dyn.run(fmax=settings.fmax)
|
||||||
|
total_energy = atoms.get_total_energy()
|
||||||
|
|
||||||
|
# 处理对称性
|
||||||
|
if output_format == "cif":
|
||||||
|
optimized_structure = Structure.from_ase_atoms(atoms)
|
||||||
|
content = generate_symmetry_cif(optimized_structure)
|
||||||
|
else:
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||||||
|
write(tmp_file.name, atoms)
|
||||||
|
tmp_file.seek(0)
|
||||||
|
content = tmp_file.read()
|
||||||
|
|
||||||
|
# 保存优化结果到临时文件
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
file_name = f"optimized_structure_{timestamp}.{output_format}"
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
|
||||||
|
tmp_file.write(content)
|
||||||
|
tmp_path = tmp_file.name
|
||||||
|
|
||||||
|
# 上传到MinIO
|
||||||
|
url = handle_minio_upload(tmp_path, file_name)
|
||||||
|
return total_energy, content, url
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
@@ -4,117 +4,28 @@ Institution: SIAT-MIC
|
|||||||
Contact: yt.li2@siat.ac.cn
|
Contact: yt.li2@siat.ac.cn
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import boto3
|
|
||||||
from fastapi import APIRouter, Request
|
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import datetime
|
import datetime
|
||||||
from mp_api.client import MPRester
|
|
||||||
from multiprocessing import Process, Manager
|
from multiprocessing import Process, Manager
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
from constant import MP_ENDPOINT, MP_API_KEY, TIME_OUT, TOPK_RESULT, HTTP_PROXY, HTTPS_PROXY, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
|
from mp_api.client import MPRester
|
||||||
|
from utils import settings, handle_minio_upload
|
||||||
router = APIRouter(prefix="/mp", tags=["Material Project"])
|
from error_handlers import handle_general_error
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search")
|
|
||||||
async def search_from_material_project(request: Request):
|
|
||||||
# 打印请求日志
|
|
||||||
logger.info(f"Received request: {request.method} {request.url}")
|
|
||||||
logger.info(f"Query parameters: {request.query_params}")
|
|
||||||
|
|
||||||
# 解析查询参数
|
|
||||||
search_args = parse_search_parameters(request.query_params)
|
|
||||||
|
|
||||||
# 检查API key
|
|
||||||
if MP_API_KEY is None or MP_API_KEY == '':
|
|
||||||
return 'Material Project API CANNOT Be None'
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 执行搜索
|
|
||||||
docs = await execute_search(search_args)
|
|
||||||
|
|
||||||
# 处理搜索结果
|
|
||||||
res = process_search_results(docs)
|
|
||||||
url = ""
|
|
||||||
# 返回结果
|
|
||||||
if len(res) == 0:
|
|
||||||
return {"status": "success", "data": "No results found, please try again."}
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 上传结果到MinIO
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
||||||
file_name = f"mp_search_results_{timestamp}.json"
|
|
||||||
|
|
||||||
try:
|
|
||||||
minio_client = boto3.client(
|
|
||||||
's3',
|
|
||||||
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
|
|
||||||
aws_access_key_id=MINIO_ACCESS_KEY,
|
|
||||||
aws_secret_access_key=MINIO_SECRET_KEY
|
|
||||||
)
|
|
||||||
|
|
||||||
# 将结果写入临时文件
|
|
||||||
with open(file_name, 'w') as f:
|
|
||||||
json.dump(res, f, indent=2)
|
|
||||||
|
|
||||||
# 上传到MinIO
|
|
||||||
minio_client.upload_file(file_name, MINIO_BUCKET, file_name, ExtraArgs={"ACL": "private"})
|
|
||||||
|
|
||||||
# 生成预签名URL
|
|
||||||
url = minio_client.generate_presigned_url(
|
|
||||||
'get_object',
|
|
||||||
Params={'Bucket': MINIO_BUCKET, 'Key': file_name},
|
|
||||||
ExpiresIn=3600
|
|
||||||
)
|
|
||||||
url = url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
|
|
||||||
|
|
||||||
# 删除临时文件
|
|
||||||
os.remove(file_name)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to upload to MinIO: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"data": f"Failed to upload results to MinIO: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 格式化返回结果
|
|
||||||
res_chunk = "```json\n" + json.dumps(res[:TOPK_RESULT], indent=2) + "\n```"
|
|
||||||
res_template = f"""
|
|
||||||
好的,以下是用户的查询结果:
|
|
||||||
由于返回长度的限制,我们只能返回前{TOPK_RESULT}个结果。如下:
|
|
||||||
{res_chunk}
|
|
||||||
如果用户需要更多的结果,请提示用户修改查询条件,或者尝试使用其他查询参数。
|
|
||||||
同时我们将全部的的查询结果上传到MinIO中,请你提示用户可以通过以下链接下载:
|
|
||||||
[Download]({url})
|
|
||||||
"""
|
|
||||||
return {"status": "success", "data": res_template}
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again.")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"data": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again."
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_bool(param: str) -> bool | None:
|
def parse_bool(param: str) -> bool | None:
|
||||||
if not param:
|
if not param:
|
||||||
return None
|
return None
|
||||||
return param.lower() == 'true'
|
return param.lower() == 'true'
|
||||||
|
|
||||||
|
|
||||||
def parse_list(param: str) -> List[str] | None:
|
def parse_list(param: str) -> List[str] | None:
|
||||||
if not param:
|
if not param:
|
||||||
return None
|
return None
|
||||||
return param.split(',')
|
return param.split(',')
|
||||||
|
|
||||||
|
|
||||||
def parse_tuple(param: str) -> tuple[float, float] | None:
|
def parse_tuple(param: str) -> tuple[float, float] | None:
|
||||||
if not param:
|
if not param:
|
||||||
return None
|
return None
|
||||||
@@ -124,7 +35,6 @@ def parse_tuple(param: str) -> tuple[float, float] | None:
|
|||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
|
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
|
||||||
"""解析搜索参数"""
|
"""解析搜索参数"""
|
||||||
return {
|
return {
|
||||||
@@ -147,7 +57,6 @@ def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
|
|||||||
'chunk_size': int(query_params.get('chunk_size', '5'))
|
'chunk_size': int(query_params.get('chunk_size', '5'))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
"""处理搜索结果"""
|
"""处理搜索结果"""
|
||||||
fields = [
|
fields = [
|
||||||
@@ -171,13 +80,12 @@ def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|||||||
continue
|
continue
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def execute_search(search_args: Dict[str, Any], timeout: int = 30) -> List[Dict[str, Any]]:
|
||||||
async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -> List[Dict[str, Any]]:
|
|
||||||
"""执行搜索"""
|
"""执行搜索"""
|
||||||
manager = Manager()
|
manager = Manager()
|
||||||
queue = manager.Queue()
|
queue = manager.Queue()
|
||||||
|
|
||||||
p = Process(target=_search_worker, args=(queue, MP_API_KEY), kwargs=search_args)
|
p = Process(target=_search_worker, args=(queue, settings.mp_api_key), kwargs=search_args)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
logger.info(f"Started worker process with PID: {p.pid}")
|
logger.info(f"Started worker process with PID: {p.pid}")
|
||||||
@@ -207,14 +115,13 @@ async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -
|
|||||||
logger.info(f"Successfully retrieved {len(result)} documents")
|
logger.info(f"Successfully retrieved {len(result)} documents")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _search_worker(queue, api_key, **kwargs):
|
def _search_worker(queue, api_key, **kwargs):
|
||||||
"""搜索工作线程"""
|
"""搜索工作线程"""
|
||||||
try:
|
try:
|
||||||
import os
|
import os
|
||||||
os.environ['HTTP_PROXY'] = HTTP_PROXY
|
os.environ['HTTP_PROXY'] = settings.http_proxy or ''
|
||||||
os.environ['HTTPS_PROXY'] = HTTPS_PROXY
|
os.environ['HTTPS_PROXY'] = settings.https_proxy or ''
|
||||||
mpr = MPRester(api_key, endpoint=MP_ENDPOINT)
|
mpr = MPRester(api_key, endpoint=settings.mp_endpoint)
|
||||||
result = mpr.materials.summary.search(**kwargs)
|
result = mpr.materials.summary.search(**kwargs)
|
||||||
queue.put([doc.model_dump() for doc in result])
|
queue.put([doc.model_dump() for doc in result])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -3,69 +3,22 @@ Author: Yutang LI
|
|||||||
Institution: SIAT-MIC
|
Institution: SIAT-MIC
|
||||||
Contact: yt.li2@siat.ac.cn
|
Contact: yt.li2@siat.ac.cn
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import boto3
|
|
||||||
from fastapi import APIRouter, Request
|
|
||||||
from io import StringIO
|
|
||||||
import logging
|
|
||||||
import httpx
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import httpx
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from playwright.async_api import async_playwright
|
from playwright.async_api import async_playwright
|
||||||
from constant import MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
|
from io import StringIO
|
||||||
|
from utils import settings, handle_minio_upload
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/oqmd", tags=["OQMD"])
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search")
|
|
||||||
async def search_from_oqmd_by_composition(request: Request):
|
|
||||||
# 打印请求日志
|
|
||||||
logger.info(f"Received request: {request.method} {request.url}")
|
|
||||||
logger.info(f"Query parameters: {request.query_params}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取并解析数据
|
|
||||||
composition = request.query_params['composition']
|
|
||||||
html = await fetch_oqmd_data(composition)
|
|
||||||
basic_data, table_data, phase_data = parse_oqmd_html(html)
|
|
||||||
|
|
||||||
# 渲染并保存图表
|
|
||||||
phase_diagram_name = await render_and_save_charts(phase_data)
|
|
||||||
# 返回格式化后的响应
|
|
||||||
return format_response(basic_data, table_data, phase_diagram_name)
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
logger.error(f"OQMD API request failed: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"OQMD API request failed: {str(e)}"
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Unexpected error: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_oqmd_data(composition: str) -> str:
|
async def fetch_oqmd_data(composition: str) -> str:
|
||||||
"""
|
"""从OQMD获取数据"""
|
||||||
从OQMD获取数据
|
|
||||||
Args:
|
|
||||||
composition: 材料组成字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HTML内容字符串
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
httpx.HTTPError: 当发生HTTP相关错误时抛出
|
|
||||||
ValueError: 当响应内容无效时抛出
|
|
||||||
"""
|
|
||||||
url = f"https://www.oqmd.org/materials/composition/{composition}"
|
url = f"https://www.oqmd.org/materials/composition/{composition}"
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
@@ -79,55 +32,35 @@ async def fetch_oqmd_data(composition: str) -> str:
|
|||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
status_code = e.response.status_code
|
logger.error(f"OQMD API request failed: {str(e)}")
|
||||||
if status_code == 401:
|
raise
|
||||||
logger.error("OQMD API: Unauthorized access")
|
except httpx.TimeoutException:
|
||||||
raise httpx.HTTPError("Unauthorized access to OQMD API") from e
|
|
||||||
elif status_code == 403:
|
|
||||||
logger.error("OQMD API: Forbidden access")
|
|
||||||
raise httpx.HTTPError("Forbidden access to OQMD API") from e
|
|
||||||
elif status_code == 404:
|
|
||||||
logger.error("OQMD API: Resource not found")
|
|
||||||
raise httpx.HTTPError("Resource not found on OQMD API") from e
|
|
||||||
elif status_code >= 500:
|
|
||||||
logger.error(f"OQMD API: Server error ({status_code})")
|
|
||||||
raise httpx.HTTPError(f"OQMD API server error ({status_code})") from e
|
|
||||||
else:
|
|
||||||
logger.error(f"OQMD API request failed: {str(e)}")
|
|
||||||
raise httpx.HTTPError(f"OQMD API request failed: {str(e)}") from e
|
|
||||||
|
|
||||||
except httpx.TimeoutException as e:
|
|
||||||
logger.error("OQMD API request timed out")
|
logger.error("OQMD API request timed out")
|
||||||
raise httpx.HTTPError("OQMD API request timed out") from e
|
raise
|
||||||
|
|
||||||
except httpx.NetworkError as e:
|
except httpx.NetworkError as e:
|
||||||
logger.error(f"Network error occurred: {str(e)}")
|
logger.error(f"Network error occurred: {str(e)}")
|
||||||
raise httpx.HTTPError(f"Network error: {str(e)}") from e
|
raise
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Invalid response content: {str(e)}")
|
logger.error(f"Invalid response content: {str(e)}")
|
||||||
raise ValueError(f"Invalid response content: {str(e)}") from e
|
raise
|
||||||
|
|
||||||
def parse_oqmd_html(html: str) -> tuple[list, str, list]:
|
def parse_oqmd_html(html: str) -> tuple[list, str, list]:
|
||||||
"""
|
"""解析OQMD HTML数据"""
|
||||||
解析OQMD HTML数据
|
|
||||||
"""
|
|
||||||
soup = BeautifulSoup(html, 'html.parser')
|
soup = BeautifulSoup(html, 'html.parser')
|
||||||
|
|
||||||
# 解析基本数据
|
# 解析基本数据
|
||||||
basic_data = []
|
basic_data = []
|
||||||
basic_data.append(soup.find('h1').text.strip())
|
basic_data.append(soup.find('h1').text.strip())
|
||||||
for script in soup.find_all('p'):
|
for script in soup.find_all('p'):
|
||||||
if script:
|
if script:
|
||||||
combined_text = ""
|
combined_text = ""
|
||||||
for element in script.contents: # 遍历 <p> 的子元素
|
for element in script.contents:
|
||||||
if element.name == 'a': # 如果是 <a> 标签
|
if element.name == 'a':
|
||||||
url = "https://www.oqmd.org" + element['href']
|
url = "https://www.oqmd.org" + element['href']
|
||||||
combined_text += f"[{element.text.strip()}]({url}) "
|
combined_text += f"[{element.text.strip()}]({url}) "
|
||||||
else: # 如果是文本
|
else:
|
||||||
combined_text += element.text.strip() + " "
|
combined_text += element.text.strip() + " "
|
||||||
basic_data.append(combined_text.strip())
|
basic_data.append(combined_text.strip())
|
||||||
# import pdb
|
|
||||||
# pdb.set_trace()
|
|
||||||
|
|
||||||
# 解析表格数据
|
# 解析表格数据
|
||||||
table = soup.find('table')
|
table = soup.find('table')
|
||||||
@@ -135,7 +68,6 @@ def parse_oqmd_html(html: str) -> tuple[list, str, list]:
|
|||||||
df = pd.read_html(StringIO(str(table)))[0]
|
df = pd.read_html(StringIO(str(table)))[0]
|
||||||
df = df.fillna('')
|
df = df.fillna('')
|
||||||
df = df.replace([float('inf'), float('-inf')], '')
|
df = df.replace([float('inf'), float('-inf')], '')
|
||||||
# table_data = df.to_dict(orient='records')
|
|
||||||
table_data = df.to_markdown(index=False)
|
table_data = df.to_markdown(index=False)
|
||||||
|
|
||||||
# 提取JavaScript数据
|
# 提取JavaScript数据
|
||||||
@@ -149,15 +81,8 @@ def parse_oqmd_html(html: str) -> tuple[list, str, list]:
|
|||||||
|
|
||||||
return basic_data, table_data, phase_data
|
return basic_data, table_data, phase_data
|
||||||
|
|
||||||
|
|
||||||
async def render_and_save_charts(script_data: list) -> str:
|
async def render_and_save_charts(script_data: list) -> str:
|
||||||
"""
|
"""渲染并保存图表到MinIO"""
|
||||||
渲染并保存图表到MinIO
|
|
||||||
Returns:
|
|
||||||
str: 图片的预签名URL
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 如果图片生成或上传失败
|
|
||||||
"""
|
|
||||||
browser = None
|
browser = None
|
||||||
temp_files = []
|
temp_files = []
|
||||||
try:
|
try:
|
||||||
@@ -200,7 +125,6 @@ async def render_and_save_charts(script_data: list) -> str:
|
|||||||
await page.wait_for_timeout(5000)
|
await page.wait_for_timeout(5000)
|
||||||
|
|
||||||
# 分别截图两个图表
|
# 分别截图两个图表
|
||||||
# 获取placeholder元素位置并扩大截图区域
|
|
||||||
placeholder = page.locator('#placeholder')
|
placeholder = page.locator('#placeholder')
|
||||||
placeholder_box = await placeholder.bounding_box()
|
placeholder_box = await placeholder.bounding_box()
|
||||||
await page.screenshot(
|
await page.screenshot(
|
||||||
@@ -213,7 +137,6 @@ async def render_and_save_charts(script_data: list) -> str:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取phasediagram元素位置并扩大截图区域
|
|
||||||
phasediagram = page.locator('#phasediagram')
|
phasediagram = page.locator('#phasediagram')
|
||||||
phasediagram_box = await phasediagram.bounding_box()
|
phasediagram_box = await phasediagram.bounding_box()
|
||||||
await page.screenshot(
|
await page.screenshot(
|
||||||
@@ -245,61 +168,24 @@ async def render_and_save_charts(script_data: list) -> str:
|
|||||||
logger.error(f"Failed to process images: {str(e)}")
|
logger.error(f"Failed to process images: {str(e)}")
|
||||||
raise RuntimeError(f"Image processing failed: {str(e)}") from e
|
raise RuntimeError(f"Image processing failed: {str(e)}") from e
|
||||||
|
|
||||||
# 上传到 MinIO 的逻辑
|
# 上传到 MinIO
|
||||||
try:
|
url = handle_minio_upload(file_name, file_name)
|
||||||
minio_client = boto3.client(
|
return url
|
||||||
's3',
|
|
||||||
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
|
|
||||||
aws_access_key_id=MINIO_ACCESS_KEY,
|
|
||||||
aws_secret_access_key=MINIO_SECRET_KEY
|
|
||||||
)
|
|
||||||
|
|
||||||
bucket_name = MINIO_BUCKET
|
|
||||||
minio_client.upload_file(file_name, bucket_name, file_name, ExtraArgs={"ACL": "private"})
|
|
||||||
|
|
||||||
# 生成预签名 URL
|
|
||||||
url = minio_client.generate_presigned_url(
|
|
||||||
'get_object',
|
|
||||||
Params={'Bucket': bucket_name, 'Key': file_name},
|
|
||||||
ExpiresIn=3600
|
|
||||||
)
|
|
||||||
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to upload to MinIO: {str(e)}")
|
|
||||||
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
|
|
||||||
finally:
|
|
||||||
# 清理临时文件
|
|
||||||
for temp_file in temp_files:
|
|
||||||
try:
|
|
||||||
if os.path.exists(temp_file):
|
|
||||||
os.remove(temp_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to render and save charts: {str(e)}")
|
logger.error(f"Failed to render and save charts: {str(e)}")
|
||||||
raise RuntimeError(f"Chart rendering failed: {str(e)}") from e
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
for temp_file in temp_files:
|
||||||
|
try:
|
||||||
|
if os.path.exists(temp_file):
|
||||||
|
os.remove(temp_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}")
|
||||||
# 确保浏览器关闭
|
# 确保浏览器关闭
|
||||||
if browser:
|
if browser:
|
||||||
try:
|
try:
|
||||||
await browser.close()
|
await browser.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to close browser: {str(e)}")
|
logger.warning(f"Failed to close browser: {str(e)}")
|
||||||
|
|
||||||
def format_response(basic_data: list, table_data: str, phase_data: str) -> str:
|
|
||||||
"""
|
|
||||||
格式化响应数据
|
|
||||||
"""
|
|
||||||
response = "### OQMD Data\n"
|
|
||||||
for item in basic_data:
|
|
||||||
response += f"**{item}**\n"
|
|
||||||
response += "\n### Phase Diagram\n\n"
|
|
||||||
response += f"\n\n"
|
|
||||||
response += "\n### Compounds at this composition\n\n"
|
|
||||||
response += f"{table_data}\n"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"data": response
|
|
||||||
}
|
|
||||||
102
utils.py
102
utils.py
@@ -1,8 +1,102 @@
|
|||||||
|
"""
|
||||||
|
Author: Yutang LI
|
||||||
|
Institution: SIAT-MIC
|
||||||
|
Contact: yt.li2@siat.ac.cn
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import boto3
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Process, Manager
|
from typing import Optional
|
||||||
import asyncio
|
from pydantic import Field
|
||||||
from typing import Dict, Any, List
|
from pydantic_settings import BaseSettings
|
||||||
from mp_api.client import MPRester
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Material Project
|
||||||
|
mp_api_key: Optional[str] = Field(None, env="MP_API_KEY")
|
||||||
|
mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT")
|
||||||
|
|
||||||
|
# Proxy
|
||||||
|
http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
|
||||||
|
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
|
||||||
|
|
||||||
|
# FairChem
|
||||||
|
fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH")
|
||||||
|
fmax: Optional[float] = Field(0.05, env="FMAX")
|
||||||
|
|
||||||
|
# MinIO
|
||||||
|
minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT")
|
||||||
|
internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT")
|
||||||
|
minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY")
|
||||||
|
minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY")
|
||||||
|
minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
|
||||||
|
def get_minio_client(settings: Settings):
|
||||||
|
"""获取MinIO客户端"""
|
||||||
|
return boto3.client(
|
||||||
|
's3',
|
||||||
|
endpoint_url=settings.internal_minio_endpoint or settings.minio_endpoint,
|
||||||
|
aws_access_key_id=settings.minio_access_key,
|
||||||
|
aws_secret_access_key=settings.minio_secret_key
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_minio_upload(file_path: str, file_name: str) -> str:
|
||||||
|
"""统一处理MinIO上传"""
|
||||||
|
try:
|
||||||
|
client = get_minio_client(settings)
|
||||||
|
client.upload_file(file_path, settings.minio_bucket, file_name, ExtraArgs={"ACL": "private"})
|
||||||
|
|
||||||
|
# 生成预签名 URL
|
||||||
|
url = client.generate_presigned_url(
|
||||||
|
'get_object',
|
||||||
|
Params={'Bucket': settings.minio_bucket, 'Key': file_name},
|
||||||
|
ExpiresIn=3600
|
||||||
|
)
|
||||||
|
return url.replace(settings.internal_minio_endpoint or "", settings.minio_endpoint)
|
||||||
|
except Exception as e:
|
||||||
|
from error_handlers import handle_minio_error
|
||||||
|
return handle_minio_error(e)
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
"""配置日志记录"""
|
||||||
|
logging.config.dictConfig({
|
||||||
|
'version': 1,
|
||||||
|
'disable_existing_loggers': False,
|
||||||
|
'formatters': {
|
||||||
|
'standard': {
|
||||||
|
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
'datefmt': '%Y-%m-%d %H:%M:%S'
|
||||||
|
},
|
||||||
|
},
|
||||||
|
'handlers': {
|
||||||
|
'console': {
|
||||||
|
'level': 'INFO',
|
||||||
|
'class': 'logging.StreamHandler',
|
||||||
|
'formatter': 'standard'
|
||||||
|
},
|
||||||
|
'file': {
|
||||||
|
'level': 'DEBUG',
|
||||||
|
'class': 'logging.handlers.RotatingFileHandler',
|
||||||
|
'filename': 'mars_toolkit.log',
|
||||||
|
'maxBytes': 10485760, # 10MB
|
||||||
|
'backupCount': 5,
|
||||||
|
'formatter': 'standard'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'loggers': {
|
||||||
|
'': {
|
||||||
|
'handlers': ['console', 'file'],
|
||||||
|
'level': 'INFO',
|
||||||
|
'propagate': True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 初始化配置
|
||||||
|
settings = Settings()
|
||||||
|
|||||||
Reference in New Issue
Block a user