diff --git a/.gitignore b/.gitignore index 5d381cc..416ae75 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,9 @@ # ---> Python + +# model +*.pt +/home/ubuntu/sas0/LYT/mars1215/mars_toolkit/model/eqV2_86M_omat_mp_salex.pt + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/chart_screenshot.png b/chart_screenshot.png new file mode 100644 index 0000000..1dede93 Binary files /dev/null and b/chart_screenshot.png differ diff --git a/constant.py b/constant.py new file mode 100644 index 0000000..a36a7fe --- /dev/null +++ b/constant.py @@ -0,0 +1,10 @@ +TOPK_RESULT = 5 +TIME_OUT = 30 +MP_API_KEY = "gfBp2in8qxm9Xm2SwLKFwNxDyZvNTAEt" + +# MinIO configuration +MINIO_ENDPOINT = "https://s3-api.siat-mic.com" +INTERNEL_MINIO_ENDPOINT = "http://100.85.52.31:9000" # 内网地址,如果有就填,上传会更快。 +MINIO_ACCESS_KEY = "9bUtQL1Gpo9JB6o3pSGr" +MINIO_SECRET_KEY = "1Qug5H73R3kP8boIHvdVcFtcb1jU9GRWnlmMpx0g" +MINIO_BUCKET = "temp" diff --git a/database/__init__.py b/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/material_project_router.py b/database/material_project_router.py new file mode 100644 index 0000000..aae911c --- /dev/null +++ b/database/material_project_router.py @@ -0,0 +1,163 @@ +""" +Author: Yutang LI +Institution: SIAT-MIC +Contact: yt.li2@siat.ac.cn +""" + +from fastapi import APIRouter, Request +import json +import asyncio +import logging +from mp_api.client import MPRester +from multiprocessing import Process, Manager +from typing import Dict, Any, List +from constant import MP_API_KEY, TIME_OUT, TOPK_RESULT + +router = APIRouter(prefix="/mp", tags=["Material Project"]) + +logger = logging.getLogger(__name__) + + +@router.get("/search") +async def search_from_material_project(request: Request): + # 打印请求日志 + logger.info(f"Received request: {request.method} {request.url}") + logger.info(f"Query parameters: {request.query_params}") + + # 解析查询参数 + search_args = parse_search_parameters(request.query_params) + + # 检查API key + if MP_API_KEY is None or MP_API_KEY == '': + return 'Material Project API CANNOT Be None' + + try: + # 执行搜索 + docs = await execute_search(search_args) + + # 处理搜索结果 + res = process_search_results(docs) + + # 返回结果 + if len(res) >= TOPK_RESULT: + return json.dumps(res[:TOPK_RESULT], indent=2) + return json.dumps(res, indent=2) + + except asyncio.TimeoutError: + logger.error(f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again.") + return {"error": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again."} + + +def parse_bool(param: str) -> bool | None: + if not param: + return None + return param.lower() == 'true' + + +def parse_list(param: str) -> List[str] | None: + if not param: + return None + return param.split(',') + + +def parse_tuple(param: str) -> tuple[float, float] | None: + if not param: + return None + try: + values = param.split(',') + return (float(values[0]), float(values[1])) + except (ValueError, IndexError): + return None + + +def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]: + """解析搜索参数""" + return { + 'band_gap': parse_tuple(query_params.get('band_gap')), + 'chemsys': parse_list(query_params.get('chemsys')), + 'crystal_system': parse_list(query_params.get('crystal_system')), + 'density': parse_tuple(query_params.get('density')), + 'formation_energy': parse_tuple(query_params.get('formation_energy')), + 'elements': parse_list(query_params.get('elements')), + 'exclude_elements': parse_list(query_params.get('exclude_elements')), + 'formula': parse_list(query_params.get('formula')), + 'is_gap_direct': parse_bool(query_params.get('is_gap_direct')), + 'is_metal': parse_bool(query_params.get('is_metal')), + 'is_stable': parse_bool(query_params.get('is_stable')), + 'magnetic_ordering': query_params.get('magnetic_ordering'), + 'material_ids': parse_list(query_params.get('material_ids')), + 'total_energy': parse_tuple(query_params.get('total_energy')), + 'num_elements': parse_tuple(query_params.get('num_elements')), + 'volume': parse_tuple(query_params.get('volume')), + 'chunk_size': int(query_params.get('chunk_size', '5')) + } + + +def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """处理搜索结果""" + fields = [ + 'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys', + 'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap', + 'is_gap_direct', 'is_stable', 'formation_energy_per_atom', + 'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi', + 'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus', + 'universal_anisotropy', 'theoretical' + ] + + res = [] + for doc in docs: + try: + new_docs = {} + for field_name in fields: + new_docs[field_name] = doc.get(field_name, '') + res.append(new_docs) + except Exception as e: + logger.warning(f"Error processing document: {str(e)}") + continue + return res + + +async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -> List[Dict[str, Any]]: + """执行搜索""" + manager = Manager() + queue = manager.Queue() + + p = Process(target=_search_worker, args=(queue, MP_API_KEY), kwargs=search_args) + p.start() + + logger.info(f"Started worker process with PID: {p.pid}") + p.join(timeout=timeout) + + if p.is_alive(): + logger.warning(f"Terminating worker process {p.pid} due to timeout") + p.terminate() + p.join() + raise asyncio.TimeoutError(f"Request timed out after {timeout} seconds") + + try: + if queue.empty(): + logger.warning("Queue is empty after process completion") + else: + logger.info("Queue contains data, retrieving...") + + result = queue.get(timeout=15) + except queue.Empty: + logger.error("Failed to retrieve data from queue") + raise RuntimeError("Failed to retrieve data from worker process") + + if isinstance(result, Exception): + logger.error(f"Error in search worker: {str(result)}") + raise result + + logger.info(f"Successfully retrieved {len(result)} documents") + return result + + +def _search_worker(queue, api_key, **kwargs): + """搜索工作线程""" + try: + mpr = MPRester(api_key) + result = mpr.materials.summary.search(**kwargs) + queue.put([doc.model_dump() for doc in result]) + except Exception as e: + queue.put(e) diff --git a/database/oqmd_router.py b/database/oqmd_router.py new file mode 100644 index 0000000..c5153c2 --- /dev/null +++ b/database/oqmd_router.py @@ -0,0 +1,305 @@ +""" +Author: Yutang LI +Institution: SIAT-MIC +Contact: yt.li2@siat.ac.cn +""" +import os +import boto3 +from fastapi import APIRouter, Request +from io import StringIO +import logging +import httpx +import datetime +import pandas as pd +from bs4 import BeautifulSoup +from PIL import Image +from playwright.async_api import async_playwright +from constant import MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT + + +router = APIRouter(prefix="/oqmd", tags=["OQMD"]) +logger = logging.getLogger(__name__) + + +@router.get("/search") +async def search_from_oqmd_by_composition(request: Request): + # 打印请求日志 + logger.info(f"Received request: {request.method} {request.url}") + logger.info(f"Query parameters: {request.query_params}") + + try: + # 获取并解析数据 + composition = request.query_params['composition'] + html = await fetch_oqmd_data(composition) + basic_data, table_data, phase_data = parse_oqmd_html(html) + + # 渲染并保存图表 + phase_diagram_name = await render_and_save_charts(phase_data) + # 返回格式化后的响应 + return format_response(basic_data, table_data, phase_diagram_name) + + except httpx.HTTPStatusError as e: + logger.error(f"OQMD API request failed: {str(e)}") + return { + "status": "error", + "message": f"OQMD API request failed: {str(e)}" + } + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + return { + "status": "error", + "message": f"Unexpected error: {str(e)}" + } + + + +async def fetch_oqmd_data(composition: str) -> str: + """ + 从OQMD获取数据 + Args: + composition: 材料组成字符串 + + Returns: + HTML内容字符串 + + Raises: + httpx.HTTPError: 当发生HTTP相关错误时抛出 + ValueError: 当响应内容无效时抛出 + """ + url = f"https://www.oqmd.org/materials/composition/{composition}" + try: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url) + response.raise_for_status() + + # 验证响应内容 + if not response.text or len(response.text) < 100: + raise ValueError("Invalid response content from OQMD API") + + return response.text + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + if status_code == 401: + logger.error("OQMD API: Unauthorized access") + raise httpx.HTTPError("Unauthorized access to OQMD API") from e + elif status_code == 403: + logger.error("OQMD API: Forbidden access") + raise httpx.HTTPError("Forbidden access to OQMD API") from e + elif status_code == 404: + logger.error("OQMD API: Resource not found") + raise httpx.HTTPError("Resource not found on OQMD API") from e + elif status_code >= 500: + logger.error(f"OQMD API: Server error ({status_code})") + raise httpx.HTTPError(f"OQMD API server error ({status_code})") from e + else: + logger.error(f"OQMD API request failed: {str(e)}") + raise httpx.HTTPError(f"OQMD API request failed: {str(e)}") from e + + except httpx.TimeoutException as e: + logger.error("OQMD API request timed out") + raise httpx.HTTPError("OQMD API request timed out") from e + + except httpx.NetworkError as e: + logger.error(f"Network error occurred: {str(e)}") + raise httpx.HTTPError(f"Network error: {str(e)}") from e + + except ValueError as e: + logger.error(f"Invalid response content: {str(e)}") + raise ValueError(f"Invalid response content: {str(e)}") from e + +def parse_oqmd_html(html: str) -> tuple[list, str, list]: + """ + 解析OQMD HTML数据 + """ + soup = BeautifulSoup(html, 'html.parser') + # 解析基本数据 + basic_data = [] + basic_data.append(soup.find('h1').text.strip()) + for script in soup.find_all('p'): + if script: + combined_text = "" + for element in script.contents: # 遍历

