Files
mars_toolkit/model/fairchem_router.py
2025-01-05 17:43:06 +08:00

166 lines
5.6 KiB
Python

from fastapi import APIRouter, Body, Query
from fairchem.core import OCPCalculator
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
import tempfile
import os
import boto3
from constant import FAIRCHEM_MODEL_PATH, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT, FMAX
from typing import Optional
import logging
import datetime
router = APIRouter(prefix="/fairchem", tags=["fairchem"])
logger = logging.getLogger(__name__)
# 初始化模型
calc = None
def init_model():
global calc
calc = OCPCalculator(checkpoint_path=FAIRCHEM_MODEL_PATH)
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""将输入内容转换为Atoms对象"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""生成对称性CIF"""
analyzer = SpacegroupAnalyzer(structure)
structure = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
return tmp_file.read()
def upload_to_minio(file_path: str, file_name: str) -> str:
"""上传文件到MinIO并返回预签名URL"""
try:
minio_client = boto3.client(
's3',
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
aws_access_key_id=MINIO_ACCESS_KEY,
aws_secret_access_key=MINIO_SECRET_KEY
)
bucket_name = MINIO_BUCKET
minio_client.upload_file(file_path, bucket_name, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名 URL
url = minio_client.generate_presigned_url(
'get_object',
Params={'Bucket': bucket_name, 'Key': file_name},
ExpiresIn=3600
)
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
except Exception as e:
logger.error(f"Failed to upload to MinIO: {str(e)}")
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
from io import StringIO
import sys
def optimize_structure(atoms: Atoms, output_format: str):
"""优化晶体结构"""
atoms.calc = calc
# 捕获优化日志
old_stdout = sys.stdout
sys.stdout = log_capture = StringIO()
try:
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=FMAX)
total_energy = atoms.get_total_energy()
optimization_log = log_capture.getvalue()
finally:
sys.stdout = old_stdout
# 处理对称性
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
# 保存优化结果到临时文件
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"optimized_structure_{timestamp}.{output_format}"
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
try:
# 上传到MinIO
url = upload_to_minio(tmp_path, file_name)
return total_energy, content, url, optimization_log
finally:
os.unlink(tmp_path)
@router.post("/optimize_structure")
async def optimize_structure_endpoint(
content: str = Body(..., description="Input structure content"),
input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"),
output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)")
):
# 转换输入结构
atoms = convert_structure(input_format, content)
if atoms is None:
return {
"status": "error",
"data": f"Invalid {input_format} content"
}
try:
# 优化结构
total_energy, optimized_content, download_url, optimization_log = optimize_structure(atoms, output_format)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
```text
{optimization_log}
```
Finally, the Total Energy is: {total_energy} eV
Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link:
👉 Click [here]({download_url}) to download the {output_format.upper()} file
Please ensure that the Optimization Results and download link are fully conveyed to the user, as this is very important for them.
"""
print(format_result)
return {
"status": "success",
"data": format_result
}
except Exception as e:
logger.error(f"Optimization failed: {str(e)}")
return {
"status": "error",
"data": str(e)
}
if __name__ == "__main__":
init_model()