Files
mars_toolkit/services/mp_service.py
2025-01-06 14:54:41 +08:00

129 lines
4.6 KiB
Python

"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import json
import asyncio
import logging
import datetime
from multiprocessing import Process, Manager
from typing import Dict, Any, List
from mp_api.client import MPRester
from utils import settings, handle_minio_upload
from error_handlers import handle_general_error
logger = logging.getLogger(__name__)
def parse_bool(param: str) -> bool | None:
if not param:
return None
return param.lower() == 'true'
def parse_list(param: str) -> List[str] | None:
if not param:
return None
return param.split(',')
def parse_tuple(param: str) -> tuple[float, float] | None:
if not param:
return None
try:
values = param.split(',')
return (float(values[0]), float(values[1]))
except (ValueError, IndexError):
return None
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
"""解析搜索参数"""
return {
'band_gap': parse_tuple(query_params.get('band_gap')),
'chemsys': parse_list(query_params.get('chemsys')),
'crystal_system': parse_list(query_params.get('crystal_system')),
'density': parse_tuple(query_params.get('density')),
'formation_energy': parse_tuple(query_params.get('formation_energy')),
'elements': parse_list(query_params.get('elements')),
'exclude_elements': parse_list(query_params.get('exclude_elements')),
'formula': parse_list(query_params.get('formula')),
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
'is_metal': parse_bool(query_params.get('is_metal')),
'is_stable': parse_bool(query_params.get('is_stable')),
'magnetic_ordering': query_params.get('magnetic_ordering'),
'material_ids': parse_list(query_params.get('material_ids')),
'total_energy': parse_tuple(query_params.get('total_energy')),
'num_elements': parse_tuple(query_params.get('num_elements')),
'volume': parse_tuple(query_params.get('volume')),
'chunk_size': int(query_params.get('chunk_size', '5'))
}
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""处理搜索结果"""
fields = [
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
'universal_anisotropy', 'theoretical'
]
res = []
for doc in docs:
try:
new_docs = {}
for field_name in fields:
new_docs[field_name] = doc.get(field_name, '')
res.append(new_docs)
except Exception as e:
logger.warning(f"Error processing document: {str(e)}")
continue
return res
async def execute_search(search_args: Dict[str, Any], timeout: int = 30) -> List[Dict[str, Any]]:
"""执行搜索"""
manager = Manager()
queue = manager.Queue()
p = Process(target=_search_worker, args=(queue, settings.mp_api_key), kwargs=search_args)
p.start()
logger.info(f"Started worker process with PID: {p.pid}")
p.join(timeout=timeout)
if p.is_alive():
logger.warning(f"Terminating worker process {p.pid} due to timeout")
p.terminate()
p.join()
raise asyncio.TimeoutError(f"Request timed out after {timeout} seconds")
try:
if queue.empty():
logger.warning("Queue is empty after process completion")
else:
logger.info("Queue contains data, retrieving...")
result = queue.get(timeout=15)
except queue.Empty:
logger.error("Failed to retrieve data from queue")
raise RuntimeError("Failed to retrieve data from worker process")
if isinstance(result, Exception):
logger.error(f"Error in search worker: {str(result)}")
raise result
logger.info(f"Successfully retrieved {len(result)} documents")
return result
def _search_worker(queue, api_key, **kwargs):
"""搜索工作线程"""
try:
import os
os.environ['HTTP_PROXY'] = settings.http_proxy or ''
os.environ['HTTPS_PROXY'] = settings.https_proxy or ''
mpr = MPRester(api_key, endpoint=settings.mp_endpoint)
result = mpr.materials.summary.search(**kwargs)
queue.put([doc.model_dump() for doc in result])
except Exception as e:
queue.put(e)