434 lines
17 KiB
Python
434 lines
17 KiB
Python
"""
|
|
Materials Project Query Module
|
|
|
|
This module provides functions for querying the Materials Project database,
|
|
processing search results, and formatting responses.
|
|
|
|
Author: Yutang LI
|
|
Institution: SIAT-MIC
|
|
Contact: yt.li2@siat.ac.cn
|
|
"""
|
|
|
|
import glob
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
import datetime
|
|
import os
|
|
from multiprocessing import Process, Manager
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
from mp_api.client import MPRester
|
|
from pymatgen.core import Structure
|
|
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
|
from pymatgen.io.cif import CifWriter
|
|
|
|
from mars_toolkit.core.llm_tools import llm_tool
|
|
from mars_toolkit.core.config import config
|
|
from mars_toolkit.core.error_handlers import handle_general_error
|
|
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def parse_bool(param: str) -> bool | None:
|
|
"""
|
|
Parse a string parameter into a boolean value.
|
|
|
|
Args:
|
|
param: String parameter to parse (e.g., "true", "false")
|
|
|
|
Returns:
|
|
Boolean value if param is not empty, None otherwise
|
|
"""
|
|
if not param:
|
|
return None
|
|
return param.lower() == 'true'
|
|
|
|
def parse_list(param: str) -> List[str] | None:
|
|
"""
|
|
Parse a comma-separated string into a list of strings.
|
|
|
|
Args:
|
|
param: Comma-separated string (e.g., "Li,Fe,O")
|
|
|
|
Returns:
|
|
List of strings if param is not empty, None otherwise
|
|
"""
|
|
if not param:
|
|
return None
|
|
return param.split(',')
|
|
|
|
def parse_tuple(param: str) -> tuple[float, float] | None:
|
|
"""
|
|
Parse a comma-separated string into a tuple of two float values.
|
|
|
|
Used for range parameters like band_gap, density, etc.
|
|
|
|
Args:
|
|
param: Comma-separated string of two numbers (e.g., "0,3.5")
|
|
|
|
Returns:
|
|
Tuple of two float values if param is valid, None otherwise
|
|
"""
|
|
if not param:
|
|
return None
|
|
try:
|
|
values = param.split(',')
|
|
return (float(values[0]), float(values[1]))
|
|
except (ValueError, IndexError):
|
|
return None
|
|
|
|
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
|
|
"""
|
|
Parse search parameters from query parameters.
|
|
|
|
Converts string query parameters into appropriate types for the Materials Project API.
|
|
"""
|
|
return {
|
|
'band_gap': parse_tuple(query_params.get('band_gap')),
|
|
'chemsys': parse_list(query_params.get('chemsys')),
|
|
'crystal_system': parse_list(query_params.get('crystal_system')),
|
|
'density': parse_tuple(query_params.get('density')),
|
|
'formation_energy': parse_tuple(query_params.get('formation_energy')),
|
|
'elements': parse_list(query_params.get('elements')),
|
|
'exclude_elements': parse_list(query_params.get('exclude_elements')),
|
|
'formula': parse_list(query_params.get('formula')),
|
|
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
|
|
'is_metal': parse_bool(query_params.get('is_metal')),
|
|
'is_stable': parse_bool(query_params.get('is_stable')),
|
|
'magnetic_ordering': query_params.get('magnetic_ordering'),
|
|
'material_ids': parse_list(query_params.get('material_ids')),
|
|
'total_energy': parse_tuple(query_params.get('total_energy')),
|
|
'num_elements': parse_tuple(query_params.get('num_elements')),
|
|
'volume': parse_tuple(query_params.get('volume')),
|
|
'chunk_size': int(query_params.get('chunk_size', '5'))
|
|
}
|
|
|
|
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]] | str:
|
|
"""
|
|
Process search results from the Materials Project API.
|
|
|
|
Extracts relevant fields from each document and formats them into a consistent structure.
|
|
|
|
Returns:
|
|
List of processed documents or error message string if an exception occurs
|
|
"""
|
|
try:
|
|
fields = [
|
|
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
|
|
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
|
|
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
|
|
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
|
|
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
|
|
'universal_anisotropy', 'theoretical'
|
|
]
|
|
|
|
res = []
|
|
for doc in docs:
|
|
try:
|
|
new_docs = {}
|
|
for field_name in fields:
|
|
new_docs[field_name] = doc.get(field_name, '')
|
|
res.append(new_docs)
|
|
except Exception as e:
|
|
logger.warning(f"Error processing document: {str(e)}")
|
|
continue
|
|
return res
|
|
except Exception as e:
|
|
error_msg = f"Error in process_search_results: {str(e)}"
|
|
logger.error(error_msg)
|
|
return error_msg
|
|
|
|
def _search_worker(queue, api_key, **kwargs):
|
|
"""
|
|
Worker function for executing Materials Project API searches.
|
|
|
|
Runs in a separate process to perform the actual API call and puts results in the queue.
|
|
|
|
Args:
|
|
queue: Multiprocessing queue for returning results
|
|
api_key: Materials Project API key
|
|
**kwargs: Search parameters to pass to the API
|
|
"""
|
|
try:
|
|
import os
|
|
import traceback
|
|
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
|
|
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
|
|
|
|
# 初始化 MPRester 客户端
|
|
with MPRester(api_key) as mpr:
|
|
result = mpr.materials.summary.search(**kwargs)
|
|
|
|
# 检查结果
|
|
if result:
|
|
# 尝试使用更安全的方式处理结果
|
|
processed_results = []
|
|
for doc in result:
|
|
try:
|
|
# 尝试使用 model_dump 方法
|
|
processed_doc = doc.model_dump()
|
|
processed_results.append(processed_doc)
|
|
except AttributeError:
|
|
# 如果没有 model_dump 方法,尝试使用 dict 方法
|
|
try:
|
|
processed_doc = doc.dict()
|
|
processed_results.append(processed_doc)
|
|
except AttributeError:
|
|
# 如果没有 dict 方法,尝试直接转换为字典
|
|
if hasattr(doc, "__dict__"):
|
|
processed_doc = doc.__dict__
|
|
# 移除可能导致序列化问题的特殊属性
|
|
if "_sa_instance_state" in processed_doc:
|
|
del processed_doc["_sa_instance_state"]
|
|
processed_results.append(processed_doc)
|
|
else:
|
|
# 最后的尝试,直接使用 doc
|
|
processed_results.append(doc)
|
|
|
|
queue.put(processed_results)
|
|
else:
|
|
queue.put([])
|
|
except Exception as e:
|
|
queue.put(e)
|
|
|
|
async def execute_search(search_args: Dict[str, Any], timeout: int = 120) -> List[Dict[str, Any]] | str:
|
|
"""
|
|
Execute a search against the Materials Project API.
|
|
|
|
Runs the search in a separate process to handle potential timeouts and returns the results.
|
|
|
|
Args:
|
|
search_args: Dictionary of search parameters
|
|
timeout: Maximum time in seconds to wait for the search to complete
|
|
|
|
Returns:
|
|
List of document dictionaries from the search results or error message string if an exception occurs
|
|
"""
|
|
# 确保 formula 参数是列表类型
|
|
if 'formula' in search_args and isinstance(search_args['formula'], str):
|
|
search_args['formula'] = [search_args['formula']]
|
|
|
|
manager = Manager()
|
|
queue = manager.Queue()
|
|
|
|
try:
|
|
p = Process(target=_search_worker, args=(queue, config.MP_API_KEY), kwargs=search_args)
|
|
p.start()
|
|
p.join(timeout=timeout)
|
|
|
|
if p.is_alive():
|
|
logger.warning(f"Terminating worker process {p.pid} due to timeout")
|
|
p.terminate()
|
|
p.join()
|
|
error_msg = f"Request timed out after {timeout} seconds"
|
|
return error_msg
|
|
|
|
try:
|
|
result = queue.get(timeout=timeout)
|
|
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Error in search worker: {str(result)}")
|
|
if hasattr(result, "__traceback__"):
|
|
import traceback
|
|
tb_str = ''.join(traceback.format_exception(None, result, result.__traceback__))
|
|
return f"Error in search worker: {str(result)}"
|
|
|
|
return result
|
|
|
|
except queue.Empty:
|
|
error_msg = "Failed to retrieve data from queue (timeout)"
|
|
logger.error(error_msg)
|
|
return error_msg
|
|
except Exception as e:
|
|
error_msg = f"Error in execute_search: {str(e)}"
|
|
logger.error(error_msg)
|
|
return error_msg
|
|
|
|
@llm_tool(name="search_material_property_from_material_project", description="Search materials in Materials Project database by formula and properties")
|
|
async def search_material_property_from_material_project(
|
|
formula: str | list[str],
|
|
chemsys: Optional[str | list[str] | None] = None,
|
|
crystal_system: Optional[str | list[str] | None] = None,
|
|
is_gap_direct: Optional[bool | None] = None,
|
|
is_stable: Optional[bool | None] = None,
|
|
) -> str:
|
|
"""
|
|
Search materials in Materials Project database.
|
|
|
|
Args:
|
|
formula: Chemical formula(s) (e.g., "Fe2O3" or ["ABO3", "Si*"])
|
|
chemsys: Chemical system(s) (e.g., "Li-Fe-O")
|
|
crystal_system: Crystal system(s) (e.g., "Cubic")
|
|
is_gap_direct: Filter for direct band gap materials
|
|
is_stable: Filter for thermodynamically stable materials
|
|
Returns:
|
|
JSON formatted material properties data
|
|
"""
|
|
# 验证晶系参数
|
|
VALID_CRYSTAL_SYSTEMS = ['Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal', 'Cubic']
|
|
|
|
# 验证晶系参数是否有效
|
|
if crystal_system is not None:
|
|
if isinstance(crystal_system, str):
|
|
if crystal_system not in VALID_CRYSTAL_SYSTEMS:
|
|
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
|
|
elif isinstance(crystal_system, list):
|
|
for cs in crystal_system:
|
|
if cs not in VALID_CRYSTAL_SYSTEMS:
|
|
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
|
|
|
|
# 确保 formula 是列表类型
|
|
if isinstance(formula, str):
|
|
formula = [formula]
|
|
|
|
params = {
|
|
"chemsys": chemsys,
|
|
"crystal_system": crystal_system,
|
|
"formula": formula,
|
|
"is_gap_direct": is_gap_direct,
|
|
"is_stable": is_stable,
|
|
"chunk_size": 5,
|
|
}
|
|
|
|
# Filter out None values
|
|
params = {k: v for k, v in params.items() if v is not None}
|
|
mp_id_list = await get_mpid_from_formula(formula=formula)
|
|
try:
|
|
res=[]
|
|
for mp_id in mp_id_list:
|
|
crystal_props = extract_cif_info(config.LOCAL_MP_ROOT+f"/Props/{mp_id}.json", ['all_fields'])
|
|
res.append(crystal_props)
|
|
|
|
if len(res) == 0:
|
|
return "No results found, please try again."
|
|
|
|
# Format response with top results
|
|
try:
|
|
# 创建包含索引的JSON结果
|
|
formatted_results = []
|
|
for i, item in enumerate(res[:config.MP_TOPK], 1):
|
|
formatted_result = f"[property {i} begin]\n"
|
|
formatted_result += json.dumps(item, indent=2)
|
|
formatted_result += f"\n[property {i} end]\n\n"
|
|
formatted_results.append(formatted_result)
|
|
|
|
# 将所有结果合并为一个字符串
|
|
res_chunk = "\n\n".join(formatted_results)
|
|
res_template = f"""
|
|
Here are the search results from the Materials Project database:
|
|
Due to length limitations, only the top {config.MP_TOPK} results are shown below:\n
|
|
{res_chunk}
|
|
If you need more results, please modify your search criteria or try different query parameters.
|
|
"""
|
|
return res_template
|
|
except Exception as format_error:
|
|
logger.error(f"Error formatting results: {str(format_error)}")
|
|
return str(format_error)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in search_material_property_from_material_project: {str(e)}")
|
|
return str(e)
|
|
|
|
@llm_tool(name="get_crystal_structures_from_materials_project", description="Get symmetrized crystal structures CIF data from Materials Project database by chemical formula")
|
|
async def get_crystal_structures_from_materials_project(
|
|
formulas: list[str],
|
|
conventional_unit_cell: bool = True,
|
|
symprec: float = 0.1
|
|
) -> str:
|
|
"""
|
|
Get crystal structures from Materials Project database by chemical formula and apply symmetrization.
|
|
|
|
Args:
|
|
formulas: List of chemical formulas (e.g., ["Fe2O3", "SiO2", "TiO2"])
|
|
conventional_unit_cell: Whether to return conventional unit cell (True) or primitive cell (False)
|
|
symprec: Precision parameter for symmetrization
|
|
|
|
Returns:
|
|
Formatted text containing symmetrized CIF data
|
|
"""
|
|
result={}
|
|
mp_id_list=await get_mpid_from_formula(formula=formulas)
|
|
|
|
for i,mp_id in enumerate(mp_id_list):
|
|
cif_file = glob.glob(config.LOCAL_MP_ROOT+f"/MPDatasets/{mp_id}.cif")[0]
|
|
structure = Structure.from_file(cif_file)
|
|
# 如果需要常规单元格
|
|
if conventional_unit_cell:
|
|
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
|
|
|
|
# 对结构进行对称化处理
|
|
sga = SpacegroupAnalyzer(structure, symprec=symprec)
|
|
symmetrized_structure = sga.get_refined_structure()
|
|
|
|
# 使用CifWriter生成CIF数据
|
|
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
|
|
cif_data = str(cif_writer)
|
|
|
|
# 删除CIF文件中的对称性操作部分
|
|
cif_data = remove_symmetry_equiv_xyz(cif_data)
|
|
cif_data=cif_data.replace('# generated using pymatgen',"")
|
|
# 生成一个唯一的键
|
|
formula = structure.composition.reduced_formula
|
|
key = f"{formula}_{i}"
|
|
|
|
result[key] = cif_data
|
|
|
|
# 只保留前config.MP_TOPK个结果
|
|
if len(result) >= config.MP_TOPK:
|
|
break
|
|
|
|
try:
|
|
prompt = f"""
|
|
# Materials Project Symmetrized Crystal Structure Data
|
|
|
|
Below are symmetrized crystal structure data for {len(result)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
|
|
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
|
|
"""
|
|
|
|
for i, (key, cif_data) in enumerate(result.items(), 1):
|
|
prompt += f"[cif {i} begin]\n"
|
|
prompt += cif_data
|
|
prompt += f"\n[cif {i} end]\n\n"
|
|
|
|
prompt += """
|
|
## Usage Instructions
|
|
|
|
1. You can copy the above CIF data and save it as .cif files
|
|
2. Open these files with crystal structure visualization software (such as VESTA, Mercury, Avogadro, etc.)
|
|
3. These structures can be used for further material analysis, simulation, or visualization
|
|
|
|
CIF files contain complete structural information of crystals, including cell parameters, atomic coordinates, symmetry, etc.
|
|
Symmetrization helps identify and optimize crystal symmetry, making the structure more standardized and accurate.
|
|
"""
|
|
return prompt
|
|
|
|
except Exception as format_error:
|
|
logger.error(f"Error formatting crystal structures: {str(format_error)}")
|
|
return str(format_error)
|
|
|
|
@llm_tool(name="get_mpid_from_formula", description="Get material IDs (mpid) from Materials Project database by chemical formula")
|
|
async def get_mpid_from_formula(formula: str) -> List[str]:
|
|
"""
|
|
Get material IDs (mpid) from Materials Project database by chemical formula.
|
|
Returns mpids for the lowest energy structures.
|
|
|
|
Args:
|
|
formula: Chemical formula (e.g., "Fe2O3")
|
|
|
|
Returns:
|
|
List of material IDs
|
|
"""
|
|
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
|
|
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
|
|
id_list = []
|
|
try:
|
|
with MPRester(config.MP_API_KEY) as mpr:
|
|
docs = mpr.materials.summary.search(formula=formula)
|
|
for doc in docs:
|
|
id_list.append(doc.material_id)
|
|
return id_list
|
|
except Exception as e:
|
|
logger.error(f"Error getting mpid from formula: {str(e)}")
|
|
return []
|