This commit is contained in:
2025-01-04 20:08:03 +08:00
parent 1f8557a918
commit f214f51e12
9 changed files with 585 additions and 0 deletions

5
.gitignore vendored
View File

@@ -1,4 +1,9 @@
# ---> Python
# model
*.pt
/home/ubuntu/sas0/LYT/mars1215/mars_toolkit/model/eqV2_86M_omat_mp_salex.pt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

BIN
chart_screenshot.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

10
constant.py Normal file
View File

@@ -0,0 +1,10 @@
TOPK_RESULT = 5
TIME_OUT = 30
MP_API_KEY = "gfBp2in8qxm9Xm2SwLKFwNxDyZvNTAEt"
# MinIO configuration
MINIO_ENDPOINT = "https://s3-api.siat-mic.com"
INTERNEL_MINIO_ENDPOINT = "http://100.85.52.31:9000" # 内网地址,如果有就填,上传会更快。
MINIO_ACCESS_KEY = "9bUtQL1Gpo9JB6o3pSGr"
MINIO_SECRET_KEY = "1Qug5H73R3kP8boIHvdVcFtcb1jU9GRWnlmMpx0g"
MINIO_BUCKET = "temp"

0
database/__init__.py Normal file
View File

View File

@@ -0,0 +1,163 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
from fastapi import APIRouter, Request
import json
import asyncio
import logging
from mp_api.client import MPRester
from multiprocessing import Process, Manager
from typing import Dict, Any, List
from constant import MP_API_KEY, TIME_OUT, TOPK_RESULT
router = APIRouter(prefix="/mp", tags=["Material Project"])
logger = logging.getLogger(__name__)
@router.get("/search")
async def search_from_material_project(request: Request):
# 打印请求日志
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Query parameters: {request.query_params}")
# 解析查询参数
search_args = parse_search_parameters(request.query_params)
# 检查API key
if MP_API_KEY is None or MP_API_KEY == '':
return 'Material Project API CANNOT Be None'
try:
# 执行搜索
docs = await execute_search(search_args)
# 处理搜索结果
res = process_search_results(docs)
# 返回结果
if len(res) >= TOPK_RESULT:
return json.dumps(res[:TOPK_RESULT], indent=2)
return json.dumps(res, indent=2)
except asyncio.TimeoutError:
logger.error(f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again.")
return {"error": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again."}
def parse_bool(param: str) -> bool | None:
if not param:
return None
return param.lower() == 'true'
def parse_list(param: str) -> List[str] | None:
if not param:
return None
return param.split(',')
def parse_tuple(param: str) -> tuple[float, float] | None:
if not param:
return None
try:
values = param.split(',')
return (float(values[0]), float(values[1]))
except (ValueError, IndexError):
return None
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
"""解析搜索参数"""
return {
'band_gap': parse_tuple(query_params.get('band_gap')),
'chemsys': parse_list(query_params.get('chemsys')),
'crystal_system': parse_list(query_params.get('crystal_system')),
'density': parse_tuple(query_params.get('density')),
'formation_energy': parse_tuple(query_params.get('formation_energy')),
'elements': parse_list(query_params.get('elements')),
'exclude_elements': parse_list(query_params.get('exclude_elements')),
'formula': parse_list(query_params.get('formula')),
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
'is_metal': parse_bool(query_params.get('is_metal')),
'is_stable': parse_bool(query_params.get('is_stable')),
'magnetic_ordering': query_params.get('magnetic_ordering'),
'material_ids': parse_list(query_params.get('material_ids')),
'total_energy': parse_tuple(query_params.get('total_energy')),
'num_elements': parse_tuple(query_params.get('num_elements')),
'volume': parse_tuple(query_params.get('volume')),
'chunk_size': int(query_params.get('chunk_size', '5'))
}
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""处理搜索结果"""
fields = [
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
'universal_anisotropy', 'theoretical'
]
res = []
for doc in docs:
try:
new_docs = {}
for field_name in fields:
new_docs[field_name] = doc.get(field_name, '')
res.append(new_docs)
except Exception as e:
logger.warning(f"Error processing document: {str(e)}")
continue
return res
async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -> List[Dict[str, Any]]:
"""执行搜索"""
manager = Manager()
queue = manager.Queue()
p = Process(target=_search_worker, args=(queue, MP_API_KEY), kwargs=search_args)
p.start()
logger.info(f"Started worker process with PID: {p.pid}")
p.join(timeout=timeout)
if p.is_alive():
logger.warning(f"Terminating worker process {p.pid} due to timeout")
p.terminate()
p.join()
raise asyncio.TimeoutError(f"Request timed out after {timeout} seconds")
try:
if queue.empty():
logger.warning("Queue is empty after process completion")
else:
logger.info("Queue contains data, retrieving...")
result = queue.get(timeout=15)
except queue.Empty:
logger.error("Failed to retrieve data from queue")
raise RuntimeError("Failed to retrieve data from worker process")
if isinstance(result, Exception):
logger.error(f"Error in search worker: {str(result)}")
raise result
logger.info(f"Successfully retrieved {len(result)} documents")
return result
def _search_worker(queue, api_key, **kwargs):
"""搜索工作线程"""
try:
mpr = MPRester(api_key)
result = mpr.materials.summary.search(**kwargs)
queue.put([doc.model_dump() for doc in result])
except Exception as e:
queue.put(e)

305
database/oqmd_router.py Normal file
View File

@@ -0,0 +1,305 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import boto3
from fastapi import APIRouter, Request
from io import StringIO
import logging
import httpx
import datetime
import pandas as pd
from bs4 import BeautifulSoup
from PIL import Image
from playwright.async_api import async_playwright
from constant import MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
router = APIRouter(prefix="/oqmd", tags=["OQMD"])
logger = logging.getLogger(__name__)
@router.get("/search")
async def search_from_oqmd_by_composition(request: Request):
# 打印请求日志
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Query parameters: {request.query_params}")
try:
# 获取并解析数据
composition = request.query_params['composition']
html = await fetch_oqmd_data(composition)
basic_data, table_data, phase_data = parse_oqmd_html(html)
# 渲染并保存图表
phase_diagram_name = await render_and_save_charts(phase_data)
# 返回格式化后的响应
return format_response(basic_data, table_data, phase_diagram_name)
except httpx.HTTPStatusError as e:
logger.error(f"OQMD API request failed: {str(e)}")
return {
"status": "error",
"message": f"OQMD API request failed: {str(e)}"
}
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return {
"status": "error",
"message": f"Unexpected error: {str(e)}"
}
async def fetch_oqmd_data(composition: str) -> str:
"""
从OQMD获取数据
Args:
composition: 材料组成字符串
Returns:
HTML内容字符串
Raises:
httpx.HTTPError: 当发生HTTP相关错误时抛出
ValueError: 当响应内容无效时抛出
"""
url = f"https://www.oqmd.org/materials/composition/{composition}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 验证响应内容
if not response.text or len(response.text) < 100:
raise ValueError("Invalid response content from OQMD API")
return response.text
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
if status_code == 401:
logger.error("OQMD API: Unauthorized access")
raise httpx.HTTPError("Unauthorized access to OQMD API") from e
elif status_code == 403:
logger.error("OQMD API: Forbidden access")
raise httpx.HTTPError("Forbidden access to OQMD API") from e
elif status_code == 404:
logger.error("OQMD API: Resource not found")
raise httpx.HTTPError("Resource not found on OQMD API") from e
elif status_code >= 500:
logger.error(f"OQMD API: Server error ({status_code})")
raise httpx.HTTPError(f"OQMD API server error ({status_code})") from e
else:
logger.error(f"OQMD API request failed: {str(e)}")
raise httpx.HTTPError(f"OQMD API request failed: {str(e)}") from e
except httpx.TimeoutException as e:
logger.error("OQMD API request timed out")
raise httpx.HTTPError("OQMD API request timed out") from e
except httpx.NetworkError as e:
logger.error(f"Network error occurred: {str(e)}")
raise httpx.HTTPError(f"Network error: {str(e)}") from e
except ValueError as e:
logger.error(f"Invalid response content: {str(e)}")
raise ValueError(f"Invalid response content: {str(e)}") from e
def parse_oqmd_html(html: str) -> tuple[list, str, list]:
"""
解析OQMD HTML数据
"""
soup = BeautifulSoup(html, 'html.parser')
# 解析基本数据
basic_data = []
basic_data.append(soup.find('h1').text.strip())
for script in soup.find_all('p'):
if script:
combined_text = ""
for element in script.contents: # 遍历 <p> 的子元素
if element.name == 'a': # 如果是 <a> 标签
url = "https://www.oqmd.org" + element['href']
combined_text += f"[{element.text.strip()}]({url}) "
else: # 如果是文本
combined_text += element.text.strip() + " "
basic_data.append(combined_text.strip())
# import pdb
# pdb.set_trace()
# 解析表格数据
table = soup.find('table')
if table:
df = pd.read_html(StringIO(str(table)))[0]
df = df.fillna('')
df = df.replace([float('inf'), float('-inf')], '')
# table_data = df.to_dict(orient='records')
table_data = df.to_markdown(index=False)
# 提取JavaScript数据
phase_data = []
for script in soup.find_all('script'):
if script.string and '$(function()' in script.string:
phase_data.append({
'type': script.get('type', 'text/javascript'),
'content': script.string.strip()
})
return basic_data, table_data, phase_data
async def render_and_save_charts(script_data: list) -> str:
"""
渲染并保存图表到MinIO
Returns:
str: 图片的预签名URL
Raises:
RuntimeError: 如果图片生成或上传失败
"""
browser = None
temp_files = []
try:
# 初始化Playwright
async with async_playwright() as p:
browser = await p.chromium.launch(headless=True)
page = await browser.new_page()
# 构建包含 JavaScript 的 HTML 代码
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/jquery.flot@0.8.3/jquery.flot.js"></script>
<title>Phase Diagram</title>
</head>
<body>
<div class="diagram">
<div id="placeholder" width="200" height="400" style="direction: ltr; position: absolute; left: 550px; top: 0px; width: 200px; height: 400px;"></div>
<script>
{placeholder_content}
</script>
<div id="phasediagram" width="500" height="400" style="direction: ltr; position: absolute; left: 0px; top: 0px; width: 500px; height: 400px;"></div>
<script>
{phasediagram_content}
</script>
</div>
</body>
</html>
"""
html_content = html_content.format(
placeholder_content=script_data[0]['content'],
phasediagram_content=script_data[1]['content'])
await page.set_content(html_content)
await page.wait_for_timeout(5000)
# 分别截图两个图表
# 获取placeholder元素位置并扩大截图区域
placeholder = page.locator('#placeholder')
placeholder_box = await placeholder.bounding_box()
await page.screenshot(
path="placeholder.png",
clip={
'x': placeholder_box['x'],
'y': placeholder_box['y'],
'width': placeholder_box['width'] + 40,
'height': placeholder_box['height'] + 40
}
)
# 获取phasediagram元素位置并扩大截图区域
phasediagram = page.locator('#phasediagram')
phasediagram_box = await phasediagram.bounding_box()
await page.screenshot(
path="phasediagram.png",
clip={
'x': phasediagram_box['x'],
'y': phasediagram_box['y'],
'width': phasediagram_box['width'] + 40,
'height': phasediagram_box['height'] + 40
}
)
await browser.close()
# 拼接图片
try:
img1 = Image.open("placeholder.png")
temp_files.append("placeholder.png")
img2 = Image.open("phasediagram.png")
temp_files.append("phasediagram.png")
new_img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
new_img.paste(img2, (0, 0))
new_img.paste(img1, (img2.width, 0))
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"oqmd_phase_diagram_{timestamp}.png"
new_img.save(file_name)
temp_files.append(file_name)
except Exception as e:
logger.error(f"Failed to process images: {str(e)}")
raise RuntimeError(f"Image processing failed: {str(e)}") from e
# 上传到 MinIO 的逻辑
try:
minio_client = boto3.client(
's3',
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
aws_access_key_id=MINIO_ACCESS_KEY,
aws_secret_access_key=MINIO_SECRET_KEY
)
bucket_name = MINIO_BUCKET
minio_client.upload_file(file_name, bucket_name, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名 URL
url = minio_client.generate_presigned_url(
'get_object',
Params={'Bucket': bucket_name, 'Key': file_name},
ExpiresIn=3600
)
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
except Exception as e:
logger.error(f"Failed to upload to MinIO: {str(e)}")
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
finally:
# 清理临时文件
for temp_file in temp_files:
try:
if os.path.exists(temp_file):
os.remove(temp_file)
except Exception as e:
logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}")
except Exception as e:
logger.error(f"Failed to render and save charts: {str(e)}")
raise RuntimeError(f"Chart rendering failed: {str(e)}") from e
finally:
# 确保浏览器关闭
if browser:
try:
await browser.close()
except Exception as e:
logger.warning(f"Failed to close browser: {str(e)}")
def format_response(basic_data: list, table_data: str, phase_data: str) -> str:
"""
格式化响应数据
"""
response = "### OQMD Data\n"
for item in basic_data:
response += f"**{item}**\n"
response += "\n### Phase Diagram\n\n"
response += f"![Phase Diagram]({phase_data})\n\n"
response += "\n### Compounds at this composition\n\n"
response += f"{table_data}\n"
return {
"status": "success",
"data": response
}

20
main.py Normal file
View File

@@ -0,0 +1,20 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
from fastapi import FastAPI
import logging
from database.material_project_router import router as material_project_router
from database.oqmd_router import router as oqmd_router
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI()
app.include_router(material_project_router)
app.include_router(oqmd_router)

74
model/fairchem_router.py Normal file
View File

@@ -0,0 +1,74 @@
from fairchem.core import OCPCalculator
from ase.optimize import FIRE # Import your optimizer of choice
from ase.filters import FrechetCellFilter # to include cell relaxations
from ase.io import read
from pymatgen.core import Structure
from pymatgen.ext.matproj import MPRester
from pymatgen.analysis.phase_diagram import PhaseDiagram, PDEntry
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatgen.entries.mixing_scheme import MaterialsProjectDFTMixingScheme
# 创建相图并计算形成能与 above hull energy 的函数
def calculate_phase_diagram_properties(structure, total_energy, api_key):
"""
计算化合物的形成能和 above hull energy
参数:
- formula (str): 化学式 (如 "CsPbBr3")
- total_energy (float): 化合物的总能量 (eV)
- mpr (MPRester): MPRester 实例
返回:
- formation_energy (float): 每个原子的形成能 (eV/atom)
- e_above_hull (float): 每个原子的 above hull energy (eV/atom)
"""
chemsys = structure.chemical_system.split("-")
formula = structure.reduced_formula
with MPRester(api_key) as mpr:
# 获取化学系统中所有的相
# entries = mpr.get_entries_in_chemsys(elements=chemsys, additional_criteria={"thermo_types": ["GGA_GGA+U"]})
entries = mpr.get_entries_in_chemsys(elements=chemsys, additional_criteria={"thermo_types": ["GGA_GGA+U", "R2SCAN"]})
# 创建新计算结构的 PDEntry
pd_entry = PDEntry(composition=formula, energy=total_energy)
# entries.append(pd_entry)
scheme = MaterialsProjectDFTMixingScheme()
entries = scheme.process_entries(entries)
# 创建相图
pd = PhaseDiagram(entries + [pd_entry])
# 计算形成能和 above hull energy
formation_energy = pd.get_form_energy_per_atom(pd_entry)
e_above_hull = pd.get_e_above_hull(pd_entry)
return formation_energy, e_above_hull
atoms = read("/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/CsPbBr3.cif") # Read in an atoms object or create your own structure
calc = OCPCalculator(checkpoint_path="/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/meta_fairchem/eqV2_86M_omat_mp_salex.pt") # Path to downloaded checkpoint
atoms.calc = calc
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=0.01)
total_energy = atoms.get_potential_energy()
print("Predicted Total Energy: ", total_energy)
# 保存优化后的结构
atoms.write("optimized_structure.cif") # 保存为 CIF 文件
print("Geometry optimization completed. Optimized structure saved as 'optimized_structure.cif'.")
# 从 ASE 转换为 Pymatgen 结构
optimized_structure = Structure.from_file("optimized_structure.cif")
api_key = "gfBp2in8qxm9Xm2SwLKFwNxDyZvNTAEt"
mpr = MPRester(api_key)
print(f"Chemical Formula: {optimized_structure .composition.reduced_formula}")
formation_energy, e_above_hull = calculate_phase_diagram_properties(
structure=optimized_structure,
total_energy=total_energy,
api_key=api_key
)
print(formation_energy, e_above_hull)
print()

8
utils.py Normal file
View File

@@ -0,0 +1,8 @@
import logging
from multiprocessing import Process, Manager
import asyncio
from typing import Dict, Any, List
from mp_api.client import MPRester
logger = logging.getLogger(__name__)