diff --git a/constant.py b/constant.py index e3e23cc..f221c78 100644 --- a/constant.py +++ b/constant.py @@ -2,20 +2,20 @@ TOPK_RESULT = 1 TIME_OUT = 60 # MP Configuration -MP_API_KEY = "gfBp2in8qxm9Xm2SwLKFwNxDyZvNTAEt" -MP_ENDPOINT = "https://api.materialsproject.org/" +MP_API_KEY = None +MP_ENDPOINT = None # Proxy -HTTP_PROXY = "http://127.0.0.1:7897" -HTTPS_PROXY = "http://127.0.0.1:7897" +HTTP_PROXY = None +HTTPS_PROXY = None # Model -FAIRCHEM_MODEL_PATH = "/home/ubuntu/sas0/LYT/mars1215/mars_toolkit/model/eqV2_86M_omat_mp_salex.pt" -FMAX = 0.05 +FAIRCHEM_MODEL_PATH = None +FMAX = None # MinIO configuration -MINIO_ENDPOINT = "https://s3-api.siat-mic.com" -INTERNEL_MINIO_ENDPOINT = "http://100.85.52.31:9000" # 内网地址,如果有就填,上传会更快。 -MINIO_ACCESS_KEY = "9bUtQL1Gpo9JB6o3pSGr" -MINIO_SECRET_KEY = "1Qug5H73R3kP8boIHvdVcFtcb1jU9GRWnlmMpx0g" -MINIO_BUCKET = "temp" +MINIO_ENDPOINT = None +INTERNEL_MINIO_ENDPOINT = None +MINIO_ACCESS_KEY = None +MINIO_SECRET_KEY = None +MINIO_BUCKET = None diff --git a/error_handlers.py b/error_handlers.py new file mode 100644 index 0000000..9669eb6 --- /dev/null +++ b/error_handlers.py @@ -0,0 +1,49 @@ +""" +Author: Yutang LI +Institution: SIAT-MIC +Contact: yt.li2@siat.ac.cn +""" + +from fastapi import HTTPException +from typing import Any, Dict +import logging + +logger = logging.getLogger(__name__) + +class APIError(HTTPException): + """自定义API错误类""" + def __init__(self, status_code: int, detail: Any = None): + super().__init__(status_code=status_code, detail=detail) + logger.error(f"API Error: {status_code} - {detail}") + +def handle_minio_error(e: Exception) -> Dict[str, str]: + """处理MinIO相关错误""" + logger.error(f"MinIO operation failed: {str(e)}") + return { + "status": "error", + "data": f"MinIO operation failed: {str(e)}" + } + +def handle_http_error(e: Exception) -> Dict[str, str]: + """处理HTTP请求错误""" + logger.error(f"HTTP request failed: {str(e)}") + return { + "status": "error", + "data": f"HTTP request failed: {str(e)}" + } + +def handle_validation_error(e: Exception) -> Dict[str, str]: + """处理数据验证错误""" + logger.error(f"Validation failed: {str(e)}") + return { + "status": "error", + "data": f"Validation failed: {str(e)}" + } + +def handle_general_error(e: Exception) -> Dict[str, str]: + """处理通用错误""" + logger.error(f"Unexpected error: {str(e)}") + return { + "status": "error", + "data": f"Unexpected error: {str(e)}" + } diff --git a/main.py b/main.py index 6f3fb5f..b88571e 100644 --- a/main.py +++ b/main.py @@ -5,25 +5,55 @@ Contact: yt.li2@siat.ac.cn """ from fastapi import FastAPI -import logging -import os -from database.material_project_router import router as material_project_router -from database.oqmd_router import router as oqmd_router -from model.fairchem_router import router as fairchem_router, init_model - - -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware +from router.mp_router import router as material_router +from router.oqmd_router import router as oqmd_router +from router.fairchem_router import router as fairchem_router +from error_handlers import ( + handle_general_error, + handle_http_error, + handle_validation_error ) -logger = logging.getLogger(__name__) +from utils import setup_logging +from router.fairchem_router import init_model -app = FastAPI() +# 初始化日志配置 +setup_logging() + +# 创建中间件列表 +middleware = [ + Middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) +] + +app = FastAPI(middleware=middleware) @app.on_event("startup") -def startup_event(): +async def startup_event(): + """应用启动时初始化模型""" init_model() -app.include_router(material_project_router) +# 注册路由 +app.include_router(material_router) app.include_router(oqmd_router) app.include_router(fairchem_router) + +# 添加全局异常处理 +@app.exception_handler(Exception) +async def global_exception_handler(request, exc): + return handle_general_error(exc) + +@app.exception_handler(ValueError) +async def validation_exception_handler(request, exc): + return handle_validation_error(exc) + +@app.exception_handler(ConnectionError) +async def http_exception_handler(request, exc): + return handle_http_error(exc) diff --git a/model/fairchem_router.py b/model/fairchem_router.py deleted file mode 100644 index 4db578c..0000000 --- a/model/fairchem_router.py +++ /dev/null @@ -1,165 +0,0 @@ -from fastapi import APIRouter, Body, Query -from fairchem.core import OCPCalculator -from ase.optimize import FIRE -from ase.filters import FrechetCellFilter -from ase.atoms import Atoms -from ase.io import read, write -from pymatgen.core.structure import Structure -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from pymatgen.io.cif import CifWriter -import tempfile -import os -import boto3 -from constant import FAIRCHEM_MODEL_PATH, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT, FMAX -from typing import Optional -import logging -import datetime - -router = APIRouter(prefix="/fairchem", tags=["fairchem"]) -logger = logging.getLogger(__name__) - -# 初始化模型 -calc = None - -def init_model(): - global calc - calc = OCPCalculator(checkpoint_path=FAIRCHEM_MODEL_PATH) - -def convert_structure(input_format: str, content: str) -> Optional[Atoms]: - """将输入内容转换为Atoms对象""" - try: - with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file: - tmp_file.write(content) - tmp_path = tmp_file.name - - atoms = read(tmp_path) - os.unlink(tmp_path) - return atoms - except Exception as e: - logger.error(f"Failed to convert structure: {str(e)}") - return None - -def generate_symmetry_cif(structure: Structure) -> str: - """生成对称性CIF""" - analyzer = SpacegroupAnalyzer(structure) - structure = analyzer.get_refined_structure() - - with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file: - cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True) - cif_writer.write_file(tmp_file.name) - tmp_file.seek(0) - return tmp_file.read() - -def upload_to_minio(file_path: str, file_name: str) -> str: - """上传文件到MinIO并返回预签名URL""" - try: - minio_client = boto3.client( - 's3', - endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT, - aws_access_key_id=MINIO_ACCESS_KEY, - aws_secret_access_key=MINIO_SECRET_KEY - ) - - bucket_name = MINIO_BUCKET - minio_client.upload_file(file_path, bucket_name, file_name, ExtraArgs={"ACL": "private"}) - - # 生成预签名 URL - url = minio_client.generate_presigned_url( - 'get_object', - Params={'Bucket': bucket_name, 'Key': file_name}, - ExpiresIn=3600 - ) - return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT) - except Exception as e: - logger.error(f"Failed to upload to MinIO: {str(e)}") - raise RuntimeError(f"MinIO upload failed: {str(e)}") from e - -from io import StringIO -import sys - -def optimize_structure(atoms: Atoms, output_format: str): - """优化晶体结构""" - atoms.calc = calc - - # 捕获优化日志 - old_stdout = sys.stdout - sys.stdout = log_capture = StringIO() - - try: - dyn = FIRE(FrechetCellFilter(atoms)) - dyn.run(fmax=FMAX) - total_energy = atoms.get_total_energy() - optimization_log = log_capture.getvalue() - finally: - sys.stdout = old_stdout - - # 处理对称性 - if output_format == "cif": - optimized_structure = Structure.from_ase_atoms(atoms) - content = generate_symmetry_cif(optimized_structure) - else: - with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file: - write(tmp_file.name, atoms) - tmp_file.seek(0) - content = tmp_file.read() - - # 保存优化结果到临时文件 - timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - file_name = f"optimized_structure_{timestamp}.{output_format}" - with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file: - tmp_file.write(content) - tmp_path = tmp_file.name - try: - # 上传到MinIO - url = upload_to_minio(tmp_path, file_name) - return total_energy, content, url, optimization_log - finally: - os.unlink(tmp_path) - -@router.post("/optimize_structure") -async def optimize_structure_endpoint( - content: str = Body(..., description="Input structure content"), - input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"), - output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)") -): - # 转换输入结构 - atoms = convert_structure(input_format, content) - if atoms is None: - return { - "status": "error", - "data": f"Invalid {input_format} content" - } - - try: - # 优化结构 - total_energy, optimized_content, download_url, optimization_log = optimize_structure(atoms, output_format) - - # 格式化返回结果 - format_result = f""" -The following is the optimized crystal structure information: - -### Optimization Results (using FIRE(eqV2_86M) algorithm): -```text -{optimization_log} -``` -Finally, the Total Energy is: {total_energy} eV -Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link: -👉 Click [here]({download_url}) to download the {output_format.upper()} file - -Please ensure that the Optimization Results and download link are fully conveyed to the user, as this is very important for them. -""" - print(format_result) - return { - "status": "success", - "data": format_result - } - - except Exception as e: - logger.error(f"Optimization failed: {str(e)}") - return { - "status": "error", - "data": str(e) - } - -if __name__ == "__main__": - init_model() diff --git a/database/__init__.py b/router/__init__.py similarity index 100% rename from database/__init__.py rename to router/__init__.py diff --git a/router/fairchem_router.py b/router/fairchem_router.py new file mode 100644 index 0000000..5d26612 --- /dev/null +++ b/router/fairchem_router.py @@ -0,0 +1,60 @@ +""" +Author: Yutang LI +Institution: SIAT-MIC +Contact: yt.li2@siat.ac.cn +""" + +from fastapi import APIRouter, Body, Query +from fastapi.responses import JSONResponse +import logging +from error_handlers import handle_general_error +from services.fairchem_service import ( + init_model, + convert_structure, + optimize_structure +) + +router = APIRouter(prefix="/fairchem", tags=["fairchem"]) +logger = logging.getLogger(__name__) + +# 初始化模型 +init_model() + +@router.post("/optimize_structure") +async def optimize_structure_endpoint( + content: str = Body(..., description="Input structure content"), + input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"), + output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)") +): + try: + # 转换输入结构 + atoms = convert_structure(input_format, content) + if atoms is None: + return JSONResponse( + status_code=400, + content={"status": "error", "data": f"Invalid {input_format} content"} + ) + + # 优化结构 + total_energy, optimized_content, download_url = optimize_structure(atoms, output_format) + + # 格式化返回结果 + format_result = f""" +The following is the optimized crystal structure information: + +### Optimization Results (using FIRE(eqV2_86M) algorithm): +Total Energy: {total_energy} eV + +Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link: +👉 Click [here]({download_url}) to download the {output_format.upper()} file +""" + return JSONResponse( + status_code=200, + content={"status": "success", "data": format_result} + ) + + except Exception as e: + return handle_general_error(e) + +if __name__ == "__main__": + init_model() diff --git a/router/mp_router.py b/router/mp_router.py new file mode 100644 index 0000000..5992ec6 --- /dev/null +++ b/router/mp_router.py @@ -0,0 +1,70 @@ +""" +Author: Yutang LI +Institution: SIAT-MIC +Contact: yt.li2@siat.ac.cn +""" + +import os +from fastapi import APIRouter, Request +import json +import logging +import datetime +from typing import Dict +from services.mp_service import ( + parse_search_parameters, + process_search_results, + execute_search +) +from utils import handle_minio_upload +from error_handlers import handle_general_error + +router = APIRouter(prefix="/mp", tags=["Material Project"]) +logger = logging.getLogger(__name__) + +@router.get("/search") +async def search_from_material_project(request: Request): + # 打印请求日志 + logger.info(f"Received request: {request.method} {request.url}") + logger.info(f"Query parameters: {request.query_params}") + + try: + # 解析查询参数 + search_args = parse_search_parameters(request.query_params) + + # 执行搜索 + docs = await execute_search(search_args) + + # 处理搜索结果 + res = process_search_results(docs) + + if len(res) == 0: + return {"status": "success", "data": "No results found, please try again."} + + # 上传结果到MinIO + timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + file_name = f"mp_search_results_{timestamp}.json" + + # 将结果写入临时文件 + with open(file_name, 'w') as f: + json.dump(res, f, indent=2) + + # 上传并获取URL + url = handle_minio_upload(file_name, file_name) + + # 删除临时文件 + os.remove(file_name) + + # 格式化返回结果 + res_chunk = "```json\n" + json.dumps(res[:5], indent=2) + "\n```" + res_template = f""" +好的,以下是用户的查询结果: +由于返回长度的限制,我们只能返回前5个结果。如下: +{res_chunk} +如果用户需要更多的结果,请提示用户修改查询条件,或者尝试使用其他查询参数。 +同时我们将全部的的查询结果上传到MinIO中,请你提示用户可以通过以下链接下载: +[Download]({url}) +""" + return {"status": "success", "data": res_template} + + except Exception as e: + return handle_general_error(e) diff --git a/router/oqmd_router.py b/router/oqmd_router.py new file mode 100644 index 0000000..3d7c294 --- /dev/null +++ b/router/oqmd_router.py @@ -0,0 +1,57 @@ +""" +Author: Yutang LI +Institution: SIAT-MIC +Contact: yt.li2@siat.ac.cn +""" +import os +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse +import logging +from error_handlers import handle_general_error +from services.oqmd_service import ( + fetch_oqmd_data, + parse_oqmd_html, + render_and_save_charts +) + +router = APIRouter(prefix="/oqmd", tags=["OQMD"]) +logger = logging.getLogger(__name__) + +@router.get("/search") +async def search_from_oqmd_by_composition(request: Request): + """通过成分搜索OQMD数据""" + try: + # 打印请求日志 + logger.info(f"Received request: {request.method} {request.url}") + logger.info(f"Query parameters: {request.query_params}") + + # 获取并解析数据 + composition = request.query_params['composition'] + html = await fetch_oqmd_data(composition) + basic_data, table_data, phase_data = parse_oqmd_html(html) + + # 渲染并保存图表 + phase_diagram_url = await render_and_save_charts(phase_data) + + # 返回格式化后的响应 + return JSONResponse( + status_code=200, + content={ + "status": "success", + "data": format_response(basic_data, table_data, phase_diagram_url) + } + ) + + except Exception as e: + return handle_general_error(e) + +def format_response(basic_data: list, table_data: str, phase_data: str) -> str: + """格式化响应数据""" + response = "### OQMD Data\n" + for item in basic_data: + response += f"**{item}**\n" + response += "\n### Phase Diagram\n\n" + response += f"![Phase Diagram]({phase_data})\n\n" + response += "\n### Compounds at this composition\n\n" + response += f"{table_data}\n" + return response diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/fairchem_service.py b/services/fairchem_service.py new file mode 100644 index 0000000..155b9c6 --- /dev/null +++ b/services/fairchem_service.py @@ -0,0 +1,92 @@ +""" +Author: Yutang LI +Institution: SIAT-MIC +Contact: yt.li2@siat.ac.cn +""" + +import logging +import tempfile +import os +import datetime +from typing import Optional +from ase.optimize import FIRE +from ase.filters import FrechetCellFilter +from ase.atoms import Atoms +from ase.io import read, write +from pymatgen.core.structure import Structure +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from pymatgen.io.cif import CifWriter +from utils import settings, handle_minio_upload + +logger = logging.getLogger(__name__) + +# 初始化模型 +calc = None + +def init_model(): + """初始化FairChem模型""" + global calc + try: + from fairchem.core import OCPCalculator + calc = OCPCalculator(checkpoint_path=settings.fairchem_model_path) + logger.info("FairChem model initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize FairChem model: {str(e)}") + raise + +def convert_structure(input_format: str, content: str) -> Optional[Atoms]: + """将输入内容转换为Atoms对象""" + try: + with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file: + tmp_file.write(content) + tmp_path = tmp_file.name + + atoms = read(tmp_path) + os.unlink(tmp_path) + return atoms + except Exception as e: + logger.error(f"Failed to convert structure: {str(e)}") + return None + +def generate_symmetry_cif(structure: Structure) -> str: + """生成对称性CIF""" + analyzer = SpacegroupAnalyzer(structure) + structure = analyzer.get_refined_structure() + + with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file: + cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True) + cif_writer.write_file(tmp_file.name) + tmp_file.seek(0) + return tmp_file.read() + +def optimize_structure(atoms: Atoms, output_format: str): + """优化晶体结构""" + atoms.calc = calc + + try: + dyn = FIRE(FrechetCellFilter(atoms)) + dyn.run(fmax=settings.fmax) + total_energy = atoms.get_total_energy() + + # 处理对称性 + if output_format == "cif": + optimized_structure = Structure.from_ase_atoms(atoms) + content = generate_symmetry_cif(optimized_structure) + else: + with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file: + write(tmp_file.name, atoms) + tmp_file.seek(0) + content = tmp_file.read() + + # 保存优化结果到临时文件 + timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + file_name = f"optimized_structure_{timestamp}.{output_format}" + with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file: + tmp_file.write(content) + tmp_path = tmp_file.name + + # 上传到MinIO + url = handle_minio_upload(tmp_path, file_name) + return total_energy, content, url + finally: + os.unlink(tmp_path) diff --git a/database/material_project_router.py b/services/mp_service.py similarity index 52% rename from database/material_project_router.py rename to services/mp_service.py index d3d106e..0ad1c46 100644 --- a/database/material_project_router.py +++ b/services/mp_service.py @@ -4,117 +4,28 @@ Institution: SIAT-MIC Contact: yt.li2@siat.ac.cn """ -import os -import boto3 -from fastapi import APIRouter, Request import json import asyncio import logging import datetime -from mp_api.client import MPRester from multiprocessing import Process, Manager from typing import Dict, Any, List -from constant import MP_ENDPOINT, MP_API_KEY, TIME_OUT, TOPK_RESULT, HTTP_PROXY, HTTPS_PROXY, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT - -router = APIRouter(prefix="/mp", tags=["Material Project"]) +from mp_api.client import MPRester +from utils import settings, handle_minio_upload +from error_handlers import handle_general_error logger = logging.getLogger(__name__) - -@router.get("/search") -async def search_from_material_project(request: Request): - # 打印请求日志 - logger.info(f"Received request: {request.method} {request.url}") - logger.info(f"Query parameters: {request.query_params}") - - # 解析查询参数 - search_args = parse_search_parameters(request.query_params) - - # 检查API key - if MP_API_KEY is None or MP_API_KEY == '': - return 'Material Project API CANNOT Be None' - - try: - # 执行搜索 - docs = await execute_search(search_args) - - # 处理搜索结果 - res = process_search_results(docs) - url = "" - # 返回结果 - if len(res) == 0: - return {"status": "success", "data": "No results found, please try again."} - - else: - # 上传结果到MinIO - timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - file_name = f"mp_search_results_{timestamp}.json" - - try: - minio_client = boto3.client( - 's3', - endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT, - aws_access_key_id=MINIO_ACCESS_KEY, - aws_secret_access_key=MINIO_SECRET_KEY - ) - - # 将结果写入临时文件 - with open(file_name, 'w') as f: - json.dump(res, f, indent=2) - - # 上传到MinIO - minio_client.upload_file(file_name, MINIO_BUCKET, file_name, ExtraArgs={"ACL": "private"}) - - # 生成预签名URL - url = minio_client.generate_presigned_url( - 'get_object', - Params={'Bucket': MINIO_BUCKET, 'Key': file_name}, - ExpiresIn=3600 - ) - url = url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT) - - # 删除临时文件 - os.remove(file_name) - - except Exception as e: - logger.error(f"Failed to upload to MinIO: {str(e)}") - return { - "status": "error", - "data": f"Failed to upload results to MinIO: {str(e)}" - } - - # 格式化返回结果 - res_chunk = "```json\n" + json.dumps(res[:TOPK_RESULT], indent=2) + "\n```" - res_template = f""" -好的,以下是用户的查询结果: -由于返回长度的限制,我们只能返回前{TOPK_RESULT}个结果。如下: -{res_chunk} -如果用户需要更多的结果,请提示用户修改查询条件,或者尝试使用其他查询参数。 -同时我们将全部的的查询结果上传到MinIO中,请你提示用户可以通过以下链接下载: -[Download]({url}) -""" - return {"status": "success", "data": res_template} - - except asyncio.TimeoutError: - logger.error(f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again.") - return { - "status": "error", - "data": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again." - } - - def parse_bool(param: str) -> bool | None: if not param: return None return param.lower() == 'true' - def parse_list(param: str) -> List[str] | None: if not param: return None return param.split(',') - def parse_tuple(param: str) -> tuple[float, float] | None: if not param: return None @@ -124,7 +35,6 @@ def parse_tuple(param: str) -> tuple[float, float] | None: except (ValueError, IndexError): return None - def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]: """解析搜索参数""" return { @@ -147,7 +57,6 @@ def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]: 'chunk_size': int(query_params.get('chunk_size', '5')) } - def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """处理搜索结果""" fields = [ @@ -171,13 +80,12 @@ def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: continue return res - -async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -> List[Dict[str, Any]]: +async def execute_search(search_args: Dict[str, Any], timeout: int = 30) -> List[Dict[str, Any]]: """执行搜索""" manager = Manager() queue = manager.Queue() - p = Process(target=_search_worker, args=(queue, MP_API_KEY), kwargs=search_args) + p = Process(target=_search_worker, args=(queue, settings.mp_api_key), kwargs=search_args) p.start() logger.info(f"Started worker process with PID: {p.pid}") @@ -207,14 +115,13 @@ async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) - logger.info(f"Successfully retrieved {len(result)} documents") return result - def _search_worker(queue, api_key, **kwargs): """搜索工作线程""" try: import os - os.environ['HTTP_PROXY'] = HTTP_PROXY - os.environ['HTTPS_PROXY'] = HTTPS_PROXY - mpr = MPRester(api_key, endpoint=MP_ENDPOINT) + os.environ['HTTP_PROXY'] = settings.http_proxy or '' + os.environ['HTTPS_PROXY'] = settings.https_proxy or '' + mpr = MPRester(api_key, endpoint=settings.mp_endpoint) result = mpr.materials.summary.search(**kwargs) queue.put([doc.model_dump() for doc in result]) except Exception as e: diff --git a/database/oqmd_router.py b/services/oqmd_service.py similarity index 54% rename from database/oqmd_router.py rename to services/oqmd_service.py index c5153c2..1fd3f48 100644 --- a/database/oqmd_router.py +++ b/services/oqmd_service.py @@ -3,69 +3,22 @@ Author: Yutang LI Institution: SIAT-MIC Contact: yt.li2@siat.ac.cn """ -import os -import boto3 -from fastapi import APIRouter, Request -from io import StringIO -import logging -import httpx + import datetime +import logging +import os +import httpx import pandas as pd from bs4 import BeautifulSoup from PIL import Image from playwright.async_api import async_playwright -from constant import MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT +from io import StringIO +from utils import settings, handle_minio_upload - -router = APIRouter(prefix="/oqmd", tags=["OQMD"]) logger = logging.getLogger(__name__) - -@router.get("/search") -async def search_from_oqmd_by_composition(request: Request): - # 打印请求日志 - logger.info(f"Received request: {request.method} {request.url}") - logger.info(f"Query parameters: {request.query_params}") - - try: - # 获取并解析数据 - composition = request.query_params['composition'] - html = await fetch_oqmd_data(composition) - basic_data, table_data, phase_data = parse_oqmd_html(html) - - # 渲染并保存图表 - phase_diagram_name = await render_and_save_charts(phase_data) - # 返回格式化后的响应 - return format_response(basic_data, table_data, phase_diagram_name) - - except httpx.HTTPStatusError as e: - logger.error(f"OQMD API request failed: {str(e)}") - return { - "status": "error", - "message": f"OQMD API request failed: {str(e)}" - } - except Exception as e: - logger.error(f"Unexpected error: {str(e)}") - return { - "status": "error", - "message": f"Unexpected error: {str(e)}" - } - - - async def fetch_oqmd_data(composition: str) -> str: - """ - 从OQMD获取数据 - Args: - composition: 材料组成字符串 - - Returns: - HTML内容字符串 - - Raises: - httpx.HTTPError: 当发生HTTP相关错误时抛出 - ValueError: 当响应内容无效时抛出 - """ + """从OQMD获取数据""" url = f"https://www.oqmd.org/materials/composition/{composition}" try: async with httpx.AsyncClient(timeout=30.0) as client: @@ -79,55 +32,35 @@ async def fetch_oqmd_data(composition: str) -> str: return response.text except httpx.HTTPStatusError as e: - status_code = e.response.status_code - if status_code == 401: - logger.error("OQMD API: Unauthorized access") - raise httpx.HTTPError("Unauthorized access to OQMD API") from e - elif status_code == 403: - logger.error("OQMD API: Forbidden access") - raise httpx.HTTPError("Forbidden access to OQMD API") from e - elif status_code == 404: - logger.error("OQMD API: Resource not found") - raise httpx.HTTPError("Resource not found on OQMD API") from e - elif status_code >= 500: - logger.error(f"OQMD API: Server error ({status_code})") - raise httpx.HTTPError(f"OQMD API server error ({status_code})") from e - else: - logger.error(f"OQMD API request failed: {str(e)}") - raise httpx.HTTPError(f"OQMD API request failed: {str(e)}") from e - - except httpx.TimeoutException as e: + logger.error(f"OQMD API request failed: {str(e)}") + raise + except httpx.TimeoutException: logger.error("OQMD API request timed out") - raise httpx.HTTPError("OQMD API request timed out") from e - + raise except httpx.NetworkError as e: logger.error(f"Network error occurred: {str(e)}") - raise httpx.HTTPError(f"Network error: {str(e)}") from e - + raise except ValueError as e: logger.error(f"Invalid response content: {str(e)}") - raise ValueError(f"Invalid response content: {str(e)}") from e + raise def parse_oqmd_html(html: str) -> tuple[list, str, list]: - """ - 解析OQMD HTML数据 - """ + """解析OQMD HTML数据""" soup = BeautifulSoup(html, 'html.parser') + # 解析基本数据 basic_data = [] basic_data.append(soup.find('h1').text.strip()) for script in soup.find_all('p'): if script: combined_text = "" - for element in script.contents: # 遍历

