306 lines
11 KiB
Python
306 lines
11 KiB
Python
"""
|
|
Author: Yutang LI
|
|
Institution: SIAT-MIC
|
|
Contact: yt.li2@siat.ac.cn
|
|
"""
|
|
import os
|
|
import boto3
|
|
from fastapi import APIRouter, Request
|
|
from io import StringIO
|
|
import logging
|
|
import httpx
|
|
import datetime
|
|
import pandas as pd
|
|
from bs4 import BeautifulSoup
|
|
from PIL import Image
|
|
from playwright.async_api import async_playwright
|
|
from constant import MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
|
|
|
|
|
|
router = APIRouter(prefix="/oqmd", tags=["OQMD"])
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@router.get("/search")
|
|
async def search_from_oqmd_by_composition(request: Request):
|
|
# 打印请求日志
|
|
logger.info(f"Received request: {request.method} {request.url}")
|
|
logger.info(f"Query parameters: {request.query_params}")
|
|
|
|
try:
|
|
# 获取并解析数据
|
|
composition = request.query_params['composition']
|
|
html = await fetch_oqmd_data(composition)
|
|
basic_data, table_data, phase_data = parse_oqmd_html(html)
|
|
|
|
# 渲染并保存图表
|
|
phase_diagram_name = await render_and_save_charts(phase_data)
|
|
# 返回格式化后的响应
|
|
return format_response(basic_data, table_data, phase_diagram_name)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"OQMD API request failed: {str(e)}")
|
|
return {
|
|
"status": "error",
|
|
"message": f"OQMD API request failed: {str(e)}"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error: {str(e)}")
|
|
return {
|
|
"status": "error",
|
|
"message": f"Unexpected error: {str(e)}"
|
|
}
|
|
|
|
|
|
|
|
async def fetch_oqmd_data(composition: str) -> str:
|
|
"""
|
|
从OQMD获取数据
|
|
Args:
|
|
composition: 材料组成字符串
|
|
|
|
Returns:
|
|
HTML内容字符串
|
|
|
|
Raises:
|
|
httpx.HTTPError: 当发生HTTP相关错误时抛出
|
|
ValueError: 当响应内容无效时抛出
|
|
"""
|
|
url = f"https://www.oqmd.org/materials/composition/{composition}"
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.get(url)
|
|
response.raise_for_status()
|
|
|
|
# 验证响应内容
|
|
if not response.text or len(response.text) < 100:
|
|
raise ValueError("Invalid response content from OQMD API")
|
|
|
|
return response.text
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
status_code = e.response.status_code
|
|
if status_code == 401:
|
|
logger.error("OQMD API: Unauthorized access")
|
|
raise httpx.HTTPError("Unauthorized access to OQMD API") from e
|
|
elif status_code == 403:
|
|
logger.error("OQMD API: Forbidden access")
|
|
raise httpx.HTTPError("Forbidden access to OQMD API") from e
|
|
elif status_code == 404:
|
|
logger.error("OQMD API: Resource not found")
|
|
raise httpx.HTTPError("Resource not found on OQMD API") from e
|
|
elif status_code >= 500:
|
|
logger.error(f"OQMD API: Server error ({status_code})")
|
|
raise httpx.HTTPError(f"OQMD API server error ({status_code})") from e
|
|
else:
|
|
logger.error(f"OQMD API request failed: {str(e)}")
|
|
raise httpx.HTTPError(f"OQMD API request failed: {str(e)}") from e
|
|
|
|
except httpx.TimeoutException as e:
|
|
logger.error("OQMD API request timed out")
|
|
raise httpx.HTTPError("OQMD API request timed out") from e
|
|
|
|
except httpx.NetworkError as e:
|
|
logger.error(f"Network error occurred: {str(e)}")
|
|
raise httpx.HTTPError(f"Network error: {str(e)}") from e
|
|
|
|
except ValueError as e:
|
|
logger.error(f"Invalid response content: {str(e)}")
|
|
raise ValueError(f"Invalid response content: {str(e)}") from e
|
|
|
|
def parse_oqmd_html(html: str) -> tuple[list, str, list]:
|
|
"""
|
|
解析OQMD HTML数据
|
|
"""
|
|
soup = BeautifulSoup(html, 'html.parser')
|
|
# 解析基本数据
|
|
basic_data = []
|
|
basic_data.append(soup.find('h1').text.strip())
|
|
for script in soup.find_all('p'):
|
|
if script:
|
|
combined_text = ""
|
|
for element in script.contents: # 遍历 <p> 的子元素
|
|
if element.name == 'a': # 如果是 <a> 标签
|
|
url = "https://www.oqmd.org" + element['href']
|
|
combined_text += f"[{element.text.strip()}]({url}) "
|
|
else: # 如果是文本
|
|
combined_text += element.text.strip() + " "
|
|
basic_data.append(combined_text.strip())
|
|
# import pdb
|
|
# pdb.set_trace()
|
|
|
|
# 解析表格数据
|
|
table = soup.find('table')
|
|
if table:
|
|
df = pd.read_html(StringIO(str(table)))[0]
|
|
df = df.fillna('')
|
|
df = df.replace([float('inf'), float('-inf')], '')
|
|
# table_data = df.to_dict(orient='records')
|
|
table_data = df.to_markdown(index=False)
|
|
|
|
# 提取JavaScript数据
|
|
phase_data = []
|
|
for script in soup.find_all('script'):
|
|
if script.string and '$(function()' in script.string:
|
|
phase_data.append({
|
|
'type': script.get('type', 'text/javascript'),
|
|
'content': script.string.strip()
|
|
})
|
|
|
|
return basic_data, table_data, phase_data
|
|
|
|
|
|
async def render_and_save_charts(script_data: list) -> str:
|
|
"""
|
|
渲染并保存图表到MinIO
|
|
Returns:
|
|
str: 图片的预签名URL
|
|
Raises:
|
|
RuntimeError: 如果图片生成或上传失败
|
|
"""
|
|
browser = None
|
|
temp_files = []
|
|
try:
|
|
# 初始化Playwright
|
|
async with async_playwright() as p:
|
|
browser = await p.chromium.launch(headless=True)
|
|
page = await browser.new_page()
|
|
|
|
# 构建包含 JavaScript 的 HTML 代码
|
|
html_content = """
|
|
<!DOCTYPE html>
|
|
<html lang="en">
|
|
<head>
|
|
<meta charset="UTF-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
|
|
<script src="https://cdn.jsdelivr.net/npm/jquery.flot@0.8.3/jquery.flot.js"></script>
|
|
<title>Phase Diagram</title>
|
|
</head>
|
|
<body>
|
|
<div class="diagram">
|
|
<div id="placeholder" width="200" height="400" style="direction: ltr; position: absolute; left: 550px; top: 0px; width: 200px; height: 400px;"></div>
|
|
<script>
|
|
{placeholder_content}
|
|
</script>
|
|
|
|
<div id="phasediagram" width="500" height="400" style="direction: ltr; position: absolute; left: 0px; top: 0px; width: 500px; height: 400px;"></div>
|
|
<script>
|
|
{phasediagram_content}
|
|
</script>
|
|
</div>
|
|
</body>
|
|
</html>
|
|
"""
|
|
html_content = html_content.format(
|
|
placeholder_content=script_data[0]['content'],
|
|
phasediagram_content=script_data[1]['content'])
|
|
|
|
await page.set_content(html_content)
|
|
await page.wait_for_timeout(5000)
|
|
|
|
# 分别截图两个图表
|
|
# 获取placeholder元素位置并扩大截图区域
|
|
placeholder = page.locator('#placeholder')
|
|
placeholder_box = await placeholder.bounding_box()
|
|
await page.screenshot(
|
|
path="placeholder.png",
|
|
clip={
|
|
'x': placeholder_box['x'],
|
|
'y': placeholder_box['y'],
|
|
'width': placeholder_box['width'] + 40,
|
|
'height': placeholder_box['height'] + 40
|
|
}
|
|
)
|
|
|
|
# 获取phasediagram元素位置并扩大截图区域
|
|
phasediagram = page.locator('#phasediagram')
|
|
phasediagram_box = await phasediagram.bounding_box()
|
|
await page.screenshot(
|
|
path="phasediagram.png",
|
|
clip={
|
|
'x': phasediagram_box['x'],
|
|
'y': phasediagram_box['y'],
|
|
'width': phasediagram_box['width'] + 40,
|
|
'height': phasediagram_box['height'] + 40
|
|
}
|
|
)
|
|
|
|
await browser.close()
|
|
|
|
# 拼接图片
|
|
try:
|
|
img1 = Image.open("placeholder.png")
|
|
temp_files.append("placeholder.png")
|
|
img2 = Image.open("phasediagram.png")
|
|
temp_files.append("phasediagram.png")
|
|
new_img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
|
|
new_img.paste(img2, (0, 0))
|
|
new_img.paste(img1, (img2.width, 0))
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
file_name = f"oqmd_phase_diagram_{timestamp}.png"
|
|
new_img.save(file_name)
|
|
temp_files.append(file_name)
|
|
except Exception as e:
|
|
logger.error(f"Failed to process images: {str(e)}")
|
|
raise RuntimeError(f"Image processing failed: {str(e)}") from e
|
|
|
|
# 上传到 MinIO 的逻辑
|
|
try:
|
|
minio_client = boto3.client(
|
|
's3',
|
|
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
|
|
aws_access_key_id=MINIO_ACCESS_KEY,
|
|
aws_secret_access_key=MINIO_SECRET_KEY
|
|
)
|
|
|
|
bucket_name = MINIO_BUCKET
|
|
minio_client.upload_file(file_name, bucket_name, file_name, ExtraArgs={"ACL": "private"})
|
|
|
|
# 生成预签名 URL
|
|
url = minio_client.generate_presigned_url(
|
|
'get_object',
|
|
Params={'Bucket': bucket_name, 'Key': file_name},
|
|
ExpiresIn=3600
|
|
)
|
|
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to upload to MinIO: {str(e)}")
|
|
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
|
|
finally:
|
|
# 清理临时文件
|
|
for temp_file in temp_files:
|
|
try:
|
|
if os.path.exists(temp_file):
|
|
os.remove(temp_file)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to render and save charts: {str(e)}")
|
|
raise RuntimeError(f"Chart rendering failed: {str(e)}") from e
|
|
finally:
|
|
# 确保浏览器关闭
|
|
if browser:
|
|
try:
|
|
await browser.close()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to close browser: {str(e)}")
|
|
|
|
def format_response(basic_data: list, table_data: str, phase_data: str) -> str:
|
|
"""
|
|
格式化响应数据
|
|
"""
|
|
response = "### OQMD Data\n"
|
|
for item in basic_data:
|
|
response += f"**{item}**\n"
|
|
response += "\n### Phase Diagram\n\n"
|
|
response += f"\n\n"
|
|
response += "\n### Compounds at this composition\n\n"
|
|
response += f"{table_data}\n"
|
|
|
|
return {
|
|
"status": "success",
|
|
"data": response
|
|
}
|