的子元素 + if element.name == 'a': # 如果是 标签 + url = "https://www.oqmd.org" + element['href'] + combined_text += f"[{element.text.strip()}]({url}) " + else: # 如果是文本 + combined_text += element.text.strip() + " " + basic_data.append(combined_text.strip()) + # import pdb + # pdb.set_trace() + + # 解析表格数据 + table = soup.find('table') + if table: + df = pd.read_html(StringIO(str(table)))[0] + df = df.fillna('') + df = df.replace([float('inf'), float('-inf')], '') + # table_data = df.to_dict(orient='records') + table_data = df.to_markdown(index=False) + + # 提取JavaScript数据 + phase_data = [] + for script in soup.find_all('script'): + if script.string and '$(function()' in script.string: + phase_data.append({ + 'type': script.get('type', 'text/javascript'), + 'content': script.string.strip() + }) + + return basic_data, table_data, phase_data + + +async def render_and_save_charts(script_data: list) -> str: + """ + 渲染并保存图表到MinIO + Returns: + str: 图片的预签名URL + Raises: + RuntimeError: 如果图片生成或上传失败 + """ + browser = None + temp_files = [] + try: + # 初始化Playwright + async with async_playwright() as p: + browser = await p.chromium.launch(headless=True) + page = await browser.new_page() + + # 构建包含 JavaScript 的 HTML 代码 + html_content = """ + + + + + + + + Phase Diagram + + +

