重构代码

This commit is contained in:
2025-01-06 14:54:41 +08:00
parent c2417fec25
commit c7d2d482da
13 changed files with 519 additions and 439 deletions

View File

@@ -2,20 +2,20 @@ TOPK_RESULT = 1
TIME_OUT = 60
# MP Configuration
MP_API_KEY = "gfBp2in8qxm9Xm2SwLKFwNxDyZvNTAEt"
MP_ENDPOINT = "https://api.materialsproject.org/"
MP_API_KEY = None
MP_ENDPOINT = None
# Proxy
HTTP_PROXY = "http://127.0.0.1:7897"
HTTPS_PROXY = "http://127.0.0.1:7897"
HTTP_PROXY = None
HTTPS_PROXY = None
# Model
FAIRCHEM_MODEL_PATH = "/home/ubuntu/sas0/LYT/mars1215/mars_toolkit/model/eqV2_86M_omat_mp_salex.pt"
FMAX = 0.05
FAIRCHEM_MODEL_PATH = None
FMAX = None
# MinIO configuration
MINIO_ENDPOINT = "https://s3-api.siat-mic.com"
INTERNEL_MINIO_ENDPOINT = "http://100.85.52.31:9000" # 内网地址,如果有就填,上传会更快。
MINIO_ACCESS_KEY = "9bUtQL1Gpo9JB6o3pSGr"
MINIO_SECRET_KEY = "1Qug5H73R3kP8boIHvdVcFtcb1jU9GRWnlmMpx0g"
MINIO_BUCKET = "temp"
MINIO_ENDPOINT = None
INTERNEL_MINIO_ENDPOINT = None
MINIO_ACCESS_KEY = None
MINIO_SECRET_KEY = None
MINIO_BUCKET = None

49
error_handlers.py Normal file
View File

@@ -0,0 +1,49 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
from fastapi import HTTPException
from typing import Any, Dict
import logging
logger = logging.getLogger(__name__)
class APIError(HTTPException):
"""自定义API错误类"""
def __init__(self, status_code: int, detail: Any = None):
super().__init__(status_code=status_code, detail=detail)
logger.error(f"API Error: {status_code} - {detail}")
def handle_minio_error(e: Exception) -> Dict[str, str]:
"""处理MinIO相关错误"""
logger.error(f"MinIO operation failed: {str(e)}")
return {
"status": "error",
"data": f"MinIO operation failed: {str(e)}"
}
def handle_http_error(e: Exception) -> Dict[str, str]:
"""处理HTTP请求错误"""
logger.error(f"HTTP request failed: {str(e)}")
return {
"status": "error",
"data": f"HTTP request failed: {str(e)}"
}
def handle_validation_error(e: Exception) -> Dict[str, str]:
"""处理数据验证错误"""
logger.error(f"Validation failed: {str(e)}")
return {
"status": "error",
"data": f"Validation failed: {str(e)}"
}
def handle_general_error(e: Exception) -> Dict[str, str]:
"""处理通用错误"""
logger.error(f"Unexpected error: {str(e)}")
return {
"status": "error",
"data": f"Unexpected error: {str(e)}"
}

58
main.py
View File

