构建mars_toolkit,删除tools_for_ms

This commit is contained in:
lzy
2025-04-02 12:53:50 +08:00
parent 603304e10f
commit a77c2cd377
73 changed files with 1884 additions and 896 deletions

46
mars_toolkit/__init__.py Normal file
View File

@@ -0,0 +1,46 @@
"""
Mars Toolkit
A comprehensive toolkit for materials science research, providing tools for:
- Material generation and property prediction
- Structure optimization
- Database queries (Materials Project, OQMD)
- Knowledge base retrieval
- Web search
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
# Core modules
from mars_toolkit.core.config import config
from mars_toolkit.core.utils import setup_logging
# Basic tools
from mars_toolkit.misc.misc_tools import get_current_time
# Compute modules
from mars_toolkit.compute.material_gen import generate_material
from mars_toolkit.compute.property_pred import predict_properties
from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
# Query modules
from mars_toolkit.query.mp_query import (
search_material_property_from_material_project,
get_crystal_structures_from_materials_project,
get_mpid_from_formula
)
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
from mars_toolkit.query.web_search import search_online
from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
# Initialize logging
setup_logging()
__version__ = "0.1.0"
__all__ = ["llm_tool", "get_tools", "get_tool_schemas"]

Binary file not shown.

View File

@@ -0,0 +1,12 @@
"""
Compute Module
This module provides computational tools for materials science, including:
- Material generation
- Property prediction
- Structure optimization
"""
from mars_toolkit.compute.material_gen import generate_material
from mars_toolkit.compute.property_pred import predict_properties
from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure

View File

@@ -0,0 +1,423 @@
"""
Material Generation Module
This module provides functions for generating crystal structures with optional property constraints.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import ast
import json
import logging
import tempfile
import os
import datetime
import asyncio
import zipfile
import shutil
import re
from pathlib import Path
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
# 导入路径已更新
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
# 使用mattergen_wrapper
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from mattergen_wrapper import generator
CrystalGenerator = generator.CrystalGenerator
from mattergen.common.data.types import TargetProperty
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
from mattergen.common.utils.data_classes import (
PRETRAINED_MODEL_NAME,
MatterGenCheckpointInfo,
)
logger = logging.getLogger(__name__)
def format_cif_content(content):
"""
Format CIF content by removing unnecessary headers and organizing each CIF file.
Args:
content: String containing CIF content, possibly with PK headers
Returns:
Formatted string with each CIF file properly labeled and formatted
"""
# 如果内容为空,直接返回空字符串
if not content or content.strip() == '':
return ''
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
# 但我们需要保留这个字段在每个CIF文件中
cif_blocks = []
# 查找所有_chemical_formula_structural的位置
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
# 如果没有找到任何_chemical_formula_structural返回空字符串
if not formula_positions:
return ''
# 分割CIF块
for i in range(len(formula_positions)):
start_pos = formula_positions[i]
# 如果是最后一个块,结束位置是字符串末尾
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
cif_block = content[start_pos:end_pos].strip()
# 提取formula值
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
if formula_match:
formula = formula_match.group(1)
cif_blocks.append((formula, cif_block))
# 格式化输出
result = []
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]"
result.append(formatted)
return "\n".join(result)
def convert_values(data_str):
"""
将字符串转换为字典
Args:
data_str: JSON字符串
Returns:
解析后的数据,如果解析失败则返回原字符串
"""
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return data_str # 如果无法解析为JSON返回原字符串
return data
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
"""
Preprocess a property value based on its name, converting it to the appropriate type.
Args:
property_name: Name of the property
property_value: Value of the property (can be string, float, or int)
Returns:
Tuple of (property_name, processed_value)
Raises:
ValueError: If the property value is invalid for the given property name
"""
valid_properties = [
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
"energy_above_hull", "formation_energy_per_atom", "space_group",
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
]
if property_name not in valid_properties:
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
# Process property_value if it's a string
if isinstance(property_value, str):
try:
# Try to convert string to float for numeric properties
if property_name != "chemical_system":
property_value = float(property_value)
except ValueError:
# If conversion fails, keep as string (for chemical_system)
pass
# Handle special cases for properties that need specific types
if property_name == "chemical_system":
if isinstance(property_value, (int, float)):
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
property_value = str(property_value)
elif property_name == "space_group" :
space_group = property_value
if space_group < 1 or space_group > 230:
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
return property_name, property_value
def main(
output_path: str,
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
model_path: str | None = None,
batch_size: int = 2,
num_batches: int = 1,
config_overrides: list[str] | None = None,
checkpoint_epoch: Literal["best", "last"] | int = "last",
properties_to_condition_on: TargetProperty | None = None,
sampling_config_path: str | None = None,
sampling_config_name: str = "default",
sampling_config_overrides: list[str] | None = None,
record_trajectories: bool = True,
diffusion_guidance_factor: float | None = None,
strict_checkpoint_loading: bool = True,
target_compositions: list[dict[str, int]] | None = None,
):
"""
Evaluate diffusion model against molecular metrics.
Args:
model_path: Path to DiffusionLightningModule checkpoint directory.
output_path: Path to output directory.
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
record: Whether to record the trajectories of the generated structures. (default: True)
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
Only supported for models trained for crystal structure prediction (CSP) (default: None)
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
"""
assert (
pretrained_name is not None or model_path is not None
), "Either pretrained_name or model_path must be provided."
assert (
pretrained_name is None or model_path is None
), "Only one of pretrained_name or model_path can be provided."
if not os.path.exists(output_path):
os.makedirs(output_path)
sampling_config_overrides = sampling_config_overrides or []
config_overrides = config_overrides or []
properties_to_condition_on = properties_to_condition_on or {}
target_compositions = target_compositions or []
if pretrained_name is not None:
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
pretrained_name, config_overrides=config_overrides
)
else:
checkpoint_info = MatterGenCheckpointInfo(
model_path=Path(model_path).resolve(),
load_epoch=checkpoint_epoch,
config_overrides=config_overrides,
strict_checkpoint_loading=strict_checkpoint_loading,
)
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
generator = CrystalGenerator(
checkpoint_info=checkpoint_info,
properties_to_condition_on=properties_to_condition_on,
batch_size=batch_size,
num_batches=num_batches,
sampling_config_name=sampling_config_name,
sampling_config_path=_sampling_config_path,
sampling_config_overrides=sampling_config_overrides,
record_trajectories=record_trajectories,
diffusion_guidance_factor=(
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
),
target_compositions_dict=target_compositions,
)
generator.generate(output_dir=Path(output_path))
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
async def generate_material(
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
batch_size: int = 2,
num_batches: int = 1,
diffusion_guidance_factor: float = 2.0
) -> str:
"""
Generate crystal structures with optional property constraints.
This unified function can generate materials in three modes:
1. Unconditional generation (no properties specified)
2. Single property conditional generation (one property specified)
3. Multi-property conditional generation (multiple properties specified)
Args:
properties: Optional property constraints. Can be:
- None or empty dict for unconditional generation
- Dict with single key-value pair for single property conditioning
- Dict with multiple key-value pairs for multi-property conditioning
Valid property names include: "dft_band_gap", "chemical_system", etc.
batch_size: Number of structures per batch
num_batches: Number of batches to generate
diffusion_guidance_factor: Controls adherence to target properties
Returns:
Descriptive text with generated crystal structures in CIF format
"""
# 使用配置中的结果目录
output_dir = config.MATTERGENMODEL_RESULT_PATH
# 处理字符串输入(如果提供)
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError:
raise ValueError(f"Invalid properties JSON string: {properties}")
# 如果为None默认为空字典
properties = properties or {}
# 根据生成模式处理属性
if not properties:
# 无条件生成
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
properties_to_condition_on = None
generation_type = "unconditional"
property_description = "unconditionally"
else:
# 条件生成(单属性或多属性)
properties_to_condition_on = {}
# 处理每个属性
for property_name, property_value in properties.items():
_, processed_value = preprocess_property(property_name, property_value)
properties_to_condition_on[property_name] = processed_value
# 根据属性确定使用哪个模型
if len(properties) == 1:
# 单属性条件
property_name = list(properties.keys())[0]
property_to_model = {
"dft_mag_density": "dft_mag_density",
"dft_bulk_modulus": "dft_bulk_modulus",
"dft_shear_modulus": "dft_shear_modulus",
"energy_above_hull": "energy_above_hull",
"formation_energy_per_atom": "formation_energy_per_atom",
"space_group": "space_group",
"hhi_score": "hhi_score",
"ml_bulk_modulus": "ml_bulk_modulus",
"chemical_system": "chemical_system",
"dft_band_gap": "dft_band_gap"
}
model_dir = property_to_model.get(property_name, property_name)
generation_type = "single_property"
property_description = f"conditioned on {property_name} = {properties[property_name]}"
else:
# 多属性条件
property_keys = set(properties.keys())
if property_keys == {"dft_mag_density", "hhi_score"}:
model_dir = "dft_mag_density_hhi_score"
elif property_keys == {"chemical_system", "energy_above_hull"}:
model_dir = "chemical_system_energy_above_hull"
else:
# 如果没有特定的多属性模型,使用第一个属性的模型
first_property = list(properties.keys())[0]
model_dir = first_property
generation_type = "multi_property"
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
# 构建完整的模型路径
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
# 检查模型目录是否存在
if not os.path.exists(model_path):
# 如果特定模型不存在,回退到基础模型
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
# 使用适当的参数调用main函数
main(
output_path=output_dir,
model_path=model_path,
batch_size=batch_size,
num_batches=num_batches,
properties_to_condition_on=properties_to_condition_on,
record_trajectories=True,
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
)
# 创建字典存储文件内容
result_dict = {}
# 定义文件路径
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
# 读取CIF压缩文件
if os.path.exists(cif_zip_path):
with open(cif_zip_path, 'rb') as f:
result_dict['cif_content'] = f.read()
# 根据生成类型创建描述性提示
if generation_type == "unconditional":
title = "Generated Material Structures"
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
elif generation_type == "single_property":
property_name = list(properties.keys())[0]
property_value = properties[property_name]
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
else: # multi_property
title = "Generated Material Structures Conditioned on Multiple Properties"
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
# 创建完整的提示
prompt = f"""
# {title}
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
{'' if generation_type == 'unconditional' else f'''
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
the generation adheres to the specified property values. Higher values produce samples that more
closely match the target properties but may reduce diversity.
'''}
## CIF Files (Crystallographic Information Files)
- Standard format for crystallographic structures
- Contains unit cell parameters, atomic positions, and symmetry information
- Used by crystallographic software and visualization tools
```
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
```
{description}
You can use these structures for materials discovery, property prediction, or further analysis.
"""
# 清理文件(读取后删除)
try:
if os.path.exists(cif_zip_path):
os.remove(cif_zip_path)
if os.path.exists(xyz_file_path):
os.remove(xyz_file_path)
if os.path.exists(trajectories_zip_path):
os.remove(trajectories_zip_path)
except Exception as e:
logger.warning(f"Error cleaning up files: {e}")
return prompt

