This commit is contained in:
2025-01-05 17:43:06 +08:00
parent f214f51e12
commit 5380ee5f9e
4 changed files with 169 additions and 64 deletions

View File

@@ -1,74 +1,165 @@
from fastapi import APIRouter, Body, Query
from fairchem.core import OCPCalculator
from ase.optimize import FIRE # Import your optimizer of choice
from ase.filters import FrechetCellFilter # to include cell relaxations
from ase.io import read
from pymatgen.core import Structure
from pymatgen.ext.matproj import MPRester
from pymatgen.analysis.phase_diagram import PhaseDiagram, PDEntry
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatgen.entries.mixing_scheme import MaterialsProjectDFTMixingScheme
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
import tempfile
import os
import boto3
from constant import FAIRCHEM_MODEL_PATH, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT, FMAX
from typing import Optional
import logging
import datetime
# 创建相图并计算形成能与 above hull energy 的函数
def calculate_phase_diagram_properties(structure, total_energy, api_key):
"""
计算化合物的形成能和 above hull energy
参数:
- formula (str): 化学式 (如 "CsPbBr3")
- total_energy (float): 化合物的总能量 (eV)
- mpr (MPRester): MPRester 实例
router = APIRouter(prefix="/fairchem", tags=["fairchem"])
logger = logging.getLogger(__name__)
# 初始化模型
calc = None
def init_model():
global calc
calc = OCPCalculator(checkpoint_path=FAIRCHEM_MODEL_PATH)
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""将输入内容转换为Atoms对象"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""生成对称性CIF"""
analyzer = SpacegroupAnalyzer(structure)
structure = analyzer.get_refined_structure()
返回:
- formation_energy (float): 每个原子的形成能 (eV/atom)
- e_above_hull (float): 每个原子的 above hull energy (eV/atom)
"""
chemsys = structure.chemical_system.split("-")
formula = structure.reduced_formula
with MPRester(api_key) as mpr:
# 获取化学系统中所有的相
# entries = mpr.get_entries_in_chemsys(elements=chemsys, additional_criteria={"thermo_types": ["GGA_GGA+U"]})
entries = mpr.get_entries_in_chemsys(elements=chemsys, additional_criteria={"thermo_types": ["GGA_GGA+U", "R2SCAN"]})
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
return tmp_file.read()
# 创建新计算结构的 PDEntry
pd_entry = PDEntry(composition=formula, energy=total_energy)
# entries.append(pd_entry)
def upload_to_minio(file_path: str, file_name: str) -> str:
"""上传文件到MinIO并返回预签名URL"""
try:
minio_client = boto3.client(
's3',
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
aws_access_key_id=MINIO_ACCESS_KEY,
aws_secret_access_key=MINIO_SECRET_KEY
)
bucket_name = MINIO_BUCKET
minio_client.upload_file(file_path, bucket_name, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名 URL
url = minio_client.generate_presigned_url(
'get_object',
Params={'Bucket': bucket_name, 'Key': file_name},
ExpiresIn=3600
)
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
except Exception as e:
logger.error(f"Failed to upload to MinIO: {str(e)}")
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
scheme = MaterialsProjectDFTMixingScheme()
entries = scheme.process_entries(entries)
from io import StringIO
import sys
def optimize_structure(atoms: Atoms, output_format: str):
"""优化晶体结构"""
atoms.calc = calc
# 创建相图
pd = PhaseDiagram(entries + [pd_entry])
# 捕获优化日志
old_stdout = sys.stdout
sys.stdout = log_capture = StringIO()
# 计算形成能和 above hull energy
formation_energy = pd.get_form_energy_per_atom(pd_entry)
e_above_hull = pd.get_e_above_hull(pd_entry)
try:
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=FMAX)
total_energy = atoms.get_total_energy()
optimization_log = log_capture.getvalue()
finally:
sys.stdout = old_stdout
# 处理对称性
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
return formation_energy, e_above_hull
# 保存优化结果到临时文件
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"optimized_structure_{timestamp}.{output_format}"
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
try:
# 上传到MinIO
url = upload_to_minio(tmp_path, file_name)
return total_energy, content, url, optimization_log
finally:
os.unlink(tmp_path)
@router.post("/optimize_structure")
async def optimize_structure_endpoint(
content: str = Body(..., description="Input structure content"),
input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"),
output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)")
):
# 转换输入结构
atoms = convert_structure(input_format, content)
if atoms is None:
return {
"status": "error",
"data": f"Invalid {input_format} content"
}
try:
# 优化结构
total_energy, optimized_content, download_url, optimization_log = optimize_structure(atoms, output_format)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
atoms = read("/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/CsPbBr3.cif") # Read in an atoms object or create your own structure
calc = OCPCalculator(checkpoint_path="/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/meta_fairchem/eqV2_86M_omat_mp_salex.pt") # Path to downloaded checkpoint
atoms.calc = calc
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=0.01)
### Optimization Results (using FIRE(eqV2_86M) algorithm):
```text
{optimization_log}
```
Finally, the Total Energy is: {total_energy} eV
Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link:
👉 Click [here]({download_url}) to download the {output_format.upper()} file
total_energy = atoms.get_potential_energy()
print("Predicted Total Energy: ", total_energy)
Please ensure that the Optimization Results and download link are fully conveyed to the user, as this is very important for them.
"""
print(format_result)
return {
"status": "success",
"data": format_result
}
except Exception as e:
logger.error(f"Optimization failed: {str(e)}")
return {
"status": "error",
"data": str(e)
}
# 保存优化后的结构
atoms.write("optimized_structure.cif") # 保存为 CIF 文件
print("Geometry optimization completed. Optimized structure saved as 'optimized_structure.cif'.")
# 从 ASE 转换为 Pymatgen 结构
optimized_structure = Structure.from_file("optimized_structure.cif")
api_key = "gfBp2in8qxm9Xm2SwLKFwNxDyZvNTAEt"
mpr = MPRester(api_key)
print(f"Chemical Formula: {optimized_structure .composition.reduced_formula}")
formation_energy, e_above_hull = calculate_phase_diagram_properties(
structure=optimized_structure,
total_energy=total_energy,
api_key=api_key
)
print(formation_energy, e_above_hull)
print()
if __name__ == "__main__":
init_model()