@@ -5,25 +5,55 @@ Contact: yt.li2@siat.ac.cn
"""
from fastapi import FastAPI
import logging
import os
from database.material_project_router import router as material_project_router
from database.oqmd_router import router as oqmd_router
from model.fairchem_router import router as fairchem_router, init_model
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from router.mp_router import router as material_router
from router.oqmd_router import router as oqmd_router
from router.fairchem_router import router as fairchem_router
from error_handlers import (
handle_general_error,
handle_http_error,
handle_validation_error
)
logger = logging.getLogger(__name__)
from utils import setup_logging
from router.fairchem_router import init_model
app = FastAPI()
# 初始化日志配置
setup_logging()
# 创建中间件列表
middleware = [
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
]
app = FastAPI(middleware=middleware)
@app.on_event("startup")
def startup_event():
async def startup_event():
"""应用启动时初始化模型"""
init_model()
app.include_router(material_project_router)
# 注册路由
app.include_router(material_router)
app.include_router(oqmd_router)
app.include_router(fairchem_router)
# 添加全局异常处理
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
return handle_general_error(exc)
@app.exception_handler(ValueError)
async def validation_exception_handler(request, exc):
return handle_validation_error(exc)
@app.exception_handler(ConnectionError)
async def http_exception_handler(request, exc):
return handle_http_error(exc)

View File

@@ -1,165 +0,0 @@
from fastapi import APIRouter, Body, Query
from fairchem.core import OCPCalculator
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
import tempfile
import os
import boto3
from constant import FAIRCHEM_MODEL_PATH, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT, FMAX
from typing import Optional
import logging
import datetime
router = APIRouter(prefix="/fairchem", tags=["fairchem"])
logger = logging.getLogger(__name__)
# 初始化模型
calc = None
def init_model():
global calc
calc = OCPCalculator(checkpoint_path=FAIRCHEM_MODEL_PATH)
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""将输入内容转换为Atoms对象"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""生成对称性CIF"""
analyzer = SpacegroupAnalyzer(structure)
structure = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
return tmp_file.read()
def upload_to_minio(file_path: str, file_name: str) -> str:
"""上传文件到MinIO并返回预签名URL"""
try:
minio_client = boto3.client(
's3',
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
aws_access_key_id=MINIO_ACCESS_KEY,
aws_secret_access_key=MINIO_SECRET_KEY
)
bucket_name = MINIO_BUCKET
minio_client.upload_file(file_path, bucket_name, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名 URL
url = minio_client.generate_presigned_url(
'get_object',
Params={'Bucket': bucket_name, 'Key': file_name},
ExpiresIn=3600
)
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
except Exception as e:
logger.error(f"Failed to upload to MinIO: {str(e)}")
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
from io import StringIO
import sys
def optimize_structure(atoms: Atoms, output_format: str):
"""优化晶体结构"""
atoms.calc = calc
# 捕获优化日志
old_stdout = sys.stdout
sys.stdout = log_capture = StringIO()
try:
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=FMAX)
total_energy = atoms.get_total_energy()
optimization_log = log_capture.getvalue()
finally:
sys.stdout = old_stdout
# 处理对称性
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
# 保存优化结果到临时文件
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"optimized_structure_{timestamp}.{output_format}"
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
try:
# 上传到MinIO
url = upload_to_minio(tmp_path, file_name)
return total_energy, content, url, optimization_log
finally:
os.unlink(tmp_path)
@router.post("/optimize_structure")
async def optimize_structure_endpoint(
content: str = Body(..., description="Input structure content"),
input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"),
output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)")
):
# 转换输入结构
atoms = convert_structure(input_format, content)
if atoms is None:
return {
"status": "error",
"data": f"Invalid {input_format} content"
}
try:
# 优化结构
total_energy, optimized_content, download_url, optimization_log = optimize_structure(atoms, output_format)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
```text
{optimization_log}
```
Finally, the Total Energy is: {total_energy} eV
Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link:
👉 Click [here]({download_url}) to download the {output_format.upper()} file
Please ensure that the Optimization Results and download link are fully conveyed to the user, as this is very important for them.
"""
print(format_result)
return {
"status": "success",
"data": format_result
}
except Exception as e:
logger.error(f"Optimization failed: {str(e)}")
return {
"status": "error",
"data": str(e)
}
if __name__ == "__main__":
init_model()

60
router/fairchem_router.py Normal file
View File

@@ -0,0 +1,60 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
from fastapi import APIRouter, Body, Query
from fastapi.responses import JSONResponse
import logging
from error_handlers import handle_general_error
from services.fairchem_service import (
init_model,
convert_structure,
optimize_structure
)
router = APIRouter(prefix="/fairchem", tags=["fairchem"])
logger = logging.getLogger(__name__)
# 初始化模型
init_model()
@router.post("/optimize_structure")
async def optimize_structure_endpoint(
content: str = Body(..., description="Input structure content"),
input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"),
output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)")
):
try:
# 转换输入结构
atoms = convert_structure(input_format, content)
if atoms is None:
return JSONResponse(
status_code=400,
content={"status": "error", "data": f"Invalid {input_format} content"}
)
# 优化结构
total_energy, optimized_content, download_url = optimize_structure(atoms, output_format)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
Total Energy: {total_energy} eV
Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link:
👉 Click [here]({download_url}) to download the {output_format.upper()} file
"""
return JSONResponse(
status_code=200,
content={"status": "success", "data": format_result}
)
except Exception as e:
return handle_general_error(e)
if __name__ == "__main__":
init_model()

70
router/mp_router.py Normal file
View File