View File

@@ -0,0 +1,72 @@
"""
Property Prediction Module
This module provides functions for predicting properties of crystal structures.
"""
import asyncio
import torch
import numpy as np
from ase.units import GPa
from mattersim.forcefield import MatterSimCalculator
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.compute.structure_opt import convert_structure
@llm_tool(
name="predict_properties",
description="Predict energy, forces, and stress of crystal structures based on CIF string",
)
async def predict_properties(cif_content: str) -> str:
"""
Use MatterSim to predict energy, forces, and stress of crystal structures.
Args:
cif_content: Crystal structure string in CIF format
Returns:
String containing prediction results
"""
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_prediction():
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
structure = convert_structure("cif", cif_content)
if structure is None:
return "Unable to parse CIF string. Please check if the format is correct."
# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
# 使用 MatterSimCalculator 计算属性
structure.calc = MatterSimCalculator(device=device)
# 直接获取能量、力和应力
energy = structure.get_potential_energy()
forces = structure.get_forces()
stresses = structure.get_stress(voigt=False)
# 计算每原子能量
num_atoms = len(structure)
energy_per_atom = energy / num_atoms
# 计算应力GPa和eV/A^3格式
stresses_ev_a3 = stresses
stresses_gpa = stresses / GPa
# 构建返回的提示信息
prompt = f"""
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
Prediction results using the provided CIF structure:
- Total Energy (eV): {energy}
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
"""
return prompt
# 异步执行预测操作
return await asyncio.to_thread(run_prediction)