+
+ + +
+ +
+ + +""" + html_content = html_content.format( + placeholder_content=script_data[0]['content'], + phasediagram_content=script_data[1]['content']) + + await page.set_content(html_content) + await page.wait_for_timeout(5000) + + # 分别截图两个图表 + # 获取placeholder元素位置并扩大截图区域 + placeholder = page.locator('#placeholder') + placeholder_box = await placeholder.bounding_box() + await page.screenshot( + path="placeholder.png", + clip={ + 'x': placeholder_box['x'], + 'y': placeholder_box['y'], + 'width': placeholder_box['width'] + 40, + 'height': placeholder_box['height'] + 40 + } + ) + + # 获取phasediagram元素位置并扩大截图区域 + phasediagram = page.locator('#phasediagram') + phasediagram_box = await phasediagram.bounding_box() + await page.screenshot( + path="phasediagram.png", + clip={ + 'x': phasediagram_box['x'], + 'y': phasediagram_box['y'], + 'width': phasediagram_box['width'] + 40, + 'height': phasediagram_box['height'] + 40 + } + ) + + await browser.close() + + # 拼接图片 + try: + img1 = Image.open("placeholder.png") + temp_files.append("placeholder.png") + img2 = Image.open("phasediagram.png") + temp_files.append("phasediagram.png") + new_img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height))) + new_img.paste(img2, (0, 0)) + new_img.paste(img1, (img2.width, 0)) + timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + file_name = f"oqmd_phase_diagram_{timestamp}.png" + new_img.save(file_name) + temp_files.append(file_name) + except Exception as e: + logger.error(f"Failed to process images: {str(e)}") + raise RuntimeError(f"Image processing failed: {str(e)}") from e + + # 上传到 MinIO 的逻辑 + try: + minio_client = boto3.client( + 's3', + endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT, + aws_access_key_id=MINIO_ACCESS_KEY, + aws_secret_access_key=MINIO_SECRET_KEY + ) + + bucket_name = MINIO_BUCKET + minio_client.upload_file(file_name, bucket_name, file_name, ExtraArgs={"ACL": "private"}) + + # 生成预签名 URL + url = minio_client.generate_presigned_url( + 'get_object', + Params={'Bucket': bucket_name, 'Key': file_name}, + ExpiresIn=3600 + ) + return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT) + + except Exception as e: + logger.error(f"Failed to upload to MinIO: {str(e)}") + raise RuntimeError(f"MinIO upload failed: {str(e)}") from e + finally: + # 清理临时文件 + for temp_file in temp_files: + try: + if os.path.exists(temp_file): + os.remove(temp_file) + except Exception as e: + logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}") + except Exception as e: + logger.error(f"Failed to render and save charts: {str(e)}") + raise RuntimeError(f"Chart rendering failed: {str(e)}") from e + finally: + # 确保浏览器关闭 + if browser: + try: + await browser.close() + except Exception as e: + logger.warning(f"Failed to close browser: {str(e)}") + +def format_response(basic_data: list, table_data: str, phase_data: str) -> str: + """ + 格式化响应数据 + """ + response = "### OQMD Data\n" + for item in basic_data: + response += f"**{item}**\n" + response += "\n### Phase Diagram\n\n" + response += f"![Phase Diagram]({phase_data})\n\n" + response += "\n### Compounds at this composition\n\n" + response += f"{table_data}\n" + + return { + "status": "success", + "data": response + } diff --git a/main.py b/main.py new file mode 100644 index 0000000..2b82f97 --- /dev/null +++ b/main.py @@ -0,0 +1,20 @@ +""" +Author: Yutang LI +Institution: SIAT-MIC +Contact: yt.li2@siat.ac.cn +""" + +from fastapi import FastAPI +import logging +from database.material_project_router import router as material_project_router +from database.oqmd_router import router as oqmd_router + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +app = FastAPI() +app.include_router(material_project_router) +app.include_router(oqmd_router) diff --git a/model/fairchem_router.py b/model/fairchem_router.py new file mode 100644 index 0000000..f334fa1 --- /dev/null +++ b/model/fairchem_router.py @@ -0,0 +1,74 @@ +from fairchem.core import OCPCalculator +from ase.optimize import FIRE # Import your optimizer of choice +from ase.filters import FrechetCellFilter # to include cell relaxations +from ase.io import read +from pymatgen.core import Structure +from pymatgen.ext.matproj import MPRester +from pymatgen.analysis.phase_diagram import PhaseDiagram, PDEntry +from pymatgen.entries.computed_entries import ComputedStructureEntry +from pymatgen.entries.mixing_scheme import MaterialsProjectDFTMixingScheme + +# 创建相图并计算形成能与 above hull energy 的函数 +def calculate_phase_diagram_properties(structure, total_energy, api_key): + """ + 计算化合物的形成能和 above hull energy + 参数: + - formula (str): 化学式 (如 "CsPbBr3") + - total_energy (float): 化合物的总能量 (eV) + - mpr (MPRester): MPRester 实例 + + 返回: + - formation_energy (float): 每个原子的形成能 (eV/atom) + - e_above_hull (float): 每个原子的 above hull energy (eV/atom) + """ + chemsys = structure.chemical_system.split("-") + formula = structure.reduced_formula + with MPRester(api_key) as mpr: + # 获取化学系统中所有的相 + # entries = mpr.get_entries_in_chemsys(elements=chemsys, additional_criteria={"thermo_types": ["GGA_GGA+U"]}) + entries = mpr.get_entries_in_chemsys(elements=chemsys, additional_criteria={"thermo_types": ["GGA_GGA+U", "R2SCAN"]}) + + # 创建新计算结构的 PDEntry + pd_entry = PDEntry(composition=formula, energy=total_energy) + # entries.append(pd_entry) + + scheme = MaterialsProjectDFTMixingScheme() + entries = scheme.process_entries(entries) + + # 创建相图 + pd = PhaseDiagram(entries + [pd_entry]) + + # 计算形成能和 above hull energy + formation_energy = pd.get_form_energy_per_atom(pd_entry) + e_above_hull = pd.get_e_above_hull(pd_entry) + + return formation_energy, e_above_hull + + +atoms = read("/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/CsPbBr3.cif") # Read in an atoms object or create your own structure +calc = OCPCalculator(checkpoint_path="/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/meta_fairchem/eqV2_86M_omat_mp_salex.pt") # Path to downloaded checkpoint +atoms.calc = calc +dyn = FIRE(FrechetCellFilter(atoms)) +dyn.run(fmax=0.01) + +total_energy = atoms.get_potential_energy() +print("Predicted Total Energy: ", total_energy) + +# 保存优化后的结构 +atoms.write("optimized_structure.cif") # 保存为 CIF 文件 +print("Geometry optimization completed. Optimized structure saved as 'optimized_structure.cif'.") + +# 从 ASE 转换为 Pymatgen 结构 +optimized_structure = Structure.from_file("optimized_structure.cif") + +api_key = "gfBp2in8qxm9Xm2SwLKFwNxDyZvNTAEt" +mpr = MPRester(api_key) +print(f"Chemical Formula: {optimized_structure .composition.reduced_formula}") +formation_energy, e_above_hull = calculate_phase_diagram_properties( + structure=optimized_structure, + total_energy=total_energy, + api_key=api_key +) + +print(formation_energy, e_above_hull) +print() \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..417f879 --- /dev/null +++ b/utils.py @@ -0,0 +1,8 @@ +import logging +from multiprocessing import Process, Manager +import asyncio +from typing import Dict, Any, List +from mp_api.client import MPRester + +logger = logging.getLogger(__name__) +