@@ -0,0 +1,70 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
from fastapi import APIRouter, Request
import json
import logging
import datetime
from typing import Dict
from services.mp_service import (
parse_search_parameters,
process_search_results,
execute_search
)
from utils import handle_minio_upload
from error_handlers import handle_general_error
router = APIRouter(prefix="/mp", tags=["Material Project"])
logger = logging.getLogger(__name__)
@router.get("/search")
async def search_from_material_project(request: Request):
# 打印请求日志
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Query parameters: {request.query_params}")
try:
# 解析查询参数
search_args = parse_search_parameters(request.query_params)
# 执行搜索
docs = await execute_search(search_args)
# 处理搜索结果
res = process_search_results(docs)
if len(res) == 0:
return {"status": "success", "data": "No results found, please try again."}
# 上传结果到MinIO
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"mp_search_results_{timestamp}.json"
# 将结果写入临时文件
with open(file_name, 'w') as f:
json.dump(res, f, indent=2)
# 上传并获取URL
url = handle_minio_upload(file_name, file_name)
# 删除临时文件
os.remove(file_name)
# 格式化返回结果
res_chunk = "```json\n" + json.dumps(res[:5], indent=2) + "\n```"
res_template = f"""
好的,以下是用户的查询结果:
由于返回长度的限制我们只能返回前5个结果。如下
{res_chunk}
如果用户需要更多的结果,请提示用户修改查询条件,或者尝试使用其他查询参数。
同时我们将全部的的查询结果上传到MinIO中请你提示用户可以通过以下链接下载
[Download]({url})
"""
return {"status": "success", "data": res_template}
except Exception as e:
return handle_general_error(e)

57
router/oqmd_router.py Normal file
View File

@@ -0,0 +1,57 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
import logging
from error_handlers import handle_general_error
from services.oqmd_service import (
fetch_oqmd_data,
parse_oqmd_html,
render_and_save_charts
)
router = APIRouter(prefix="/oqmd", tags=["OQMD"])
logger = logging.getLogger(__name__)
@router.get("/search")
async def search_from_oqmd_by_composition(request: Request):
"""通过成分搜索OQMD数据"""
try:
# 打印请求日志
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Query parameters: {request.query_params}")
# 获取并解析数据
composition = request.query_params['composition']
html = await fetch_oqmd_data(composition)
basic_data, table_data, phase_data = parse_oqmd_html(html)
# 渲染并保存图表
phase_diagram_url = await render_and_save_charts(phase_data)
# 返回格式化后的响应
return JSONResponse(
status_code=200,
content={
"status": "success",
"data": format_response(basic_data, table_data, phase_diagram_url)
}
)
except Exception as e:
return handle_general_error(e)
def format_response(basic_data: list, table_data: str, phase_data: str) -> str:
"""格式化响应数据"""
response = "### OQMD Data\n"
for item in basic_data:
response += f"**{item}**\n"
response += "\n### Phase Diagram\n\n"
response += f"![Phase Diagram]({phase_data})\n\n"
response += "\n### Compounds at this composition\n\n"
response += f"{table_data}\n"
return response

0
services/__init__.py Normal file
View File

View File

@@ -0,0 +1,92 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import logging
import tempfile
import os
import datetime
from typing import Optional
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from utils import settings, handle_minio_upload
logger = logging.getLogger(__name__)
# 初始化模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
try:
from fairchem.core import OCPCalculator
calc = OCPCalculator(checkpoint_path=settings.fairchem_model_path)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""将输入内容转换为Atoms对象"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""生成对称性CIF"""
analyzer = SpacegroupAnalyzer(structure)
structure = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
return tmp_file.read()
def optimize_structure(atoms: Atoms, output_format: str):
"""优化晶体结构"""
atoms.calc = calc
try:
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=settings.fmax)
total_energy = atoms.get_total_energy()
# 处理对称性
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
# 保存优化结果到临时文件
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"optimized_structure_{timestamp}.{output_format}"
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
# 上传到MinIO
url = handle_minio_upload(tmp_path, file_name)
return total_energy, content, url
finally:
os.unlink(tmp_path)

View File