View File

@@ -0,0 +1,192 @@
"""
Structure Optimization Module
This module provides functions for optimizing crystal structures using the FairChem model.
"""
import asyncio
from io import StringIO
import sys
import tempfile
import os
import logging
from typing import Optional, Dict, Any
from ase.io import read, write
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
from mars_toolkit.core.error_handlers import handle_general_error
logger = logging.getLogger(__name__)
# 初始化FairChem模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
if calc is not None:
return
try:
from fairchem.core import OCPCalculator
calc = OCPCalculator(checkpoint_path=config.FAIRCHEM_MODEL_PATH)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""
将输入内容转换为Atoms对象
Args:
input_format: 输入格式 (cif, xyz, vasp等)
content: 结构内容字符串
Returns:
ASE Atoms对象如果转换失败则返回None
"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""
生成对称性CIF
Args:
structure: Pymatgen Structure对象
Returns:
CIF格式的字符串
"""
analyzer = SpacegroupAnalyzer(structure)
structure_refined = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
return content
def optimize_structure(atoms: Atoms, output_format: str) -> str:
"""
优化晶体结构
Args:
atoms: ASE Atoms对象
output_format: 输出格式 (cif, xyz, vasp等)
Returns:
包含优化结果的格式化字符串
"""
atoms.calc = calc
try:
# 捕获优化过程的输出
temp_output = StringIO()
original_stdout = sys.stdout
sys.stdout = temp_output
# 执行优化
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=config.FMAX)
# 恢复标准输出并获取日志
sys.stdout = original_stdout
optimization_log = temp_output.getvalue()
temp_output.close()
# 获取总能量
total_energy = atoms.get_potential_energy()
# 处理优化后的结构
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
content = remove_symmetry_equiv_xyz(content)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
**Total Energy: {total_energy} eV**
#### Optimizing Log:
```text
{optimization_log}
```
### Optimized {output_format.upper()} Content:
```
{content}
```
"""
return format_result
except Exception as e:
logger.error(f"Failed to optimize structure: {str(e)}")
raise e
@llm_tool(name="optimize_crystal_structure",
description="Optimize crystal structure using FairChem model")
async def optimize_crystal_structure(
content: str,
input_format: str = "cif",
output_format: str = "cif"
) -> str:
"""
Optimize crystal structure using FairChem model.
Args:
content: Crystal structure content string
input_format: Input format (cif, xyz, vasp)
output_format: Output format (cif, xyz, vasp)
Returns:
Optimized structure with energy and optimization log
"""
# 确保模型已初始化
if calc is None:
init_model()
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_optimization():
# 转换结构
atoms = convert_structure(input_format, content)
if atoms is None:
raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确")
# 优化结构
return optimize_structure(atoms, output_format)
try:
# 直接返回结果或抛出异常
return await asyncio.to_thread(run_optimization)
except Exception as e:
return handle_general_error(e)

View File

@@ -0,0 +1,13 @@
"""
Core Module
This module provides core functionality for the Mars Toolkit.
"""
from mars_toolkit.core.config import config
from mars_toolkit.core.utils import settings, setup_logging
from mars_toolkit.core.error_handlers import (
handle_minio_error, handle_http_error,
handle_validation_error, handle_general_error
)
from mars_toolkit.core.llm_tools import llm_tool

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,121 @@
"""
CIF Utilities Module
This module provides basic functions for handling CIF (Crystallographic Information File) files,
which are commonly used in materials science for representing crystal structures.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import json
import logging
logger = logging.getLogger(__name__)
def read_cif_txt_file(file_path):
"""
Read the CIF file and return its content.
Args:
file_path: Path to the CIF file
Returns:
String content of the CIF file or None if an error occurs
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"Error reading file {file_path}: {e}")
return None
def extract_cif_info(path: str, fields_name: list):
"""
Extract specific fields from the CIF description JSON file.
Args:
path: Path to the JSON file containing CIF information
fields_name: List of field categories to extract. Use 'all_fields' to extract all fields.
Other options include 'basic_fields', 'energy_electronic_fields', 'metal_magentic_fields'
Returns:
Dictionary containing the extracted fields
"""
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
selected_fields = []
if fields_name[0] == 'all_fields':
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
else:
for field in fields_name:
selected_fields.extend(locals().get(field, []))
with open(path, 'r') as f:
docs = json.load(f)
new_docs = {}
for field_name in selected_fields:
new_docs[field_name] = docs.get(field_name, '')
return new_docs
def remove_symmetry_equiv_xyz(cif_content):
"""
Remove symmetry operations section from CIF file content.
This is often useful when working with CIF files in certain visualization tools
or when focusing on the basic structure without symmetry operations.
Args:
cif_content: CIF file content string
Returns:
Cleaned CIF content string with symmetry operations removed
"""
lines = cif_content.split('\n')
output_lines = []
i = 0
while i < len(lines):
line = lines[i].strip()
# 检测循环开始
if line == 'loop_':
# 查看下一行,检查是否是对称性循环
next_lines = []
j = i + 1
while j < len(lines) and lines[j].strip().startswith('_'):
next_lines.append(lines[j].strip())
j += 1
# 检查是否包含对称性操作标签
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
# 跳过整个循环块
while i < len(lines):
if i + 1 >= len(lines):
break
next_line = lines[i + 1].strip()
# 检查是否到达下一个循环或数据块
if next_line == 'loop_' or next_line.startswith('data_'):
break
# 检查是否到达原子位置部分
if next_line.startswith('_atom_site_'):
break
i += 1
else:
# 不是对称性循环保留loop_行
output_lines.append(lines[i])
else:
# 非循环开始行,直接保留
output_lines.append(lines[i])
i += 1
return '\n'.join(output_lines)

