修改mp的实现

This commit is contained in:
2025-01-06 10:03:38 +08:00
parent 5380ee5f9e
commit c2417fec25
3 changed files with 72 additions and 8 deletions

View File

@@ -4,14 +4,17 @@ Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import boto3
from fastapi import APIRouter, Request
import json
import asyncio
import logging
import datetime
from mp_api.client import MPRester
from multiprocessing import Process, Manager
from typing import Dict, Any, List
from constant import MP_ENDPOINT, MP_API_KEY, TIME_OUT, TOPK_RESULT
from constant import MP_ENDPOINT, MP_API_KEY, TIME_OUT, TOPK_RESULT, HTTP_PROXY, HTTPS_PROXY, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT
router = APIRouter(prefix="/mp", tags=["Material Project"])
@@ -37,15 +40,67 @@ async def search_from_material_project(request: Request):
# 处理搜索结果
res = process_search_results(docs)
url = ""
# 返回结果
if len(res) >= TOPK_RESULT:
return json.dumps(res[:TOPK_RESULT], indent=2)
return json.dumps(res, indent=2)
if len(res) == 0:
return {"status": "success", "data": "No results found, please try again."}
else:
# 上传结果到MinIO
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"mp_search_results_{timestamp}.json"
try:
minio_client = boto3.client(
's3',
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
aws_access_key_id=MINIO_ACCESS_KEY,
aws_secret_access_key=MINIO_SECRET_KEY
)
# 将结果写入临时文件
with open(file_name, 'w') as f:
json.dump(res, f, indent=2)
# 上传到MinIO
minio_client.upload_file(file_name, MINIO_BUCKET, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名URL
url = minio_client.generate_presigned_url(
'get_object',
Params={'Bucket': MINIO_BUCKET, 'Key': file_name},
ExpiresIn=3600
)
url = url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
# 删除临时文件
os.remove(file_name)
except Exception as e:
logger.error(f"Failed to upload to MinIO: {str(e)}")
return {
"status": "error",
"data": f"Failed to upload results to MinIO: {str(e)}"
}
# 格式化返回结果
res_chunk = "```json\n" + json.dumps(res[:TOPK_RESULT], indent=2) + "\n```"
res_template = f"""
好的,以下是用户的查询结果:
由于返回长度的限制,我们只能返回前{TOPK_RESULT}个结果。如下:
{res_chunk}
如果用户需要更多的结果,请提示用户修改查询条件,或者尝试使用其他查询参数。
同时我们将全部的的查询结果上传到MinIO中请你提示用户可以通过以下链接下载
[Download]({url})
"""
return {"status": "success", "data": res_template}
except asyncio.TimeoutError:
logger.error(f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again.")
return {"error": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again."}
return {
"status": "error",
"data": f"Request timed out after {TIME_OUT} seconds, please simplify your query and try again."
}
def parse_bool(param: str) -> bool | None:
@@ -156,6 +211,9 @@ async def execute_search(search_args: Dict[str, Any], timeout: int = TIME_OUT) -
def _search_worker(queue, api_key, **kwargs):
"""搜索工作线程"""
try:
import os
os.environ['HTTP_PROXY'] = HTTP_PROXY
os.environ['HTTPS_PROXY'] = HTTPS_PROXY
mpr = MPRester(api_key, endpoint=MP_ENDPOINT)
result = mpr.materials.summary.search(**kwargs)
queue.put([doc.model_dump() for doc in result])