的子元素 - if element.name == 'a': # 如果是 标签 + for element in script.contents: + if element.name == 'a': url = "https://www.oqmd.org" + element['href'] combined_text += f"[{element.text.strip()}]({url}) " - else: # 如果是文本 + else: combined_text += element.text.strip() + " " basic_data.append(combined_text.strip()) - # import pdb - # pdb.set_trace() # 解析表格数据 table = soup.find('table') @@ -135,7 +68,6 @@ def parse_oqmd_html(html: str) -> tuple[list, str, list]: df = pd.read_html(StringIO(str(table)))[0] df = df.fillna('') df = df.replace([float('inf'), float('-inf')], '') - # table_data = df.to_dict(orient='records') table_data = df.to_markdown(index=False) # 提取JavaScript数据 @@ -149,15 +81,8 @@ def parse_oqmd_html(html: str) -> tuple[list, str, list]: return basic_data, table_data, phase_data - async def render_and_save_charts(script_data: list) -> str: - """ - 渲染并保存图表到MinIO - Returns: - str: 图片的预签名URL - Raises: - RuntimeError: 如果图片生成或上传失败 - """ + """渲染并保存图表到MinIO""" browser = None temp_files = [] try: @@ -200,7 +125,6 @@ async def render_and_save_charts(script_data: list) -> str: await page.wait_for_timeout(5000) # 分别截图两个图表 - # 获取placeholder元素位置并扩大截图区域 placeholder = page.locator('#placeholder') placeholder_box = await placeholder.bounding_box() await page.screenshot( @@ -213,7 +137,6 @@ async def render_and_save_charts(script_data: list) -> str: } ) - # 获取phasediagram元素位置并扩大截图区域 phasediagram = page.locator('#phasediagram') phasediagram_box = await phasediagram.bounding_box() await page.screenshot( @@ -245,61 +168,24 @@ async def render_and_save_charts(script_data: list) -> str: logger.error(f"Failed to process images: {str(e)}") raise RuntimeError(f"Image processing failed: {str(e)}") from e - # 上传到 MinIO 的逻辑 - try: - minio_client = boto3.client( - 's3', - endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT, - aws_access_key_id=MINIO_ACCESS_KEY, - aws_secret_access_key=MINIO_SECRET_KEY - ) - - bucket_name = MINIO_BUCKET - minio_client.upload_file(file_name, bucket_name, file_name, ExtraArgs={"ACL": "private"}) - - # 生成预签名 URL - url = minio_client.generate_presigned_url( - 'get_object', - Params={'Bucket': bucket_name, 'Key': file_name}, - ExpiresIn=3600 - ) - return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT) + # 上传到 MinIO + url = handle_minio_upload(file_name, file_name) + return url - except Exception as e: - logger.error(f"Failed to upload to MinIO: {str(e)}") - raise RuntimeError(f"MinIO upload failed: {str(e)}") from e - finally: - # 清理临时文件 - for temp_file in temp_files: - try: - if os.path.exists(temp_file): - os.remove(temp_file) - except Exception as e: - logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}") except Exception as e: logger.error(f"Failed to render and save charts: {str(e)}") - raise RuntimeError(f"Chart rendering failed: {str(e)}") from e + raise finally: + # 清理临时文件 + for temp_file in temp_files: + try: + if os.path.exists(temp_file): + os.remove(temp_file) + except Exception as e: + logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}") # 确保浏览器关闭 if browser: try: await browser.close() except Exception as e: logger.warning(f"Failed to close browser: {str(e)}") - -def format_response(basic_data: list, table_data: str, phase_data: str) -> str: - """ - 格式化响应数据 - """ - response = "### OQMD Data\n" - for item in basic_data: - response += f"**{item}**\n" - response += "\n### Phase Diagram\n\n" - response += f"![Phase Diagram]({phase_data})\n\n" - response += "\n### Compounds at this composition\n\n" - response += f"{table_data}\n" - - return { - "status": "success", - "data": response - } diff --git a/utils.py b/utils.py index 417f879..aa2223f 100644 --- a/utils.py +++ b/utils.py @@ -1,8 +1,102 @@ +""" +Author: Yutang LI +Institution: SIAT-MIC +Contact: yt.li2@siat.ac.cn +""" + +import os +import boto3 import logging -from multiprocessing import Process, Manager -import asyncio -from typing import Dict, Any, List -from mp_api.client import MPRester +from typing import Optional +from pydantic import Field +from pydantic_settings import BaseSettings logger = logging.getLogger(__name__) +class Settings(BaseSettings): + # Material Project + mp_api_key: Optional[str] = Field(None, env="MP_API_KEY") + mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT") + + # Proxy + http_proxy: Optional[str] = Field(None, env="HTTP_PROXY") + https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY") + + # FairChem + fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH") + fmax: Optional[float] = Field(0.05, env="FMAX") + + # MinIO + minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT") + internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT") + minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY") + minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY") + minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET") + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + +def get_minio_client(settings: Settings): + """获取MinIO客户端""" + return boto3.client( + 's3', + endpoint_url=settings.internal_minio_endpoint or settings.minio_endpoint, + aws_access_key_id=settings.minio_access_key, + aws_secret_access_key=settings.minio_secret_key + ) + +def handle_minio_upload(file_path: str, file_name: str) -> str: + """统一处理MinIO上传""" + try: + client = get_minio_client(settings) + client.upload_file(file_path, settings.minio_bucket, file_name, ExtraArgs={"ACL": "private"}) + + # 生成预签名 URL + url = client.generate_presigned_url( + 'get_object', + Params={'Bucket': settings.minio_bucket, 'Key': file_name}, + ExpiresIn=3600 + ) + return url.replace(settings.internal_minio_endpoint or "", settings.minio_endpoint) + except Exception as e: + from error_handlers import handle_minio_error + return handle_minio_error(e) + +def setup_logging(): + """配置日志记录""" + logging.config.dictConfig({ + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'standard': { + 'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + 'datefmt': '%Y-%m-%d %H:%M:%S' + }, + }, + 'handlers': { + 'console': { + 'level': 'INFO', + 'class': 'logging.StreamHandler', + 'formatter': 'standard' + }, + 'file': { + 'level': 'DEBUG', + 'class': 'logging.handlers.RotatingFileHandler', + 'filename': 'mars_toolkit.log', + 'maxBytes': 10485760, # 10MB + 'backupCount': 5, + 'formatter': 'standard' + } + }, + 'loggers': { + '': { + 'handlers': ['console', 'file'], + 'level': 'INFO', + 'propagate': True + } + } + }) + +# 初始化配置 +settings = Settings()