View File

@@ -0,0 +1,59 @@
"""
Configuration Module
This module provides configuration settings for the Mars Toolkit.
It includes API keys, endpoints, paths, and other configuration parameters.
"""
from typing import Dict, Any
class Config:
"""Configuration class for Mars Toolkit"""
# Materials Project
MP_API_KEY = 'PMASAg256b814q3OaSRWeVc7MKx4mlKI'
MP_ENDPOINT = 'https://api.materialsproject.org/'
MP_TOPK = 3
LOCAL_MP_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/'
# Proxy
HTTP_PROXY = 'http://192.168.168.1:20171'
HTTPS_PROXY = 'http://192.168.168.1:20171'
# FairChem
FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
FMAX = 0.05
# MatterGen
MATTERGENMODEL_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
MATTERGENMODEL_RESULT_PATH = 'results/'
# Dify
DIFY_ROOT_URL = 'http://192.168.191.101:6080'
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
# Searxng
SEARXNG_HOST="http://192.168.191.101:40032/"
# Visualization
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
@classmethod
def as_dict(cls) -> Dict[str, Any]:
"""Return all configuration settings as a dictionary"""
return {
key: value for key, value in cls.__dict__.items()
if not key.startswith('__') and not callable(value)
}
@classmethod
def update(cls, **kwargs):
"""Update configuration settings"""
for key, value in kwargs.items():
if hasattr(cls, key):
setattr(cls, key, value)
# Create a global instance for easy access
config = Config()

View File

