This commit is contained in:
2025-01-04 20:08:03 +08:00
parent 1f8557a918
commit f214f51e12
9 changed files with 585 additions and 0 deletions

0
database/__init__.py Normal file
View File

View File

@@ -0,0 +1,163 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
from fastapi import APIRouter, Request
import json
import asyncio
import logging
from mp_api.client import MPRester
from multiprocessing import Process, Manager
from typing import Dict, Any, List
from constant import MP_API_KEY, TIME_OUT, TOPK_RESULT
router = APIRouter(prefix="/mp", tags=["Material Project"])
logger = logging.getLogger(__name__)
@router.get("/search")
async def search_from_material_project(request: Request):
# 打印请求日志
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Query parameters: {request.query_params}")
# 解析查询参数
search_args = parse_search_parameters(request.query_params)
# 检查API key
if MP_API_KEY is None or MP_API_KEY == '':
return 'Material Project API CANNOT Be None'
try:
# 执行搜索
docs = await execute_search(search_args)
# 处理搜索结果
res = process_search_results(docs)
# 返回结果
if len(res) >= TOPK_RESULT:
return json.dumps(res[:TOPK_RESULT], indent=2)
return json.dumps(res, indent=2)
except asyncio.TimeoutError:
logger.error(f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again.")
return {"error": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again."}
def parse_bool(param: str) -> bool | None:
if not param:
return None
return param.lower() == 'true'
def parse_list(param: str) -> List[str] | None:
if not param:
return None
return param.split(',')
def parse_tuple(param: str) -> tuple[float, float] | None:
if not param:
return None
try:
values = param.split(',')
return (float(values[0]), float(values[1]))
except (ValueError, IndexError):
return None
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
"""解析搜索参数"""
return {
'band_gap': parse_tuple(query_params.get('band_gap')),
'chemsys': parse_list(query_params.get('chemsys')),
'crystal_system': parse_list(query_params.get('crystal_system')),
'density': parse_tuple(query_params.get('density')),
'formation_energy': parse_tuple(query_params.get('formation_energy')),
'elements': parse_list(query_params.get('elements')),
'exclude_elements': parse_list(query_params.get('exclude_elements')),
'formula': parse_list(query_params.get('formula')),
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
'is_metal': parse_bool(query_params.get('is_metal')),
'is_stable': parse_bool(query_params.get('is_stable')),
'magnetic_ordering': query_params.get('magnetic_ordering'),
'material_ids': parse_list(query_params.get('material_ids')),
'total_energy': parse_tuple(query_params.get('total_energy')),
'num_elements': parse_tuple(query_params.get('num_elements')),
'volume': parse_tuple(query_params.get('volume')),
'chunk_size': int(query_params.get('chunk_size', '5'))
}
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""处理搜索结果"""
fields = [
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
'universal_anisotropy', 'theoretical'
]
res = []
for doc in docs:
try:
new_docs = {}
for field_name in fields:
new_docs[field_name] = doc.get(field_name, '')
res.append(new_docs)
except Exception as e:
logger.warning(f"Error processing document: {str(e)}")
continue
return res
async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -> List[Dict[str, Any]]:
"""执行搜索"""
manager = Manager()
queue = manager.Queue()
p = Process(target=_search_worker, args=(queue, MP_API_KEY), kwargs=search_args)
p.start()
logger.info(f"Started worker process with PID: {p.pid}")
p.join(timeout=timeout)
if p.is_alive():
logger.warning(f"Terminating worker process {p.pid} due to timeout")
p.terminate()
p.join()
raise asyncio.TimeoutError(f"Request timed out after {timeout} seconds")
try:
if queue.empty():
logger.warning("Queue is empty after process completion")
else:
logger.info("Queue contains data, retrieving...")
result = queue.get(timeout=15)
except queue.Empty:
logger.error("Failed to retrieve data from queue")
raise RuntimeError("Failed to retrieve data from worker process")
if isinstance(result, Exception):
logger.error(f"Error in search worker: {str(result)}")
raise result
logger.info(f"Successfully retrieved {len(result)} documents")
return result
def _search_worker(queue, api_key, **kwargs):
"""搜索工作线程"""
try:
mpr = MPRester(api_key)
result = mpr.materials.summary.search(**kwargs)
queue.put([doc.model_dump() for doc in result])
except Exception as e:
queue.put(e)

305
database/oqmd_router.py Normal file
View File

@@ -0,0 +1,305 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import boto3
from fastapi import APIRouter, Request
from io import StringIO
import logging
import httpx
import datetime
import pandas as pd
from bs4 import BeautifulSoup
from PIL import Image
from playwright.async_api import async_playwright
from constant import MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
router = APIRouter(prefix="/oqmd", tags=["OQMD"])
logger = logging.getLogger(__name__)
@router.get("/search")
async def search_from_oqmd_by_composition(request: Request):
# 打印请求日志
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Query parameters: {request.query_params}")
try:
# 获取并解析数据
composition = request.query_params['composition']
html = await fetch_oqmd_data(composition)
basic_data, table_data, phase_data = parse_oqmd_html(html)
# 渲染并保存图表
phase_diagram_name = await render_and_save_charts(phase_data)
# 返回格式化后的响应
return format_response(basic_data, table_data, phase_diagram_name)
except httpx.HTTPStatusError as e:
logger.error(f"OQMD API request failed: {str(e)}")
return {
"status": "error",
"message": f"OQMD API request failed: {str(e)}"
}
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return {
"status": "error",
"message": f"Unexpected error: {str(e)}"
}
async def fetch_oqmd_data(composition: str) -> str:
"""
从OQMD获取数据
Args:
composition: 材料组成字符串
Returns:
HTML内容字符串
Raises:
httpx.HTTPError: 当发生HTTP相关错误时抛出
ValueError: 当响应内容无效时抛出
"""
url = f"https://www.oqmd.org/materials/composition/{composition}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 验证响应内容
if not response.text or len(response.text) < 100:
raise ValueError("Invalid response content from OQMD API")
return response.text
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
if status_code == 401:
logger.error("OQMD API: Unauthorized access")
raise httpx.HTTPError("Unauthorized access to OQMD API") from e
elif status_code == 403:
logger.error("OQMD API: Forbidden access")
raise httpx.HTTPError("Forbidden access to OQMD API") from e
elif status_code == 404:
logger.error("OQMD API: Resource not found")
raise httpx.HTTPError("Resource not found on OQMD API") from e
elif status_code >= 500:
logger.error(f"OQMD API: Server error ({status_code})")
raise httpx.HTTPError(f"OQMD API server error ({status_code})") from e
else:
logger.error(f"OQMD API request failed: {str(e)}")
raise httpx.HTTPError(f"OQMD API request failed: {str(e)}") from e
except httpx.TimeoutException as e:
logger.error("OQMD API request timed out")
raise httpx.HTTPError("OQMD API request timed out") from e
except httpx.NetworkError as e:
logger.error(f"Network error occurred: {str(e)}")
raise httpx.HTTPError(f"Network error: {str(e)}") from e
except ValueError as e:
logger.error(f"Invalid response content: {str(e)}")
raise ValueError(f"Invalid response content: {str(e)}") from e
def parse_oqmd_html(html: str) -> tuple[list, str, list]:
"""
解析OQMD HTML数据
"""
soup = BeautifulSoup(html, 'html.parser')
# 解析基本数据
basic_data = []
basic_data.append(soup.find('h1').text.strip())
for script in soup.find_all('p'):
if script:
combined_text = ""
for element in script.contents: # 遍历 <p> 的子元素
if element.name == 'a': # 如果是 <a> 标签
url = "https://www.oqmd.org" + element['href']
combined_text += f"[{element.text.strip()}]({url}) "
else: # 如果是文本
combined_text += element.text.strip() + " "
basic_data.append(combined_text.strip())
# import pdb
# pdb.set_trace()
# 解析表格数据
table = soup.find('table')
if table:
df = pd.read_html(StringIO(str(table)))[0]
df = df.fillna('')
df = df.replace([float('inf'), float('-inf')], '')
# table_data = df.to_dict(orient='records')
table_data = df.to_markdown(index=False)
# 提取JavaScript数据
phase_data = []
for script in soup.find_all('script'):
if script.string and '$(function()' in script.string:
phase_data.append({
'type': script.get('type', 'text/javascript'),
'content': script.string.strip()
})
return basic_data, table_data, phase_data
async def render_and_save_charts(script_data: list) -> str:
"""
渲染并保存图表到MinIO
Returns:
str: 图片的预签名URL
Raises:
RuntimeError: 如果图片生成或上传失败
"""
browser = None
temp_files = []
try:
# 初始化Playwright
async with async_playwright() as p:
browser = await p.chromium.launch(headless=True)
page = await browser.new_page()
# 构建包含 JavaScript 的 HTML 代码
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/jquery.flot@0.8.3/jquery.flot.js"></script>
<title>Phase Diagram</title>
</head>
<body>
<div class="diagram">
<div id="placeholder" width="200" height="400" style="direction: ltr; position: absolute; left: 550px; top: 0px; width: 200px; height: 400px;"></div>
<script>
{placeholder_content}
</script>
<div id="phasediagram" width="500" height="400" style="direction: ltr; position: absolute; left: 0px; top: 0px; width: 500px; height: 400px;"></div>
<script>
{phasediagram_content}
</script>
</div>
</body>
</html>
"""
html_content = html_content.format(
placeholder_content=script_data[0]['content'],
phasediagram_content=script_data[1]['content'])
await page.set_content(html_content)
await page.wait_for_timeout(5000)
# 分别截图两个图表
# 获取placeholder元素位置并扩大截图区域
placeholder = page.locator('#placeholder')
placeholder_box = await placeholder.bounding_box()
await page.screenshot(
path="placeholder.png",
clip={
'x': placeholder_box['x'],
'y': placeholder_box['y'],
'width': placeholder_box['width'] + 40,
'height': placeholder_box['height'] + 40
}
)
# 获取phasediagram元素位置并扩大截图区域
phasediagram = page.locator('#phasediagram')
phasediagram_box = await phasediagram.bounding_box()
await page.screenshot(
path="phasediagram.png",
clip={
'x': phasediagram_box['x'],
'y': phasediagram_box['y'],
'width': phasediagram_box['width'] + 40,
'height': phasediagram_box['height'] + 40
}
)
await browser.close()
# 拼接图片
try:
img1 = Image.open("placeholder.png")
temp_files.append("placeholder.png")
img2 = Image.open("phasediagram.png")
temp_files.append("phasediagram.png")
new_img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
new_img.paste(img2, (0, 0))
new_img.paste(img1, (img2.width, 0))
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"oqmd_phase_diagram_{timestamp}.png"
new_img.save(file_name)
temp_files.append(file_name)
except Exception as e:
logger.error(f"Failed to process images: {str(e)}")
raise RuntimeError(f"Image processing failed: {str(e)}") from e
# 上传到 MinIO 的逻辑
try:
minio_client = boto3.client(
's3',
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
aws_access_key_id=MINIO_ACCESS_KEY,
aws_secret_access_key=MINIO_SECRET_KEY
)
bucket_name = MINIO_BUCKET
minio_client.upload_file(file_name, bucket_name, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名 URL
url = minio_client.generate_presigned_url(
'get_object',
Params={'Bucket': bucket_name, 'Key': file_name},
ExpiresIn=3600
)
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
except Exception as e:
logger.error(f"Failed to upload to MinIO: {str(e)}")
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
finally:
# 清理临时文件
for temp_file in temp_files:
try:
if os.path.exists(temp_file):
os.remove(temp_file)
except Exception as e:
logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}")
except Exception as e:
logger.error(f"Failed to render and save charts: {str(e)}")
raise RuntimeError(f"Chart rendering failed: {str(e)}") from e
finally:
# 确保浏览器关闭
if browser:
try:
await browser.close()
except Exception as e:
logger.warning(f"Failed to close browser: {str(e)}")
def format_response(basic_data: list, table_data: str, phase_data: str) -> str:
"""
格式化响应数据
"""
response = "### OQMD Data\n"
for item in basic_data:
response += f"**{item}**\n"
response += "\n### Phase Diagram\n\n"
response += f"![Phase Diagram]({phase_data})\n\n"
response += "\n### Compounds at this composition\n\n"
response += f"{table_data}\n"
return {
"status": "success",
"data": response
}