fairchem
This commit is contained in:
@@ -1,74 +1,165 @@
|
||||
from fastapi import APIRouter, Body, Query
|
||||
from fairchem.core import OCPCalculator
|
||||
from ase.optimize import FIRE # Import your optimizer of choice
|
||||
from ase.filters import FrechetCellFilter # to include cell relaxations
|
||||
from ase.io import read
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.ext.matproj import MPRester
|
||||
from pymatgen.analysis.phase_diagram import PhaseDiagram, PDEntry
|
||||
from pymatgen.entries.computed_entries import ComputedStructureEntry
|
||||
from pymatgen.entries.mixing_scheme import MaterialsProjectDFTMixingScheme
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from ase.io import read, write
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
import tempfile
|
||||
import os
|
||||
import boto3
|
||||
from constant import FAIRCHEM_MODEL_PATH, MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY, MINIO_BUCKET, INTERNEL_MINIO_ENDPOINT, FMAX
|
||||
from typing import Optional
|
||||
import logging
|
||||
import datetime
|
||||
|
||||
# 创建相图并计算形成能与 above hull energy 的函数
|
||||
def calculate_phase_diagram_properties(structure, total_energy, api_key):
|
||||
"""
|
||||
计算化合物的形成能和 above hull energy
|
||||
参数:
|
||||
- formula (str): 化学式 (如 "CsPbBr3")
|
||||
- total_energy (float): 化合物的总能量 (eV)
|
||||
- mpr (MPRester): MPRester 实例
|
||||
router = APIRouter(prefix="/fairchem", tags=["fairchem"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化模型
|
||||
calc = None
|
||||
|
||||
def init_model():
|
||||
global calc
|
||||
calc = OCPCalculator(checkpoint_path=FAIRCHEM_MODEL_PATH)
|
||||
|
||||
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
||||
"""将输入内容转换为Atoms对象"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
atoms = read(tmp_path)
|
||||
os.unlink(tmp_path)
|
||||
return atoms
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert structure: {str(e)}")
|
||||
return None
|
||||
|
||||
def generate_symmetry_cif(structure: Structure) -> str:
|
||||
"""生成对称性CIF"""
|
||||
analyzer = SpacegroupAnalyzer(structure)
|
||||
structure = analyzer.get_refined_structure()
|
||||
|
||||
返回:
|
||||
- formation_energy (float): 每个原子的形成能 (eV/atom)
|
||||
- e_above_hull (float): 每个原子的 above hull energy (eV/atom)
|
||||
"""
|
||||
chemsys = structure.chemical_system.split("-")
|
||||
formula = structure.reduced_formula
|
||||
with MPRester(api_key) as mpr:
|
||||
# 获取化学系统中所有的相
|
||||
# entries = mpr.get_entries_in_chemsys(elements=chemsys, additional_criteria={"thermo_types": ["GGA_GGA+U"]})
|
||||
entries = mpr.get_entries_in_chemsys(elements=chemsys, additional_criteria={"thermo_types": ["GGA_GGA+U", "R2SCAN"]})
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||
cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
|
||||
cif_writer.write_file(tmp_file.name)
|
||||
tmp_file.seek(0)
|
||||
return tmp_file.read()
|
||||
|
||||
# 创建新计算结构的 PDEntry
|
||||
pd_entry = PDEntry(composition=formula, energy=total_energy)
|
||||
# entries.append(pd_entry)
|
||||
def upload_to_minio(file_path: str, file_name: str) -> str:
|
||||
"""上传文件到MinIO并返回预签名URL"""
|
||||
try:
|
||||
minio_client = boto3.client(
|
||||
's3',
|
||||
endpoint_url=MINIO_ENDPOINT if INTERNEL_MINIO_ENDPOINT == "" else INTERNEL_MINIO_ENDPOINT,
|
||||
aws_access_key_id=MINIO_ACCESS_KEY,
|
||||
aws_secret_access_key=MINIO_SECRET_KEY
|
||||
)
|
||||
|
||||
bucket_name = MINIO_BUCKET
|
||||
minio_client.upload_file(file_path, bucket_name, file_name, ExtraArgs={"ACL": "private"})
|
||||
|
||||
# 生成预签名 URL
|
||||
url = minio_client.generate_presigned_url(
|
||||
'get_object',
|
||||
Params={'Bucket': bucket_name, 'Key': file_name},
|
||||
ExpiresIn=3600
|
||||
)
|
||||
return url.replace(INTERNEL_MINIO_ENDPOINT, MINIO_ENDPOINT)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload to MinIO: {str(e)}")
|
||||
raise RuntimeError(f"MinIO upload failed: {str(e)}") from e
|
||||
|
||||
scheme = MaterialsProjectDFTMixingScheme()
|
||||
entries = scheme.process_entries(entries)
|
||||
from io import StringIO
|
||||
import sys
|
||||
|
||||
def optimize_structure(atoms: Atoms, output_format: str):
|
||||
"""优化晶体结构"""
|
||||
atoms.calc = calc
|
||||
|
||||
# 创建相图
|
||||
pd = PhaseDiagram(entries + [pd_entry])
|
||||
# 捕获优化日志
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = log_capture = StringIO()
|
||||
|
||||
# 计算形成能和 above hull energy
|
||||
formation_energy = pd.get_form_energy_per_atom(pd_entry)
|
||||
e_above_hull = pd.get_e_above_hull(pd_entry)
|
||||
try:
|
||||
dyn = FIRE(FrechetCellFilter(atoms))
|
||||
dyn.run(fmax=FMAX)
|
||||
total_energy = atoms.get_total_energy()
|
||||
optimization_log = log_capture.getvalue()
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
|
||||
# 处理对称性
|
||||
if output_format == "cif":
|
||||
optimized_structure = Structure.from_ase_atoms(atoms)
|
||||
content = generate_symmetry_cif(optimized_structure)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||||
write(tmp_file.name, atoms)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
|
||||
return formation_energy, e_above_hull
|
||||
# 保存优化结果到临时文件
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
file_name = f"optimized_structure_{timestamp}.{output_format}"
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
try:
|
||||
# 上传到MinIO
|
||||
url = upload_to_minio(tmp_path, file_name)
|
||||
return total_energy, content, url, optimization_log
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@router.post("/optimize_structure")
|
||||
async def optimize_structure_endpoint(
|
||||
content: str = Body(..., description="Input structure content"),
|
||||
input_format: str = Query("cif", description="Input format (cif, poscar, json, xyz)"),
|
||||
output_format: str = Query("cif", description="Output format (cif, poscar, json, xyz)")
|
||||
):
|
||||
# 转换输入结构
|
||||
atoms = convert_structure(input_format, content)
|
||||
if atoms is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"Invalid {input_format} content"
|
||||
}
|
||||
|
||||
try:
|
||||
# 优化结构
|
||||
total_energy, optimized_content, download_url, optimization_log = optimize_structure(atoms, output_format)
|
||||
|
||||
# 格式化返回结果
|
||||
format_result = f"""
|
||||
The following is the optimized crystal structure information:
|
||||
|
||||
atoms = read("/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/CsPbBr3.cif") # Read in an atoms object or create your own structure
|
||||
calc = OCPCalculator(checkpoint_path="/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/meta_fairchem/eqV2_86M_omat_mp_salex.pt") # Path to downloaded checkpoint
|
||||
atoms.calc = calc
|
||||
dyn = FIRE(FrechetCellFilter(atoms))
|
||||
dyn.run(fmax=0.01)
|
||||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||
```text
|
||||
{optimization_log}
|
||||
```
|
||||
Finally, the Total Energy is: {total_energy} eV
|
||||
Due to length limitations, the complete {output_format.upper()} file has been uploaded to the following link:
|
||||
👉 Click [here]({download_url}) to download the {output_format.upper()} file
|
||||
|
||||
total_energy = atoms.get_potential_energy()
|
||||
print("Predicted Total Energy: ", total_energy)
|
||||
Please ensure that the Optimization Results and download link are fully conveyed to the user, as this is very important for them.
|
||||
"""
|
||||
print(format_result)
|
||||
return {
|
||||
"status": "success",
|
||||
"data": format_result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Optimization failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": str(e)
|
||||
}
|
||||
|
||||
# 保存优化后的结构
|
||||
atoms.write("optimized_structure.cif") # 保存为 CIF 文件
|
||||
print("Geometry optimization completed. Optimized structure saved as 'optimized_structure.cif'.")
|
||||
|
||||
# 从 ASE 转换为 Pymatgen 结构
|
||||
optimized_structure = Structure.from_file("optimized_structure.cif")
|
||||
|
||||
api_key = "gfBp2in8qxm9Xm2SwLKFwNxDyZvNTAEt"
|
||||
mpr = MPRester(api_key)
|
||||
print(f"Chemical Formula: {optimized_structure .composition.reduced_formula}")
|
||||
formation_energy, e_above_hull = calculate_phase_diagram_properties(
|
||||
structure=optimized_structure,
|
||||
total_energy=total_energy,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
print(formation_energy, e_above_hull)
|
||||
print()
|
||||
if __name__ == "__main__":
|
||||
init_model()
|
||||
|
||||
Reference in New Issue
Block a user