166 lines
5.6 KiB
Python
166 lines
5.6 KiB
Python
from fastapi import APIRouter, Body, Query
|
|
from fairchem.core import OCPCalculator
|
|
from ase.optimize import FIRE
|
|
from ase.filters import FrechetCellFilter
|
|
from ase.atoms import Atoms
|
|
from ase.io import read, write
|
|
from pymatgen.core.structure import Structure
|
|
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
|
from pymatgen.io.cif import CifWriter
|
|
import tempfile
|
|
import os
|
|
import boto3
|
|
from constant import FAIRCHEM_MODEL_PATH, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT, FMAX
|
|
from typing import Optional
|
|
import logging
|
|
import datetime
|
|
|
|
router = APIRouter(prefix="/fairchem", tags=["fairchem"])
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 初始化模型
|
|
calc = None
|
|
|
|
def init_model():
|
|
global calc
|
|
calc = OCPCalculator(checkpoint_path=FAIRCHEM_MODEL_PATH)
|
|
|
|
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
|
"""将输入内容转换为Atoms对象"""
|
|
try:
|
|
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
|
tmp_file.write(content)
|
|
tmp_path = tmp_file.name
|
|
|
|
atoms = read(tmp_path)
|
|
os.unlink(tmp_path)
|
|
return atoms
|
|
except Exception as e:
|
|
logger.error(f"Failed to convert structure: {str(e)}")
|
|
return None
|
|
|
|
def generate_symmetry_cif(structure: Structure) -> str:
|
|
"""生成对称性CIF"""
|
|
analyzer = SpacegroupAnalyzer(structure)
|
|
structure = analyzer.get_refined_structure()
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
|
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
|
|
cif_writer.write_file(tmp_file.name)
|
|
tmp_file.seek(0)
|
|
return tmp_file.read()
|
|
|
|
def upload_to_minio(file_path: str, file_name: str) -> str:
|
|
"""上传文件到MinIO并返回预签名URL"""
|
|
try:
|
|
minio_client = boto3.client(
|
|
's3',
|
|
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
|
|
aws_access_key_id=MINIO_ACCESS_KEY,
|
|
aws_secret_access_key=MINIO_SECRET_KEY
|
|
)
|
|
|
|
bucket_name = MINIO_BUCKET
|
|
minio_client.upload_file(file_path, bucket_name, file_name, ExtraArgs={"ACL": "private"})
|
|
|
|
# 生成预签名 URL
|
|
url = minio_client.generate_presigned_url(
|
|
'get_object',
|
|
Params={'Bucket': bucket_name, 'Key': file_name},
|
|
ExpiresIn=3600
|
|
)
|
|
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
|
|
except Exception as e:
|
|
logger.error(f"Failed to upload to MinIO: {str(e)}")
|
|
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
|
|
|
|
from io import StringIO
|
|
import sys
|
|
|
|
def optimize_structure(atoms: Atoms, output_format: str):
|
|
"""优化晶体结构"""
|
|
atoms.calc = calc
|
|
|
|
# 捕获优化日志
|
|
old_stdout = sys.stdout
|
|
sys.stdout = log_capture = StringIO()
|
|
|
|
try:
|
|
dyn = FIRE(FrechetCellFilter(atoms))
|
|
dyn.run(fmax=FMAX)
|
|
total_energy = atoms.get_total_energy()
|
|
optimization_log = log_capture.getvalue()
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
|
|
# 处理对称性
|
|
if output_format == "cif":
|
|
optimized_structure = Structure.from_ase_atoms(atoms)
|
|
content = generate_symmetry_cif(optimized_structure)
|
|
else:
|
|
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
|
write(tmp_file.name, atoms)
|
|
tmp_file.seek(0)
|
|
content = tmp_file.read()
|
|
|
|
# 保存优化结果到临时文件
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
file_name = f"optimized_structure_{timestamp}.{output_format}"
|
|
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
|
|
tmp_file.write(content)
|
|
tmp_path = tmp_file.name
|
|
try:
|
|
# 上传到MinIO
|
|
url = upload_to_minio(tmp_path, file_name)
|
|
return total_energy, content, url, optimization_log
|
|
finally:
|
|
os.unlink(tmp_path)
|
|
|
|
@router.post("/optimize_structure")
|
|
async def optimize_structure_endpoint(
|
|
content: str = Body(..., description="Input structure content"),
|
|
input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"),
|
|
output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)")
|
|
):
|
|
# 转换输入结构
|
|
atoms = convert_structure(input_format, content)
|
|
if atoms is None:
|
|
return {
|
|
"status": "error",
|
|
"data": f"Invalid {input_format} content"
|
|
}
|
|
|
|
try:
|
|
# 优化结构
|
|
total_energy, optimized_content, download_url, optimization_log = optimize_structure(atoms, output_format)
|
|
|
|
# 格式化返回结果
|
|
format_result = f"""
|
|
The following is the optimized crystal structure information:
|
|
|
|
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
|
```text
|
|
{optimization_log}
|
|
```
|
|
Finally, the Total Energy is: {total_energy} eV
|
|
Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link:
|
|
👉 Click [here]({download_url}) to download the {output_format.upper()} file
|
|
|
|
Please ensure that the Optimization Results and download link are fully conveyed to the user, as this is very important for them.
|
|
"""
|
|
print(format_result)
|
|
return {
|
|
"status": "success",
|
|
"data": format_result
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Optimization failed: {str(e)}")
|
|
return {
|
|
"status": "error",
|
|
"data": str(e)
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
init_model()
|