@@ -0,0 +1,55 @@
"""
Error Handlers Module
This module provides error handling utilities for the Mars Toolkit.
It includes functions for handling various types of errors that may occur
during toolkit operations.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
from fastapi import HTTPException
from typing import Any, Dict
import logging
logger = logging.getLogger(__name__)
class APIError(HTTPException):
"""自定义API错误类"""
def __init__(self, status_code: int, detail: Any = None):
super().__init__(status_code=status_code, detail=detail)
logger.error(f"API Error: {status_code} - {detail}")
def handle_minio_error(e: Exception) -> Dict[str, str]:
"""处理MinIO相关错误"""
logger.error(f"MinIO operation failed: {str(e)}")
return {
"status": "error",
"data": f"MinIO operation failed: {str(e)}"
}
def handle_http_error(e: Exception) -> Dict[str, str]:
"""处理HTTP请求错误"""
logger.error(f"HTTP request failed: {str(e)}")
return {
"status": "error",
"data": f"HTTP request failed: {str(e)}"
}
def handle_validation_error(e: Exception) -> Dict[str, str]:
"""处理数据验证错误"""
logger.error(f"Validation failed: {str(e)}")
return {
"status": "error",
"data": f"Validation failed: {str(e)}"
}
def handle_general_error(e: Exception) -> Dict[str, str]:
"""处理通用错误"""
logger.error(f"Unexpected error: {str(e)}")
return {
"status": "error",
"data": f"Unexpected error: {str(e)}"
}

View File

@@ -0,0 +1,213 @@
"""
LLM Tools Module
This module provides decorators and utilities for defining, registering, and managing LLM tools.
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
registered tools for use with LLM APIs.
"""
import asyncio
import inspect
import json
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
import docstring_parser
from pydantic import BaseModel, create_model, Field
# Registry to store all registered tools
_TOOL_REGISTRY = {}
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
"""
Decorator to mark a function as an LLM tool.
This decorator registers the function as an LLM tool, generates a JSON schema for it,
and makes it available for retrieval through the get_tools function.
Args:
name: Optional custom name for the tool. If not provided, the function name will be used.
description: Optional custom description for the tool. If not provided, the function's
docstring will be used.
Returns:
The decorated function with additional attributes for LLM tool functionality.
Example:
@llm_tool(name="weather_lookup", description="Get current weather for a location")
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
'''Get weather information for a specific location.'''
# Implementation...
return {"temperature": 22.5, "conditions": "sunny"}
"""
# Handle case when decorator is used without parentheses: @llm_tool
if callable(name):
func = name
name = None
description = None
return _llm_tool_impl(func, name, description)
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
def decorator(func: Callable) -> Callable:
return _llm_tool_impl(func, name, description)
return decorator
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
"""Implementation of the llm_tool decorator."""
# Get function signature and docstring
sig = inspect.signature(func)
doc = inspect.getdoc(func) or ""
parsed_doc = docstring_parser.parse(doc)
# Determine tool name
tool_name = name or func.__name__
# Determine tool description
tool_description = description or doc
# Create parameter properties for JSON schema
properties = {}
required = []
for param_name, param in sig.parameters.items():
# Skip self parameter for methods
if param_name == "self":
continue
param_type = param.annotation
param_default = None if param.default is inspect.Parameter.empty else param.default
param_required = param.default is inspect.Parameter.empty
# Get parameter description from docstring if available
param_desc = ""
for param_doc in parsed_doc.params:
if param_doc.arg_name == param_name:
param_desc = param_doc.description
break
# Handle Annotated types
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
args = get_args(param_type)
param_type = args[0] # The actual type
if len(args) > 1 and isinstance(args[1], str):
param_desc = args[1] # The description
# Create property for parameter
param_schema = {
"type": _get_json_type(param_type),
"description": param_desc,
"title": param_name.replace("_", " ").title()
}
# Add default value if available
if param_default is not None:
param_schema["default"] = param_default
properties[param_name] = param_schema
# Add to required list if no default value
if param_required:
required.append(param_name)
# Create JSON schema
schema = {
"type": "function",
"function": {
"name": tool_name,
"description": tool_description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
}
# Create Pydantic model for args schema
field_definitions = {}
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
param_type = param.annotation
param_default = ... if param.default is inspect.Parameter.empty else param.default
# Handle Annotated types
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
args = get_args(param_type)
param_type = args[0]
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
else:
field_definitions[param_name] = (param_type, Field(default=param_default))
# Create args schema model
model_name = f"{tool_name.title().replace('_', '')}Schema"
args_schema = create_model(model_name, **field_definitions)
# 根据原始函数是否是异步函数来创建相应类型的包装函数
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
else:
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
# Attach metadata to function
wrapper.is_llm_tool = True
wrapper.tool_name = tool_name
wrapper.tool_description = tool_description
wrapper.json_schema = schema
wrapper.args_schema = args_schema
# Register the tool
_TOOL_REGISTRY[tool_name] = wrapper
return wrapper
def get_tools() -> Dict[str, Callable]:
"""
Get all registered LLM tools.
Returns:
A dictionary mapping tool names to their corresponding functions.
"""
return _TOOL_REGISTRY
def get_tool_schemas() -> List[Dict[str, Any]]:
"""
Get JSON schemas for all registered LLM tools.
Returns:
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
"""
return [tool.json_schema for tool in _TOOL_REGISTRY.values()]
def _get_json_type(python_type: Any) -> str:
"""
Convert Python type to JSON schema type.
Args:
python_type: Python type annotation
Returns:
Corresponding JSON schema type as string
"""
if python_type is str:
return "string"
elif python_type is int:
return "integer"
elif python_type is float:
return "number"
elif python_type is bool:
return "boolean"
elif python_type is list or python_type is List:
return "array"
elif python_type is dict or python_type is Dict:
return "object"
else:
# Default to string for complex types
return "string"

View File

@@ -0,0 +1,75 @@
import os
import boto3
import logging
import logging.config
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
logger = logging.getLogger(__name__)
class Settings(BaseSettings):
# Material Project
mp_api_key: Optional[str] = Field(None, env="MP_API_KEY")
mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT")
mp_topk: Optional[int] = Field(3, env="MP_TOPK")
# Proxy
http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
# FairChem
fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH")
fmax: Optional[float] = Field(0.05, env="FMAX")
# MinIO
minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT")
internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT")
minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY")
minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY")
minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET")
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
def setup_logging():
"""配置日志记录"""
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
log_file_path = os.path.join(parent_dir, 'mars_toolkit.log')
logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S'
},
},
'handlers': {
'console': {
'level': 'INFO',
'class': 'logging.StreamHandler',
'formatter': 'standard'
},
'file': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': log_file_path,
'maxBytes': 10485760, # 10MB
'backupCount': 5,
'formatter': 'standard'
}
},
'loggers': {
'': {
'handlers': ['console', 'file'],
'level': 'INFO',
'propagate': True
}
}
})
# 初始化配置
settings = Settings()

View File

@@ -0,0 +1,7 @@
"""
Basic Module
This module provides basic utility functions for the Mars Toolkit.
"""
from mars_toolkit.misc.misc_tools import get_current_time

Binary file not shown.

View File

@@ -0,0 +1,29 @@
"""
General Tools Module
This module provides basic utility functions that are not specific to materials science.
"""
import asyncio
from datetime import datetime
import pytz
from typing import Annotated
from mars_toolkit.core.llm_tools import llm_tool
@llm_tool(name="get_current_time", description="Get current date and time in specified timezone")
async def get_current_time(timezone: str = "UTC") -> str:
"""Returns the current date and time in the specified timezone.
Args:
timezone: Timezone name (e.g., UTC, Asia/Shanghai, America/New_York)
Returns:
Formatted date and time string
"""
try:
tz = pytz.timezone(timezone)
current_time = datetime.now(tz)
return f"The current {timezone} time is: {current_time.strftime('%Y-%m-%d %H:%M:%S %Z')}"
except pytz.exceptions.UnknownTimeZoneError:
return f"Unknown timezone: {timezone}. Please use a valid timezone such as 'UTC', 'Asia/Shanghai', etc."

View File

@@ -0,0 +1,18 @@
"""
Query Module
This module provides query tools for materials science, including:
- Materials Project database queries
- OQMD database queries
- Dify knowledge base retrieval
- Web search
"""
from mars_toolkit.query.mp_query import (
search_material_property_from_material_project,
get_crystal_structures_from_materials_project,
get_mpid_from_formula
)
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
from mars_toolkit.query.web_search import search_online

View File

@@ -0,0 +1,84 @@
"""
Dify Search Module
This module provides functions for retrieving information from local materials science
literature knowledge base using Dify API.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import asyncio
import json
import requests
import codecs
from typing import Dict, Any
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
@llm_tool(
name="retrieval_from_knowledge_base",
description="Retrieve information from local materials science literature knowledge base"
)
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
"""
检索本地材料科学文献知识库中的相关信息
Args:
query: 查询字符串,如材料名称"CsPbBr3"
topk: 返回结果数量默认3条
Returns:
包含文档ID、标题和相关性分数的字典
"""
# 设置Dify API的URL端点
url = f'{config.DIFY_ROOT_URL}/v1/chat-messages'
# 配置请求头包含API密钥和内容类型
headers = {
'Authorization': f'Bearer {config.DIFY_API_KEY}',
'Content-Type': 'application/json'
}
# 准备请求数据
data = {
"inputs": {"topK": topk}, # 设置返回的最大结果数量
"query": query, # 设置查询字符串
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
"conversation_id": "", # 不使用会话ID每次都是独立查询
"user": "abc-123" # 用户标识符
}
try:
# 发送POST请求到Dify API并获取响应
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
response = requests.post(url, headers=headers, json=data, timeout=1111)
# 获取响应文本
response_text = response.text
# 解码响应文本中的Unicode转义序列
response_text = codecs.decode(response_text, 'unicode_escape')
# 将响应文本解析为JSON对象
result_json = json.loads(response_text)
# 从响应中提取元数据
metadata = result_json.get("metadata", {})
# 构建包含关键信息的结果字典
useful_info = {
"id": metadata.get("document_id"), # 文档ID
"title": result_json.get("title"), # 文档标题
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
"score": metadata.get("score") # 相关性分数
}
# 返回提取的有用信息
return json.dumps(useful_info, ensure_ascii=False, indent=2)
except Exception as e:
# 捕获并处理所有可能的异常,返回错误信息
return f"错误: {str(e)}"

