Files
mars_toolkit/database/material_project_router.py
2025-01-06 10:03:38 +08:00

222 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import boto3
from fastapi import APIRouter, Request
import json
import asyncio
import logging
import datetime
from mp_api.client import MPRester
from multiprocessing import Process, Manager
from typing import Dict, Any, List
from constant import MP_ENDPOINT, MP_API_KEY, TIME_OUT, TOPK_RESULT, HTTP_PROXY, HTTPS_PROXY, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
router = APIRouter(prefix="/mp", tags=["Material Project"])
logger = logging.getLogger(__name__)
@router.get("/search")
async def search_from_material_project(request: Request):
# 打印请求日志
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Query parameters: {request.query_params}")
# 解析查询参数
search_args = parse_search_parameters(request.query_params)
# 检查API key
if MP_API_KEY is None or MP_API_KEY == '':
return 'Material Project API CANNOT Be None'
try:
# 执行搜索
docs = await execute_search(search_args)
# 处理搜索结果
res = process_search_results(docs)
url = ""
# 返回结果
if len(res) == 0:
return {"status": "success", "data": "No results found, please try again."}
else:
# 上传结果到MinIO
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"mp_search_results_{timestamp}.json"
try:
minio_client = boto3.client(
's3',
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
aws_access_key_id=MINIO_ACCESS_KEY,
aws_secret_access_key=MINIO_SECRET_KEY
)
# 将结果写入临时文件
with open(file_name, 'w') as f:
json.dump(res, f, indent=2)
# 上传到MinIO
minio_client.upload_file(file_name, MINIO_BUCKET, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名URL
url = minio_client.generate_presigned_url(
'get_object',
Params={'Bucket': MINIO_BUCKET, 'Key': file_name},
ExpiresIn=3600
)
url = url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
# 删除临时文件
os.remove(file_name)
except Exception as e:
logger.error(f"Failed to upload to MinIO: {str(e)}")
return {
"status": "error",
"data": f"Failed to upload results to MinIO: {str(e)}"
}
# 格式化返回结果
res_chunk = "```json\n" + json.dumps(res[:TOPK_RESULT], indent=2) + "\n```"
res_template = f"""
好的,以下是用户的查询结果:
由于返回长度的限制,我们只能返回前{TOPK_RESULT}个结果。如下:
{res_chunk}
如果用户需要更多的结果,请提示用户修改查询条件,或者尝试使用其他查询参数。
同时我们将全部的的查询结果上传到MinIO中请你提示用户可以通过以下链接下载
[Download]({url})
"""
return {"status": "success", "data": res_template}
except asyncio.TimeoutError:
logger.error(f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again.")
return {
"status": "error",
"data": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again."
}
def parse_bool(param: str) -> bool | None:
if not param:
return None
return param.lower() == 'true'
def parse_list(param: str) -> List[str] | None:
if not param:
return None
return param.split(',')
def parse_tuple(param: str) -> tuple[float, float] | None:
if not param:
return None
try:
values = param.split(',')
return (float(values[0]), float(values[1]))
except (ValueError, IndexError):
return None
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
"""解析搜索参数"""
return {
'band_gap': parse_tuple(query_params.get('band_gap')),
'chemsys': parse_list(query_params.get('chemsys')),
'crystal_system': parse_list(query_params.get('crystal_system')),
'density': parse_tuple(query_params.get('density')),
'formation_energy': parse_tuple(query_params.get('formation_energy')),
'elements': parse_list(query_params.get('elements')),
'exclude_elements': parse_list(query_params.get('exclude_elements')),
'formula': parse_list(query_params.get('formula')),
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
'is_metal': parse_bool(query_params.get('is_metal')),
'is_stable': parse_bool(query_params.get('is_stable')),
'magnetic_ordering': query_params.get('magnetic_ordering'),
'material_ids': parse_list(query_params.get('material_ids')),
'total_energy': parse_tuple(query_params.get('total_energy')),
'num_elements': parse_tuple(query_params.get('num_elements')),
'volume': parse_tuple(query_params.get('volume')),
'chunk_size': int(query_params.get('chunk_size', '5'))
}
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""处理搜索结果"""
fields = [
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
'universal_anisotropy', 'theoretical'
]
res = []
for doc in docs:
try:
new_docs = {}
for field_name in fields:
new_docs[field_name] = doc.get(field_name, '')
res.append(new_docs)
except Exception as e:
logger.warning(f"Error processing document: {str(e)}")
continue
return res
async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -> List[Dict[str, Any]]:
"""执行搜索"""
manager = Manager()
queue = manager.Queue()
p = Process(target=_search_worker, args=(queue, MP_API_KEY), kwargs=search_args)
p.start()
logger.info(f"Started worker process with PID: {p.pid}")
p.join(timeout=timeout)
if p.is_alive():
logger.warning(f"Terminating worker process {p.pid} due to timeout")
p.terminate()
p.join()
raise asyncio.TimeoutError(f"Request timed out after {timeout} seconds")
try:
if queue.empty():
logger.warning("Queue is empty after process completion")
else:
logger.info("Queue contains data, retrieving...")
result = queue.get(timeout=15)
except queue.Empty:
logger.error("Failed to retrieve data from queue")
raise RuntimeError("Failed to retrieve data from worker process")
if isinstance(result, Exception):
logger.error(f"Error in search worker: {str(result)}")
raise result
logger.info(f"Successfully retrieved {len(result)} documents")
return result
def _search_worker(queue, api_key, **kwargs):
"""搜索工作线程"""
try:
import os
os.environ['HTTP_PROXY'] = HTTP_PROXY
os.environ['HTTPS_PROXY'] = HTTPS_PROXY
mpr = MPRester(api_key, endpoint=MP_ENDPOINT)
result = mpr.materials.summary.search(**kwargs)
queue.put([doc.model_dump() for doc in result])
except Exception as e:
queue.put(e)