@@ -4,117 +4,28 @@ Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import boto3
from fastapi import APIRouter, Request
import json
import asyncio
import logging
import datetime
from mp_api.client import MPRester
from multiprocessing import Process, Manager
from typing import Dict, Any, List
from constant import MP_ENDPOINT, MP_API_KEY, TIME_OUT, TOPK_RESULT, HTTP_PROXY, HTTPS_PROXY, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
router = APIRouter(prefix="/mp", tags=["Material Project"])
from mp_api.client import MPRester
from utils import settings, handle_minio_upload
from error_handlers import handle_general_error
logger = logging.getLogger(__name__)
@router.get("/search")
async def search_from_material_project(request: Request):
# 打印请求日志
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Query parameters: {request.query_params}")
# 解析查询参数
search_args = parse_search_parameters(request.query_params)
# 检查API key
if MP_API_KEY is None or MP_API_KEY == '':
return 'Material Project API CANNOT Be None'
try:
# 执行搜索
docs = await execute_search(search_args)
# 处理搜索结果
res = process_search_results(docs)
url = ""
# 返回结果
if len(res) == 0:
return {"status": "success", "data": "No results found, please try again."}
else:
# 上传结果到MinIO
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"mp_search_results_{timestamp}.json"
try:
minio_client = boto3.client(
's3',
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
aws_access_key_id=MINIO_ACCESS_KEY,
aws_secret_access_key=MINIO_SECRET_KEY
)
# 将结果写入临时文件
with open(file_name, 'w') as f:
json.dump(res, f, indent=2)
# 上传到MinIO
minio_client.upload_file(file_name, MINIO_BUCKET, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名URL
url = minio_client.generate_presigned_url(
'get_object',
Params={'Bucket': MINIO_BUCKET, 'Key': file_name},
ExpiresIn=3600
)
url = url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
# 删除临时文件
os.remove(file_name)
except Exception as e:
logger.error(f"Failed to upload to MinIO: {str(e)}")
return {
"status": "error",
"data": f"Failed to upload results to MinIO: {str(e)}"
}
# 格式化返回结果
res_chunk = "```json\n" + json.dumps(res[:TOPK_RESULT], indent=2) + "\n```"
res_template = f"""
好的以下是用户的查询结果
由于返回长度的限制我们只能返回前{TOPK_RESULT}个结果如下
{res_chunk}
如果用户需要更多的结果请提示用户修改查询条件或者尝试使用其他查询参数
同时我们将全部的的查询结果上传到MinIO中请你提示用户可以通过以下链接下载
[Download]({url})
"""
return {"status": "success", "data": res_template}
except asyncio.TimeoutError:
logger.error(f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again.")
return {
"status": "error",
"data": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again."
}
def parse_bool(param: str) -> bool | None:
if not param:
return None
return param.lower() == 'true'
def parse_list(param: str) -> List[str] | None:
if not param:
return None
return param.split(',')
def parse_tuple(param: str) -> tuple[float, float] | None:
if not param:
return None
@@ -124,7 +35,6 @@ def parse_tuple(param: str) -> tuple[float, float] | None:
except (ValueError, IndexError):
return None
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
"""解析搜索参数"""
return {
@@ -147,7 +57,6 @@ def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
'chunk_size': int(query_params.get('chunk_size', '5'))
}
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""处理搜索结果"""
fields = [
@@ -171,13 +80,12 @@ def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
continue
return res
async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -> List[Dict[str, Any]]:
async def execute_search(search_args: Dict[str, Any], timeout: int = 30) -> List[Dict[str, Any]]:
"""执行搜索"""
manager = Manager()
queue = manager.Queue()
p = Process(target=_search_worker, args=(queue, MP_API_KEY), kwargs=search_args)
p = Process(target=_search_worker, args=(queue, settings.mp_api_key), kwargs=search_args)
p.start()
logger.info(f"Started worker process with PID: {p.pid}")
@@ -207,14 +115,13 @@ async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -
logger.info(f"Successfully retrieved {len(result)} documents")
return result
def _search_worker(queue, api_key, **kwargs):
"""搜索工作线程"""
try:
import os
os.environ['HTTP_PROXY'] = HTTP_PROXY
os.environ['HTTPS_PROXY'] = HTTPS_PROXY
mpr = MPRester(api_key, endpoint=MP_ENDPOINT)
os.environ['HTTP_PROXY'] = settings.http_proxy or ''
os.environ['HTTPS_PROXY'] = settings.https_proxy or ''
mpr = MPRester(api_key, endpoint=settings.mp_endpoint)
result = mpr.materials.summary.search(**kwargs)
queue.put([doc.model_dump() for doc in result])
except Exception as e:

View File

@@ -3,69 +3,22 @@ Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import boto3
from fastapi import APIRouter, Request
from io import StringIO
import logging
import httpx
import datetime
import logging
import os
import httpx
import pandas as pd
from bs4 import BeautifulSoup
from PIL import Image
from playwright.async_api import async_playwright
from constant import MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
from io import StringIO
from utils import settings, handle_minio_upload
router = APIRouter(prefix="/oqmd", tags=["OQMD"])
logger = logging.getLogger(__name__)
@router.get("/search")
async def search_from_oqmd_by_composition(request: Request):
# 打印请求日志
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Query parameters: {request.query_params}")
try:
# 获取并解析数据
composition = request.query_params['composition']
html = await fetch_oqmd_data(composition)
basic_data, table_data, phase_data = parse_oqmd_html(html)
# 渲染并保存图表
phase_diagram_name = await render_and_save_charts(phase_data)
# 返回格式化后的响应
return format_response(basic_data, table_data, phase_diagram_name)
except httpx.HTTPStatusError as e:
logger.error(f"OQMD API request failed: {str(e)}")
return {
"status": "error",
"message": f"OQMD API request failed: {str(e)}"
}
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return {
"status": "error",
"message": f"Unexpected error: {str(e)}"
}
async def fetch_oqmd_data(composition: str) -> str:
"""
从OQMD获取数据
Args:
composition: 材料组成字符串
Returns:
HTML内容字符串
Raises:
httpx.HTTPError: 当发生HTTP相关错误时抛出
ValueError: 当响应内容无效时抛出
"""
"""从OQMD获取数据"""
url = f"https://www.oqmd.org/materials/composition/{composition}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
@@ -79,55 +32,35 @@ async def fetch_oqmd_data(composition: str) -> str:
return response.text
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
if status_code == 401:
logger.error("OQMD API: Unauthorized access")
raise httpx.HTTPError("Unauthorized access to OQMD API") from e
elif status_code == 403:
logger.error("OQMD API: Forbidden access")
raise httpx.HTTPError("Forbidden access to OQMD API") from e
elif status_code == 404:
logger.error("OQMD API: Resource not found")
raise httpx.HTTPError("Resource not found on OQMD API") from e
elif status_code >= 500:
logger.error(f"OQMD API: Server error ({status_code})")
raise httpx.HTTPError(f"OQMD API server error ({status_code})") from e
else:
logger.error(f"OQMD API request failed: {str(e)}")
raise httpx.HTTPError(f"OQMD API request failed: {str(e)}") from e
except httpx.TimeoutException as e:
logger.error(f"OQMD API request failed: {str(e)}")
raise
except httpx.TimeoutException:
logger.error("OQMD API request timed out")
raise httpx.HTTPError("OQMD API request timed out") from e
raise
except httpx.NetworkError as e:
logger.error(f"Network error occurred: {str(e)}")
raise httpx.HTTPError(f"Network error: {str(e)}") from e
raise
except ValueError as e:
logger.error(f"Invalid response content: {str(e)}")
raise ValueError(f"Invalid response content: {str(e)}") from e
raise
def parse_oqmd_html(html: str) -> tuple[list, str, list]:
"""
解析OQMD HTML数据
"""
"""解析OQMD HTML数据"""
soup = BeautifulSoup(html, 'html.parser')
# 解析基本数据
basic_data = []
basic_data.append(soup.find('h1').text.strip())
for script in soup.find_all('p'):
if script:
combined_text = ""
for element in script.contents: # 遍历 <p> 的子元素
if element.name == 'a': # 如果是 <a> 标签
for element in script.contents:
if element.name == 'a':
url = "https://www.oqmd.org" + element['href']
combined_text += f"[{element.text.strip()}]({url}) "
else: # 如果是文本
else:
combined_text += element.text.strip() + " "
basic_data.append(combined_text.strip())
# import pdb
# pdb.set_trace()
# 解析表格数据
table = soup.find('table')
@@ -135,7 +68,6 @@ def parse_oqmd_html(html: str) -> tuple[list, str, list]:
df = pd.read_html(StringIO(str(table)))[0]
df = df.fillna('')
df = df.replace([float('inf'), float('-inf')], '')
# table_data = df.to_dict(orient='records')
table_data = df.to_markdown(index=False)
# 提取JavaScript数据
@@ -149,15 +81,8 @@ def parse_oqmd_html(html: str) -> tuple[list, str, list]:
return basic_data, table_data, phase_data
async def render_and_save_charts(script_data: list) -> str:
"""
渲染并保存图表到MinIO
Returns:
str: 图片的预签名URL
Raises:
RuntimeError: 如果图片生成或上传失败
"""
"""渲染并保存图表到MinIO"""
browser = None
temp_files = []
try:
@@ -200,7 +125,6 @@ async def render_and_save_charts(script_data: list) -> str:
await page.wait_for_timeout(5000)
# 分别截图两个图表
# 获取placeholder元素位置并扩大截图区域
placeholder = page.locator('#placeholder')
placeholder_box = await placeholder.bounding_box()
await page.screenshot(
@@ -213,7 +137,6 @@ async def render_and_save_charts(script_data: list) -> str:
}
)
# 获取phasediagram元素位置并扩大截图区域
phasediagram = page.locator('#phasediagram')
phasediagram_box = await phasediagram.bounding_box()
await page.screenshot(
@@ -245,61 +168,24 @@ async def render_and_save_charts(script_data: list) -> str:
logger.error(f"Failed to process images: {str(e)}")
raise RuntimeError(f"Image processing failed: {str(e)}") from e
# 上传到 MinIO 的逻辑
try:
minio_client = boto3.client(
's3',
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
aws_access_key_id=MINIO_ACCESS_KEY,
aws_secret_access_key=MINIO_SECRET_KEY
)
bucket_name = MINIO_BUCKET
minio_client.upload_file(file_name, bucket_name, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名 URL
url = minio_client.generate_presigned_url(
'get_object',
Params={'Bucket': bucket_name, 'Key': file_name},
ExpiresIn=3600
)
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
# 上传到 MinIO
url = handle_minio_upload(file_name, file_name)
return url
except Exception as e:
logger.error(f"Failed to upload to MinIO: {str(e)}")
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
finally:
# 清理临时文件
for temp_file in temp_files:
try:
if os.path.exists(temp_file):
os.remove(temp_file)
except Exception as e:
logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}")
except Exception as e:
logger.error(f"Failed to render and save charts: {str(e)}")
raise RuntimeError(f"Chart rendering failed: {str(e)}") from e
raise
finally:
# 清理临时文件
for temp_file in temp_files:
try:
if os.path.exists(temp_file):
os.remove(temp_file)
except Exception as e:
logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}")
# 确保浏览器关闭
if browser:
try:
await browser.close()
except Exception as e:
logger.warning(f"Failed to close browser: {str(e)}")
def format_response(basic_data: list, table_data: str, phase_data: str) -> str:
"""
格式化响应数据
"""
response = "### OQMD Data\n"
for item in basic_data:
response += f"**{item}**\n"
response += "\n### Phase Diagram\n\n"
response += f"![Phase Diagram]({phase_data})\n\n"
response += "\n### Compounds at this composition\n\n"
response += f"{table_data}\n"
return {
"status": "success",
"data": response
}

102
utils.py
View File

@@ -1,8 +1,102 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import boto3
import logging
from multiprocessing import Process, Manager
import asyncio
from typing import Dict, Any, List
from mp_api.client import MPRester
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
logger = logging.getLogger(__name__)
class Settings(BaseSettings):
# Material Project
mp_api_key: Optional[str] = Field(None, env="MP_API_KEY")
mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT")
# Proxy
http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
# FairChem
fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH")
fmax: Optional[float] = Field(0.05, env="FMAX")
# MinIO
minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT")
internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT")
minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY")
minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY")
minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET")
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
def get_minio_client(settings: Settings):
"""获取MinIO客户端"""
return boto3.client(
's3',
endpoint_url=settings.internal_minio_endpoint or settings.minio_endpoint,
aws_access_key_id=settings.minio_access_key,
aws_secret_access_key=settings.minio_secret_key
)
def handle_minio_upload(file_path: str, file_name: str) -> str:
"""统一处理MinIO上传"""
try:
client = get_minio_client(settings)
client.upload_file(file_path, settings.minio_bucket, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名 URL
url = client.generate_presigned_url(
'get_object',
Params={'Bucket': settings.minio_bucket, 'Key': file_name},
ExpiresIn=3600
)
return url.replace(settings.internal_minio_endpoint or "", settings.minio_endpoint)
except Exception as e:
from error_handlers import handle_minio_error
return handle_minio_error(e)
def setup_logging():
"""配置日志记录"""
logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S'
},
},
'handlers': {
'console': {
'level': 'INFO',
'class': 'logging.StreamHandler',
'formatter': 'standard'
},
'file': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': 'mars_toolkit.log',
'maxBytes': 10485760, # 10MB
'backupCount': 5,
'formatter': 'standard'
}
},
'loggers': {
'': {
'handlers': ['console', 'file'],
'level': 'INFO',
'propagate': True
}
}
})
# 初始化配置
settings = Settings()