View File

@@ -0,0 +1,433 @@
"""
Materials Project Query Module
This module provides functions for querying the Materials Project database,
processing search results, and formatting responses.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import glob
import json
import asyncio
import logging
import datetime
import os
from multiprocessing import Process, Manager
from typing import Dict, Any, List, Optional
from mp_api.client import MPRester
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
from mars_toolkit.core.error_handlers import handle_general_error
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
logger = logging.getLogger(__name__)
def parse_bool(param: str) -> bool | None:
"""
Parse a string parameter into a boolean value.
Args:
param: String parameter to parse (e.g., "true", "false")
Returns:
Boolean value if param is not empty, None otherwise
"""
if not param:
return None
return param.lower() == 'true'
def parse_list(param: str) -> List[str] | None:
"""
Parse a comma-separated string into a list of strings.
Args:
param: Comma-separated string (e.g., "Li,Fe,O")
Returns:
List of strings if param is not empty, None otherwise
"""
if not param:
return None
return param.split(',')
def parse_tuple(param: str) -> tuple[float, float] | None:
"""
Parse a comma-separated string into a tuple of two float values.
Used for range parameters like band_gap, density, etc.
Args:
param: Comma-separated string of two numbers (e.g., "0,3.5")
Returns:
Tuple of two float values if param is valid, None otherwise
"""
if not param:
return None
try:
values = param.split(',')
return (float(values[0]), float(values[1]))
except (ValueError, IndexError):
return None
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
"""
Parse search parameters from query parameters.
Converts string query parameters into appropriate types for the Materials Project API.
"""
return {
'band_gap': parse_tuple(query_params.get('band_gap')),
'chemsys': parse_list(query_params.get('chemsys')),
'crystal_system': parse_list(query_params.get('crystal_system')),
'density': parse_tuple(query_params.get('density')),
'formation_energy': parse_tuple(query_params.get('formation_energy')),
'elements': parse_list(query_params.get('elements')),
'exclude_elements': parse_list(query_params.get('exclude_elements')),
'formula': parse_list(query_params.get('formula')),
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
'is_metal': parse_bool(query_params.get('is_metal')),
'is_stable': parse_bool(query_params.get('is_stable')),
'magnetic_ordering': query_params.get('magnetic_ordering'),
'material_ids': parse_list(query_params.get('material_ids')),
'total_energy': parse_tuple(query_params.get('total_energy')),
'num_elements': parse_tuple(query_params.get('num_elements')),
'volume': parse_tuple(query_params.get('volume')),
'chunk_size': int(query_params.get('chunk_size', '5'))
}
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]] | str:
"""
Process search results from the Materials Project API.
Extracts relevant fields from each document and formats them into a consistent structure.
Returns:
List of processed documents or error message string if an exception occurs
"""
try:
fields = [
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
'universal_anisotropy', 'theoretical'
]
res = []
for doc in docs:
try:
new_docs = {}
for field_name in fields:
new_docs[field_name] = doc.get(field_name, '')
res.append(new_docs)
except Exception as e:
logger.warning(f"Error processing document: {str(e)}")
continue
return res
except Exception as e:
error_msg = f"Error in process_search_results: {str(e)}"
logger.error(error_msg)
return error_msg
def _search_worker(queue, api_key, **kwargs):
"""
Worker function for executing Materials Project API searches.
Runs in a separate process to perform the actual API call and puts results in the queue.
Args:
queue: Multiprocessing queue for returning results
api_key: Materials Project API key
**kwargs: Search parameters to pass to the API
"""
try:
import os
import traceback
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
# 初始化 MPRester 客户端
with MPRester(api_key) as mpr:
result = mpr.materials.summary.search(**kwargs)
# 检查结果
if result:
# 尝试使用更安全的方式处理结果
processed_results = []
for doc in result:
try:
# 尝试使用 model_dump 方法
processed_doc = doc.model_dump()
processed_results.append(processed_doc)
except AttributeError:
# 如果没有 model_dump 方法,尝试使用 dict 方法
try:
processed_doc = doc.dict()
processed_results.append(processed_doc)
except AttributeError:
# 如果没有 dict 方法,尝试直接转换为字典
if hasattr(doc, "__dict__"):
processed_doc = doc.__dict__
# 移除可能导致序列化问题的特殊属性
if "_sa_instance_state" in processed_doc:
del processed_doc["_sa_instance_state"]
processed_results.append(processed_doc)
else:
# 最后的尝试,直接使用 doc
processed_results.append(doc)
queue.put(processed_results)
else:
queue.put([])
except Exception as e:
queue.put(e)
async def execute_search(search_args: Dict[str, Any], timeout: int = 120) -> List[Dict[str, Any]] | str:
"""
Execute a search against the Materials Project API.
Runs the search in a separate process to handle potential timeouts and returns the results.
Args:
search_args: Dictionary of search parameters
timeout: Maximum time in seconds to wait for the search to complete
Returns:
List of document dictionaries from the search results or error message string if an exception occurs
"""
# 确保 formula 参数是列表类型
if 'formula' in search_args and isinstance(search_args['formula'], str):
search_args['formula'] = [search_args['formula']]
manager = Manager()
queue = manager.Queue()
try:
p = Process(target=_search_worker, args=(queue, config.MP_API_KEY), kwargs=search_args)
p.start()
p.join(timeout=timeout)
if p.is_alive():
logger.warning(f"Terminating worker process {p.pid} due to timeout")
p.terminate()
p.join()
error_msg = f"Request timed out after {timeout} seconds"
return error_msg
try:
result = queue.get(timeout=timeout)
if isinstance(result, Exception):
logger.error(f"Error in search worker: {str(result)}")
if hasattr(result, "__traceback__"):
import traceback
tb_str = ''.join(traceback.format_exception(None, result, result.__traceback__))
return f"Error in search worker: {str(result)}"
return result
except queue.Empty:
error_msg = "Failed to retrieve data from queue (timeout)"
logger.error(error_msg)
return error_msg
except Exception as e:
error_msg = f"Error in execute_search: {str(e)}"
logger.error(error_msg)
return error_msg
@llm_tool(name="search_material_property_from_material_project", description="Search materials in Materials Project database by formula and properties")
async def search_material_property_from_material_project(
formula: str | list[str],
chemsys: Optional[str | list[str] | None] = None,
crystal_system: Optional[str | list[str] | None] = None,
is_gap_direct: Optional[bool | None] = None,
is_stable: Optional[bool | None] = None,
) -> str:
"""
Search materials in Materials Project database.
Args:
formula: Chemical formula(s) (e.g., "Fe2O3" or ["ABO3", "Si*"])
chemsys: Chemical system(s) (e.g., "Li-Fe-O")
crystal_system: Crystal system(s) (e.g., "Cubic")
is_gap_direct: Filter for direct band gap materials
is_stable: Filter for thermodynamically stable materials
Returns:
JSON formatted material properties data
"""
# 验证晶系参数
VALID_CRYSTAL_SYSTEMS = ['Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal', 'Cubic']
# 验证晶系参数是否有效
if crystal_system is not None:
if isinstance(crystal_system, str):
if crystal_system not in VALID_CRYSTAL_SYSTEMS:
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
elif isinstance(crystal_system, list):
for cs in crystal_system:
if cs not in VALID_CRYSTAL_SYSTEMS:
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
# 确保 formula 是列表类型
if isinstance(formula, str):
formula = [formula]
params = {
"chemsys": chemsys,
"crystal_system": crystal_system,
"formula": formula,
"is_gap_direct": is_gap_direct,
"is_stable": is_stable,
"chunk_size": 5,
}
# Filter out None values
params = {k: v for k, v in params.items() if v is not None}
mp_id_list = await get_mpid_from_formula(formula=formula)
try:
res=[]
for mp_id in mp_id_list:
crystal_props = extract_cif_info(config.LOCAL_MP_ROOT+f"/Props/{mp_id}.json", ['all_fields'])
res.append(crystal_props)
if len(res) == 0:
return "No results found, please try again."
# Format response with top results
try:
# 创建包含索引的JSON结果
formatted_results = []
for i, item in enumerate(res[:config.MP_TOPK], 1):
formatted_result = f"[property {i} begin]\n"
formatted_result += json.dumps(item, indent=2)
formatted_result += f"\n[property {i} end]\n\n"
formatted_results.append(formatted_result)
# 将所有结果合并为一个字符串
res_chunk = "\n\n".join(formatted_results)
res_template = f"""
Here are the search results from the Materials Project database:
Due to length limitations, only the top {config.MP_TOPK} results are shown below:\n
{res_chunk}
If you need more results, please modify your search criteria or try different query parameters.
"""
return res_template
except Exception as format_error:
logger.error(f"Error formatting results: {str(format_error)}")
return str(format_error)
except Exception as e:
logger.error(f"Error in search_material_property_from_material_project: {str(e)}")
return str(e)
@llm_tool(name="get_crystal_structures_from_materials_project", description="Get symmetrized crystal structures CIF data from Materials Project database by chemical formula")
async def get_crystal_structures_from_materials_project(
formulas: list[str],
conventional_unit_cell: bool = True,
symprec: float = 0.1
) -> str:
"""
Get crystal structures from Materials Project database by chemical formula and apply symmetrization.
Args:
formulas: List of chemical formulas (e.g., ["Fe2O3", "SiO2", "TiO2"])
conventional_unit_cell: Whether to return conventional unit cell (True) or primitive cell (False)
symprec: Precision parameter for symmetrization
Returns:
Formatted text containing symmetrized CIF data
"""
result={}
mp_id_list=await get_mpid_from_formula(formula=formulas)
for i,mp_id in enumerate(mp_id_list):
cif_file = glob.glob(config.LOCAL_MP_ROOT+f"/MPDatasets/{mp_id}.cif")[0]
structure = Structure.from_file(cif_file)
# 如果需要常规单元格
if conventional_unit_cell:
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
# 对结构进行对称化处理
sga = SpacegroupAnalyzer(structure, symprec=symprec)
symmetrized_structure = sga.get_refined_structure()
# 使用CifWriter生成CIF数据
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
cif_data = str(cif_writer)
# 删除CIF文件中的对称性操作部分
cif_data = remove_symmetry_equiv_xyz(cif_data)
cif_data=cif_data.replace('# generated using pymatgen',"")
# 生成一个唯一的键
formula = structure.composition.reduced_formula
key = f"{formula}_{i}"
result[key] = cif_data
# 只保留前config.MP_TOPK个结果
if len(result) >= config.MP_TOPK:
break
try:
prompt = f"""
# Materials Project Symmetrized Crystal Structure Data
Below are symmetrized crystal structure data for {len(result)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
"""
for i, (key, cif_data) in enumerate(result.items(), 1):
prompt += f"[cif {i} begin]\n"
prompt += cif_data
prompt += f"\n[cif {i} end]\n\n"
prompt += """
## Usage Instructions
1. You can copy the above CIF data and save it as .cif files
2. Open these files with crystal structure visualization software (such as VESTA, Mercury, Avogadro, etc.)
3. These structures can be used for further material analysis, simulation, or visualization
CIF files contain complete structural information of crystals, including cell parameters, atomic coordinates, symmetry, etc.
Symmetrization helps identify and optimize crystal symmetry, making the structure more standardized and accurate.
"""
return prompt
except Exception as format_error:
logger.error(f"Error formatting crystal structures: {str(format_error)}")
return str(format_error)
@llm_tool(name="get_mpid_from_formula", description="Get material IDs (mpid) from Materials Project database by chemical formula")
async def get_mpid_from_formula(formula: str) -> List[str]:
"""
Get material IDs (mpid) from Materials Project database by chemical formula.
Returns mpids for the lowest energy structures.
Args:
formula: Chemical formula (e.g., "Fe2O3")
Returns:
List of material IDs
"""
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
id_list = []
try:
with MPRester(config.MP_API_KEY) as mpr:
docs = mpr.materials.summary.search(formula=formula)
for doc in docs:
id_list.append(doc.material_id)
return id_list
except Exception as e:
logger.error(f"Error getting mpid from formula: {str(e)}")
return []

View File

@@ -0,0 +1,105 @@
"""
OQMD Query Module
This module provides functions for querying the Open Quantum Materials Database (OQMD).
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import logging
import httpx
import pandas as pd
from bs4 import BeautifulSoup
from io import StringIO
from typing import Annotated
from mars_toolkit.core.llm_tools import llm_tool
logger = logging.getLogger(__name__)
@llm_tool(name="fetch_chemical_composition_from_OQMD", description="Fetch material data for a chemical composition from OQMD database")
async def fetch_chemical_composition_from_OQMD(
composition: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
) -> str:
"""
Fetch material data for a chemical composition from OQMD database.
Args:
composition: Chemical formula (e.g., Fe2O3, LiFePO4)
Returns:
Formatted text with material information and property tables
"""
# Fetch data from OQMD
url = f"https://www.oqmd.org/materials/composition/{composition}"
try:
async with httpx.AsyncClient(timeout=100.0) as client:
response = await client.get(url)
response.raise_for_status()
# Validate response content
if not response.text or len(response.text) < 100:
raise ValueError("Invalid response content from OQMD API")
# Parse HTML data
html = response.text
soup = BeautifulSoup(html, 'html.parser')
# Parse basic data
basic_data = []
h1_element = soup.find('h1')
if h1_element:
basic_data.append(h1_element.text.strip())
else:
basic_data.append(f"Material: {composition}")
for script in soup.find_all('p'):
if script:
combined_text = ""
for element in script.contents:
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
url = "https://www.oqmd.org" + element['href']
combined_text += f"[{element.text.strip()}]({url}) "
elif hasattr(element, 'text'):
combined_text += element.text.strip() + " "
else:
combined_text += str(element).strip() + " "
basic_data.append(combined_text.strip())
# Parse table data
table_data = ""
table = soup.find('table')
if table:
try:
df = pd.read_html(StringIO(str(table)))[0]
df = df.fillna('')
df = df.replace([float('inf'), float('-inf')], '')
table_data = df.to_markdown(index=False)
except Exception as e:
logger.error(f"Error parsing table: {str(e)}")
table_data = "Error parsing table data"
# Integrate data into a single text
combined_text = "\n\n".join(basic_data)
if table_data:
combined_text += "\n\n## Material Properties Table\n\n" + table_data
return combined_text
except httpx.HTTPStatusError as e:
logger.error(f"OQMD API request failed: {str(e)}")
return f"Error: OQMD API request failed - {str(e)}"
except httpx.TimeoutException:
logger.error("OQMD API request timed out")
return "Error: OQMD API request timed out"
except httpx.NetworkError as e:
logger.error(f"Network error occurred: {str(e)}")
return f"Error: Network error occurred - {str(e)}"
except ValueError as e:
logger.error(f"Invalid response content: {str(e)}")
return f"Error: Invalid response content - {str(e)}"
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return f"Error: Unexpected error occurred - {str(e)}"

View File

@@ -0,0 +1,77 @@
"""
Web Search Module
This module provides functions for searching information on the web.
"""
import asyncio
from typing import Annotated, Dict, Any, List
from langchain_community.utilities import SearxSearchWrapper
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
@llm_tool(name="search_online", description="Search scientific information online and return results as a string")
async def search_online(
query: Annotated[str, "Search term"],
num_results: Annotated[int, "Number of results (1-20)"] = 5
) -> str:
"""
Searches for scientific information online and returns results as a formatted string.
Args:
query: Search term for scientific content
num_results: Number of results to return (1-20)
Returns:
Formatted string with search results (titles, snippets, links)
"""
# 确保 num_results 是整数
try:
num_results = int(num_results)
except (TypeError, ValueError):
num_results = 5
# Parameter validation
if num_results < 1:
num_results = 1
elif num_results > 20:
num_results = 20
# Initialize search wrapper
search = SearxSearchWrapper(
searx_host=config.SEARXNG_HOST,
categories=["science",],
k=num_results
)
# Execute search in a separate thread to avoid blocking the event loop
# since SearxSearchWrapper doesn't have native async support
loop = asyncio.get_event_loop()
raw_results = await loop.run_in_executor(
None,
lambda: search.results(query, language=['en','zh'], num_results=num_results)
)
# Transform results into structured format
formatted_results = []
for result in raw_results:
formatted_results.append({
"title": result.get("title", ""),
"snippet": result.get("snippet", ""),
"link": result.get("link", ""),
"source": result.get("source", "")
})
# Convert the results to a formatted string
result_str = f"Search Results for '{query}' ({len(formatted_results)} items):\n\n"
for i, result in enumerate(formatted_results, 1):
result_str += f"Result {i}:\n"
result_str += f"Title: {result['title']}\n"
result_str += f"Summary: {result['snippet']}\n"
result_str += f"Link: {result['link']}\n"
result_str += f"Source: {result['source']}\n\n"
return result_str

View File