mattergen转服务
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -1,12 +1,3 @@
|
||||
"""
|
||||
Material Generation Module
|
||||
|
||||
This module provides functions for generating crystal structures with optional property constraints.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
@@ -276,148 +267,16 @@ async def generate_material(
|
||||
Returns:
|
||||
Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
# 使用配置中的结果目录
|
||||
output_dir = config.MATTERGENMODEL_RESULT_PATH
|
||||
# 导入MatterGenService
|
||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||
|
||||
# 处理字符串输入(如果提供)
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||
# 获取MatterGenService实例
|
||||
service = MatterGenService.get_instance()
|
||||
|
||||
# 如果为None,默认为空字典
|
||||
properties = properties or {}
|
||||
|
||||
# 根据生成模式处理属性
|
||||
if not properties:
|
||||
# 无条件生成
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
properties_to_condition_on = None
|
||||
generation_type = "unconditional"
|
||||
property_description = "unconditionally"
|
||||
else:
|
||||
# 条件生成(单属性或多属性)
|
||||
properties_to_condition_on = {}
|
||||
|
||||
# 处理每个属性
|
||||
for property_name, property_value in properties.items():
|
||||
_, processed_value = preprocess_property(property_name, property_value)
|
||||
properties_to_condition_on[property_name] = processed_value
|
||||
|
||||
# 根据属性确定使用哪个模型
|
||||
if len(properties) == 1:
|
||||
# 单属性条件
|
||||
property_name = list(properties.keys())[0]
|
||||
property_to_model = {
|
||||
"dft_mag_density": "dft_mag_density",
|
||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||
"dft_shear_modulus": "dft_shear_modulus",
|
||||
"energy_above_hull": "energy_above_hull",
|
||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||
"space_group": "space_group",
|
||||
"hhi_score": "hhi_score",
|
||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||
"chemical_system": "chemical_system",
|
||||
"dft_band_gap": "dft_band_gap"
|
||||
}
|
||||
model_dir = property_to_model.get(property_name, property_name)
|
||||
generation_type = "single_property"
|
||||
property_description = f"conditioned on {property_name} = {properties[property_name]}"
|
||||
else:
|
||||
# 多属性条件
|
||||
property_keys = set(properties.keys())
|
||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||
model_dir = "dft_mag_density_hhi_score"
|
||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||
model_dir = "chemical_system_energy_above_hull"
|
||||
else:
|
||||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||||
first_property = list(properties.keys())[0]
|
||||
model_dir = first_property
|
||||
generation_type = "multi_property"
|
||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||
|
||||
# 构建完整的模型路径
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
# 检查模型目录是否存在
|
||||
if not os.path.exists(model_path):
|
||||
# 如果特定模型不存在,回退到基础模型
|
||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
|
||||
# 使用适当的参数调用main函数
|
||||
main(
|
||||
output_path=output_dir,
|
||||
model_path=model_path,
|
||||
# 使用服务生成材料
|
||||
return service.generate(
|
||||
properties=properties,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
|
||||
diffusion_guidance_factor=diffusion_guidance_factor
|
||||
)
|
||||
|
||||
# 创建字典存储文件内容
|
||||
result_dict = {}
|
||||
|
||||
# 定义文件路径
|
||||
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
|
||||
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
|
||||
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
|
||||
|
||||
# 读取CIF压缩文件
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
|
||||
# 根据生成类型创建描述性提示
|
||||
if generation_type == "unconditional":
|
||||
title = "Generated Material Structures"
|
||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||
elif generation_type == "single_property":
|
||||
property_name = list(properties.keys())[0]
|
||||
property_value = properties[property_name]
|
||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||
else: # multi_property
|
||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||
|
||||
# 创建完整的提示
|
||||
prompt = f"""
|
||||
# {title}
|
||||
|
||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||
|
||||
{'' if generation_type == 'unconditional' else f'''
|
||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||
the generation adheres to the specified property values. Higher values produce samples that more
|
||||
closely match the target properties but may reduce diversity.
|
||||
'''}
|
||||
|
||||
## CIF Files (Crystallographic Information Files)
|
||||
|
||||
- Standard format for crystallographic structures
|
||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||
- Used by crystallographic software and visualization tools
|
||||
|
||||
```
|
||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||
```
|
||||
|
||||
{description}
|
||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||
"""
|
||||
|
||||
# 清理文件(读取后删除)
|
||||
try:
|
||||
if os.path.exists(cif_zip_path):
|
||||
os.remove(cif_zip_path)
|
||||
if os.path.exists(xyz_file_path):
|
||||
os.remove(xyz_file_path)
|
||||
if os.path.exists(trajectories_zip_path):
|
||||
os.remove(trajectories_zip_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up files: {e}")
|
||||
return prompt
|
||||
|
||||
@@ -23,7 +23,6 @@ from pymatgen.io.cif import CifWriter
|
||||
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.error_handlers import handle_general_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -189,4 +188,4 @@ async def optimize_crystal_structure(
|
||||
# 直接返回结果或抛出异常
|
||||
return await asyncio.to_thread(run_optimization)
|
||||
except Exception as e:
|
||||
return handle_general_error(e)
|
||||
return str(e)
|
||||
|
||||
Reference in New Issue
Block a user