222 lines
7.9 KiB
Python
222 lines
7.9 KiB
Python
"""
|
||
Author: Yutang LI
|
||
Institution: SIAT-MIC
|
||
Contact: yt.li2@siat.ac.cn
|
||
"""
|
||
|
||
import os
|
||
import boto3
|
||
from fastapi import APIRouter, Request
|
||
import json
|
||
import asyncio
|
||
import logging
|
||
import datetime
|
||
from mp_api.client import MPRester
|
||
from multiprocessing import Process, Manager
|
||
from typing import Dict, Any, List
|
||
from constant import MP_ENDPOINT, MP_API_KEY, TIME_OUT, TOPK_RESULT, HTTP_PROXY, HTTPS_PROXY, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
|
||
|
||
router = APIRouter(prefix="/mp", tags=["Material Project"])
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@router.get("/search")
|
||
async def search_from_material_project(request: Request):
|
||
# 打印请求日志
|
||
logger.info(f"Received request: {request.method} {request.url}")
|
||
logger.info(f"Query parameters: {request.query_params}")
|
||
|
||
# 解析查询参数
|
||
search_args = parse_search_parameters(request.query_params)
|
||
|
||
# 检查API key
|
||
if MP_API_KEY is None or MP_API_KEY == '':
|
||
return 'Material Project API CANNOT Be None'
|
||
|
||
try:
|
||
# 执行搜索
|
||
docs = await execute_search(search_args)
|
||
|
||
# 处理搜索结果
|
||
res = process_search_results(docs)
|
||
url = ""
|
||
# 返回结果
|
||
if len(res) == 0:
|
||
return {"status": "success", "data": "No results found, please try again."}
|
||
|
||
else:
|
||
# 上传结果到MinIO
|
||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||
file_name = f"mp_search_results_{timestamp}.json"
|
||
|
||
try:
|
||
minio_client = boto3.client(
|
||
's3',
|
||
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
|
||
aws_access_key_id=MINIO_ACCESS_KEY,
|
||
aws_secret_access_key=MINIO_SECRET_KEY
|
||
)
|
||
|
||
# 将结果写入临时文件
|
||
with open(file_name, 'w') as f:
|
||
json.dump(res, f, indent=2)
|
||
|
||
# 上传到MinIO
|
||
minio_client.upload_file(file_name, MINIO_BUCKET, file_name, ExtraArgs={"ACL": "private"})
|
||
|
||
# 生成预签名URL
|
||
url = minio_client.generate_presigned_url(
|
||
'get_object',
|
||
Params={'Bucket': MINIO_BUCKET, 'Key': file_name},
|
||
ExpiresIn=3600
|
||
)
|
||
url = url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
|
||
|
||
# 删除临时文件
|
||
os.remove(file_name)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to upload to MinIO: {str(e)}")
|
||
return {
|
||
"status": "error",
|
||
"data": f"Failed to upload results to MinIO: {str(e)}"
|
||
}
|
||
|
||
# 格式化返回结果
|
||
res_chunk = "```json\n" + json.dumps(res[:TOPK_RESULT], indent=2) + "\n```"
|
||
res_template = f"""
|
||
好的,以下是用户的查询结果:
|
||
由于返回长度的限制,我们只能返回前{TOPK_RESULT}个结果。如下:
|
||
{res_chunk}
|
||
如果用户需要更多的结果,请提示用户修改查询条件,或者尝试使用其他查询参数。
|
||
同时我们将全部的的查询结果上传到MinIO中,请你提示用户可以通过以下链接下载:
|
||
[Download]({url})
|
||
"""
|
||
return {"status": "success", "data": res_template}
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again.")
|
||
return {
|
||
"status": "error",
|
||
"data": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again."
|
||
}
|
||
|
||
|
||
def parse_bool(param: str) -> bool | None:
|
||
if not param:
|
||
return None
|
||
return param.lower() == 'true'
|
||
|
||
|
||
def parse_list(param: str) -> List[str] | None:
|
||
if not param:
|
||
return None
|
||
return param.split(',')
|
||
|
||
|
||
def parse_tuple(param: str) -> tuple[float, float] | None:
|
||
if not param:
|
||
return None
|
||
try:
|
||
values = param.split(',')
|
||
return (float(values[0]), float(values[1]))
|
||
except (ValueError, IndexError):
|
||
return None
|
||
|
||
|
||
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
|
||
"""解析搜索参数"""
|
||
return {
|
||
'band_gap': parse_tuple(query_params.get('band_gap')),
|
||
'chemsys': parse_list(query_params.get('chemsys')),
|
||
'crystal_system': parse_list(query_params.get('crystal_system')),
|
||
'density': parse_tuple(query_params.get('density')),
|
||
'formation_energy': parse_tuple(query_params.get('formation_energy')),
|
||
'elements': parse_list(query_params.get('elements')),
|
||
'exclude_elements': parse_list(query_params.get('exclude_elements')),
|
||
'formula': parse_list(query_params.get('formula')),
|
||
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
|
||
'is_metal': parse_bool(query_params.get('is_metal')),
|
||
'is_stable': parse_bool(query_params.get('is_stable')),
|
||
'magnetic_ordering': query_params.get('magnetic_ordering'),
|
||
'material_ids': parse_list(query_params.get('material_ids')),
|
||
'total_energy': parse_tuple(query_params.get('total_energy')),
|
||
'num_elements': parse_tuple(query_params.get('num_elements')),
|
||
'volume': parse_tuple(query_params.get('volume')),
|
||
'chunk_size': int(query_params.get('chunk_size', '5'))
|
||
}
|
||
|
||
|
||
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||
"""处理搜索结果"""
|
||
fields = [
|
||
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
|
||
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
|
||
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
|
||
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
|
||
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
|
||
'universal_anisotropy', 'theoretical'
|
||
]
|
||
|
||
res = []
|
||
for doc in docs:
|
||
try:
|
||
new_docs = {}
|
||
for field_name in fields:
|
||
new_docs[field_name] = doc.get(field_name, '')
|
||
res.append(new_docs)
|
||
except Exception as e:
|
||
logger.warning(f"Error processing document: {str(e)}")
|
||
continue
|
||
return res
|
||
|
||
|
||
async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -> List[Dict[str, Any]]:
|
||
"""执行搜索"""
|
||
manager = Manager()
|
||
queue = manager.Queue()
|
||
|
||
p = Process(target=_search_worker, args=(queue, MP_API_KEY), kwargs=search_args)
|
||
p.start()
|
||
|
||
logger.info(f"Started worker process with PID: {p.pid}")
|
||
p.join(timeout=timeout)
|
||
|
||
if p.is_alive():
|
||
logger.warning(f"Terminating worker process {p.pid} due to timeout")
|
||
p.terminate()
|
||
p.join()
|
||
raise asyncio.TimeoutError(f"Request timed out after {timeout} seconds")
|
||
|
||
try:
|
||
if queue.empty():
|
||
logger.warning("Queue is empty after process completion")
|
||
else:
|
||
logger.info("Queue contains data, retrieving...")
|
||
|
||
result = queue.get(timeout=15)
|
||
except queue.Empty:
|
||
logger.error("Failed to retrieve data from queue")
|
||
raise RuntimeError("Failed to retrieve data from worker process")
|
||
|
||
if isinstance(result, Exception):
|
||
logger.error(f"Error in search worker: {str(result)}")
|
||
raise result
|
||
|
||
logger.info(f"Successfully retrieved {len(result)} documents")
|
||
return result
|
||
|
||
|
||
def _search_worker(queue, api_key, **kwargs):
|
||
"""搜索工作线程"""
|
||
try:
|
||
import os
|
||
os.environ['HTTP_PROXY'] = HTTP_PROXY
|
||
os.environ['HTTPS_PROXY'] = HTTPS_PROXY
|
||
mpr = MPRester(api_key, endpoint=MP_ENDPOINT)
|
||
result = mpr.materials.summary.search(**kwargs)
|
||
queue.put([doc.model_dump() for doc in result])
|
||
except Exception as e:
|
||
queue.put(e)
|