重构代码

This commit is contained in:
2025-01-06 14:54:41 +08:00
parent c2417fec25
commit c7d2d482da
13 changed files with 519 additions and 439 deletions

0
services/__init__.py Normal file
View File

View File

@@ -0,0 +1,92 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import logging
import tempfile
import os
import datetime
from typing import Optional
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from utils import settings, handle_minio_upload
logger = logging.getLogger(__name__)
# 初始化模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
try:
from fairchem.core import OCPCalculator
calc = OCPCalculator(checkpoint_path=settings.fairchem_model_path)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""将输入内容转换为Atoms对象"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""生成对称性CIF"""
analyzer = SpacegroupAnalyzer(structure)
structure = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
return tmp_file.read()
def optimize_structure(atoms: Atoms, output_format: str):
"""优化晶体结构"""
atoms.calc = calc
try:
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=settings.fmax)
total_energy = atoms.get_total_energy()
# 处理对称性
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
# 保存优化结果到临时文件
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"optimized_structure_{timestamp}.{output_format}"
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
# 上传到MinIO
url = handle_minio_upload(tmp_path, file_name)
return total_energy, content, url
finally:
os.unlink(tmp_path)

128
services/mp_service.py Normal file
View File

@@ -0,0 +1,128 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import json
import asyncio
import logging
import datetime
from multiprocessing import Process, Manager
from typing import Dict, Any, List
from mp_api.client import MPRester
from utils import settings, handle_minio_upload
from error_handlers import handle_general_error
logger = logging.getLogger(__name__)
def parse_bool(param: str) -> bool | None:
if not param:
return None
return param.lower() == 'true'
def parse_list(param: str) -> List[str] | None:
if not param:
return None
return param.split(',')
def parse_tuple(param: str) -> tuple[float, float] | None:
if not param:
return None
try:
values = param.split(',')
return (float(values[0]), float(values[1]))
except (ValueError, IndexError):
return None
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
"""解析搜索参数"""
return {
'band_gap': parse_tuple(query_params.get('band_gap')),
'chemsys': parse_list(query_params.get('chemsys')),
'crystal_system': parse_list(query_params.get('crystal_system')),
'density': parse_tuple(query_params.get('density')),
'formation_energy': parse_tuple(query_params.get('formation_energy')),
'elements': parse_list(query_params.get('elements')),
'exclude_elements': parse_list(query_params.get('exclude_elements')),
'formula': parse_list(query_params.get('formula')),
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
'is_metal': parse_bool(query_params.get('is_metal')),
'is_stable': parse_bool(query_params.get('is_stable')),
'magnetic_ordering': query_params.get('magnetic_ordering'),
'material_ids': parse_list(query_params.get('material_ids')),
'total_energy': parse_tuple(query_params.get('total_energy')),
'num_elements': parse_tuple(query_params.get('num_elements')),
'volume': parse_tuple(query_params.get('volume')),
'chunk_size': int(query_params.get('chunk_size', '5'))
}
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""处理搜索结果"""
fields = [
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
'universal_anisotropy', 'theoretical'
]
res = []
for doc in docs:
try:
new_docs = {}
for field_name in fields:
new_docs[field_name] = doc.get(field_name, '')
res.append(new_docs)
except Exception as e:
logger.warning(f"Error processing document: {str(e)}")
continue
return res
async def execute_search(search_args: Dict[str, Any], timeout: int = 30) -> List[Dict[str, Any]]:
"""执行搜索"""
manager = Manager()
queue = manager.Queue()
p = Process(target=_search_worker, args=(queue, settings.mp_api_key), kwargs=search_args)
p.start()
logger.info(f"Started worker process with PID: {p.pid}")
p.join(timeout=timeout)
if p.is_alive():
logger.warning(f"Terminating worker process {p.pid} due to timeout")
p.terminate()
p.join()
raise asyncio.TimeoutError(f"Request timed out after {timeout} seconds")
try:
if queue.empty():
logger.warning("Queue is empty after process completion")
else:
logger.info("Queue contains data, retrieving...")
result = queue.get(timeout=15)
except queue.Empty:
logger.error("Failed to retrieve data from queue")
raise RuntimeError("Failed to retrieve data from worker process")
if isinstance(result, Exception):
logger.error(f"Error in search worker: {str(result)}")
raise result
logger.info(f"Successfully retrieved {len(result)} documents")
return result
def _search_worker(queue, api_key, **kwargs):
"""搜索工作线程"""
try:
import os
os.environ['HTTP_PROXY'] = settings.http_proxy or ''
os.environ['HTTPS_PROXY'] = settings.https_proxy or ''
mpr = MPRester(api_key, endpoint=settings.mp_endpoint)
result = mpr.materials.summary.search(**kwargs)
queue.put([doc.model_dump() for doc in result])
except Exception as e:
queue.put(e)

191
services/oqmd_service.py Normal file
View File

@@ -0,0 +1,191 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import datetime
import logging
import os
import httpx
import pandas as pd
from bs4 import BeautifulSoup
from PIL import Image
from playwright.async_api import async_playwright
from io import StringIO
from utils import settings, handle_minio_upload
logger = logging.getLogger(__name__)
async def fetch_oqmd_data(composition: str) -> str:
"""从OQMD获取数据"""
url = f"https://www.oqmd.org/materials/composition/{composition}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 验证响应内容
if not response.text or len(response.text) < 100:
raise ValueError("Invalid response content from OQMD API")
return response.text
except httpx.HTTPStatusError as e:
logger.error(f"OQMD API request failed: {str(e)}")
raise
except httpx.TimeoutException:
logger.error("OQMD API request timed out")
raise
except httpx.NetworkError as e:
logger.error(f"Network error occurred: {str(e)}")
raise
except ValueError as e:
logger.error(f"Invalid response content: {str(e)}")
raise
def parse_oqmd_html(html: str) -> tuple[list, str, list]:
"""解析OQMD HTML数据"""
soup = BeautifulSoup(html, 'html.parser')
# 解析基本数据
basic_data = []
basic_data.append(soup.find('h1').text.strip())
for script in soup.find_all('p'):
if script:
combined_text = ""
for element in script.contents:
if element.name == 'a':
url = "https://www.oqmd.org" + element['href']
combined_text += f"[{element.text.strip()}]({url}) "
else:
combined_text += element.text.strip() + " "
basic_data.append(combined_text.strip())
# 解析表格数据
table = soup.find('table')
if table:
df = pd.read_html(StringIO(str(table)))[0]
df = df.fillna('')
df = df.replace([float('inf'), float('-inf')], '')
table_data = df.to_markdown(index=False)
# 提取JavaScript数据
phase_data = []
for script in soup.find_all('script'):
if script.string and '$(function()' in script.string:
phase_data.append({
'type': script.get('type', 'text/javascript'),
'content': script.string.strip()
})
return basic_data, table_data, phase_data
async def render_and_save_charts(script_data: list) -> str:
"""渲染并保存图表到MinIO"""
browser = None
temp_files = []
try:
# 初始化Playwright
async with async_playwright() as p:
browser = await p.chromium.launch(headless=True)
page = await browser.new_page()
# 构建包含 JavaScript 的 HTML 代码
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/jquery.flot@0.8.3/jquery.flot.js"></script>
<title>Phase Diagram</title>
</head>
<body>
<div class="diagram">
<div id="placeholder" width="200" height="400" style="direction: ltr; position: absolute; left: 550px; top: 0px; width: 200px; height: 400px;"></div>
<script>
{placeholder_content}
</script>
<div id="phasediagram" width="500" height="400" style="direction: ltr; position: absolute; left: 0px; top: 0px; width: 500px; height: 400px;"></div>
<script>
{phasediagram_content}
</script>
</div>
</body>
</html>
"""
html_content = html_content.format(
placeholder_content=script_data[0]['content'],
phasediagram_content=script_data[1]['content'])
await page.set_content(html_content)
await page.wait_for_timeout(5000)
# 分别截图两个图表
placeholder = page.locator('#placeholder')
placeholder_box = await placeholder.bounding_box()
await page.screenshot(
path="placeholder.png",
clip={
'x': placeholder_box['x'],
'y': placeholder_box['y'],
'width': placeholder_box['width'] + 40,
'height': placeholder_box['height'] + 40
}
)
phasediagram = page.locator('#phasediagram')
phasediagram_box = await phasediagram.bounding_box()
await page.screenshot(
path="phasediagram.png",
clip={
'x': phasediagram_box['x'],
'y': phasediagram_box['y'],
'width': phasediagram_box['width'] + 40,
'height': phasediagram_box['height'] + 40
}
)
await browser.close()
# 拼接图片
try:
img1 = Image.open("placeholder.png")
temp_files.append("placeholder.png")
img2 = Image.open("phasediagram.png")
temp_files.append("phasediagram.png")
new_img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
new_img.paste(img2, (0, 0))
new_img.paste(img1, (img2.width, 0))
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"oqmd_phase_diagram_{timestamp}.png"
new_img.save(file_name)
temp_files.append(file_name)
except Exception as e:
logger.error(f"Failed to process images: {str(e)}")
raise RuntimeError(f"Image processing failed: {str(e)}") from e
# 上传到 MinIO
url = handle_minio_upload(file_name, file_name)
return url
except Exception as e:
logger.error(f"Failed to render and save charts: {str(e)}")
raise
finally:
# 清理临时文件
for temp_file in temp_files:
try:
if os.path.exists(temp_file):
os.remove(temp_file)
except Exception as e:
logger.warning(f"Failed to remove temporary file {temp_file}: {str(e)}")
# 确保浏览器关闭
if browser:
try:
await browser.close()
except Exception as e:
logger.warning(f"Failed to close browser: {str(e)}")