构建mars_toolkit,删除tools_for_ms
This commit is contained in:
46
mars_toolkit/__init__.py
Normal file
46
mars_toolkit/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Mars Toolkit
|
||||
|
||||
A comprehensive toolkit for materials science research, providing tools for:
|
||||
- Material generation and property prediction
|
||||
- Structure optimization
|
||||
- Database queries (Materials Project, OQMD)
|
||||
- Knowledge base retrieval
|
||||
- Web search
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
# Core modules
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.utils import setup_logging
|
||||
|
||||
# Basic tools
|
||||
from mars_toolkit.misc.misc_tools import get_current_time
|
||||
|
||||
# Compute modules
|
||||
from mars_toolkit.compute.material_gen import generate_material
|
||||
from mars_toolkit.compute.property_pred import predict_properties
|
||||
from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
|
||||
|
||||
# Query modules
|
||||
from mars_toolkit.query.mp_query import (
|
||||
search_material_property_from_material_project,
|
||||
get_crystal_structures_from_materials_project,
|
||||
get_mpid_from_formula
|
||||
)
|
||||
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
from mars_toolkit.query.web_search import search_online
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
|
||||
|
||||
|
||||
|
||||
# Initialize logging
|
||||
setup_logging()
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = ["llm_tool", "get_tools", "get_tool_schemas"]
|
||||
BIN
mars_toolkit/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
12
mars_toolkit/compute/__init__.py
Normal file
12
mars_toolkit/compute/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Compute Module
|
||||
|
||||
This module provides computational tools for materials science, including:
|
||||
- Material generation
|
||||
- Property prediction
|
||||
- Structure optimization
|
||||
"""
|
||||
|
||||
from mars_toolkit.compute.material_gen import generate_material
|
||||
from mars_toolkit.compute.property_pred import predict_properties
|
||||
from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
|
||||
BIN
mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc
Normal file
BIN
mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc
Normal file
BIN
mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc
Normal file
BIN
mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc
Normal file
Binary file not shown.
423
mars_toolkit/compute/material_gen.py
Normal file
423
mars_toolkit/compute/material_gen.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Material Generation Module
|
||||
|
||||
This module provides functions for generating crystal structures with optional property constraints.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import os
|
||||
import datetime
|
||||
import asyncio
|
||||
import zipfile
|
||||
import shutil
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
||||
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from ase.io import read, write
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
# 导入路径已更新
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
|
||||
# 使用mattergen_wrapper
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
||||
from mattergen_wrapper import generator
|
||||
CrystalGenerator = generator.CrystalGenerator
|
||||
from mattergen.common.data.types import TargetProperty
|
||||
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||||
from mattergen.common.utils.data_classes import (
|
||||
PRETRAINED_MODEL_NAME,
|
||||
MatterGenCheckpointInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_cif_content(content):
|
||||
"""
|
||||
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
||||
|
||||
Args:
|
||||
content: String containing CIF content, possibly with PK headers
|
||||
|
||||
Returns:
|
||||
Formatted string with each CIF file properly labeled and formatted
|
||||
"""
|
||||
# 如果内容为空,直接返回空字符串
|
||||
if not content or content.strip() == '':
|
||||
return ''
|
||||
|
||||
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
|
||||
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
|
||||
|
||||
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
|
||||
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
|
||||
|
||||
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
|
||||
# 但我们需要保留这个字段在每个CIF文件中
|
||||
cif_blocks = []
|
||||
|
||||
# 查找所有_chemical_formula_structural的位置
|
||||
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
|
||||
|
||||
# 如果没有找到任何_chemical_formula_structural,返回空字符串
|
||||
if not formula_positions:
|
||||
return ''
|
||||
|
||||
# 分割CIF块
|
||||
for i in range(len(formula_positions)):
|
||||
start_pos = formula_positions[i]
|
||||
# 如果是最后一个块,结束位置是字符串末尾
|
||||
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
|
||||
|
||||
cif_block = content[start_pos:end_pos].strip()
|
||||
|
||||
# 提取formula值
|
||||
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
|
||||
if formula_match:
|
||||
formula = formula_match.group(1)
|
||||
cif_blocks.append((formula, cif_block))
|
||||
|
||||
# 格式化输出
|
||||
result = []
|
||||
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
|
||||
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]"
|
||||
result.append(formatted)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def convert_values(data_str):
|
||||
"""
|
||||
将字符串转换为字典
|
||||
|
||||
Args:
|
||||
data_str: JSON字符串
|
||||
|
||||
Returns:
|
||||
解析后的数据,如果解析失败则返回原字符串
|
||||
"""
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return data_str # 如果无法解析为JSON,返回原字符串
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
|
||||
"""
|
||||
Preprocess a property value based on its name, converting it to the appropriate type.
|
||||
|
||||
Args:
|
||||
property_name: Name of the property
|
||||
property_value: Value of the property (can be string, float, or int)
|
||||
|
||||
Returns:
|
||||
Tuple of (property_name, processed_value)
|
||||
|
||||
Raises:
|
||||
ValueError: If the property value is invalid for the given property name
|
||||
"""
|
||||
valid_properties = [
|
||||
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
|
||||
"energy_above_hull", "formation_energy_per_atom", "space_group",
|
||||
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
|
||||
]
|
||||
|
||||
if property_name not in valid_properties:
|
||||
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
|
||||
|
||||
# Process property_value if it's a string
|
||||
if isinstance(property_value, str):
|
||||
try:
|
||||
# Try to convert string to float for numeric properties
|
||||
if property_name != "chemical_system":
|
||||
property_value = float(property_value)
|
||||
except ValueError:
|
||||
# If conversion fails, keep as string (for chemical_system)
|
||||
pass
|
||||
|
||||
# Handle special cases for properties that need specific types
|
||||
if property_name == "chemical_system":
|
||||
if isinstance(property_value, (int, float)):
|
||||
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
|
||||
property_value = str(property_value)
|
||||
elif property_name == "space_group" :
|
||||
space_group = property_value
|
||||
if space_group < 1 or space_group > 230:
|
||||
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
|
||||
|
||||
return property_name, property_value
|
||||
|
||||
|
||||
def main(
|
||||
output_path: str,
|
||||
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
|
||||
model_path: str | None = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
config_overrides: list[str] | None = None,
|
||||
checkpoint_epoch: Literal["best", "last"] | int = "last",
|
||||
properties_to_condition_on: TargetProperty | None = None,
|
||||
sampling_config_path: str | None = None,
|
||||
sampling_config_name: str = "default",
|
||||
sampling_config_overrides: list[str] | None = None,
|
||||
record_trajectories: bool = True,
|
||||
diffusion_guidance_factor: float | None = None,
|
||||
strict_checkpoint_loading: bool = True,
|
||||
target_compositions: list[dict[str, int]] | None = None,
|
||||
):
|
||||
"""
|
||||
Evaluate diffusion model against molecular metrics.
|
||||
|
||||
Args:
|
||||
model_path: Path to DiffusionLightningModule checkpoint directory.
|
||||
output_path: Path to output directory.
|
||||
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
|
||||
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
|
||||
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
|
||||
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
|
||||
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
|
||||
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
|
||||
record: Whether to record the trajectories of the generated structures. (default: True)
|
||||
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
|
||||
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
|
||||
Only supported for models trained for crystal structure prediction (CSP) (default: None)
|
||||
|
||||
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
|
||||
"""
|
||||
assert (
|
||||
pretrained_name is not None or model_path is not None
|
||||
), "Either pretrained_name or model_path must be provided."
|
||||
assert (
|
||||
pretrained_name is None or model_path is None
|
||||
), "Only one of pretrained_name or model_path can be provided."
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
sampling_config_overrides = sampling_config_overrides or []
|
||||
config_overrides = config_overrides or []
|
||||
properties_to_condition_on = properties_to_condition_on or {}
|
||||
target_compositions = target_compositions or []
|
||||
|
||||
if pretrained_name is not None:
|
||||
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
|
||||
pretrained_name, config_overrides=config_overrides
|
||||
)
|
||||
else:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch=checkpoint_epoch,
|
||||
config_overrides=config_overrides,
|
||||
strict_checkpoint_loading=strict_checkpoint_loading,
|
||||
)
|
||||
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name=sampling_config_name,
|
||||
sampling_config_path=_sampling_config_path,
|
||||
sampling_config_overrides=sampling_config_overrides,
|
||||
record_trajectories=record_trajectories,
|
||||
diffusion_guidance_factor=(
|
||||
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
|
||||
),
|
||||
target_compositions_dict=target_compositions,
|
||||
)
|
||||
generator.generate(output_dir=Path(output_path))
|
||||
|
||||
|
||||
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
|
||||
async def generate_material(
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
This unified function can generate materials in three modes:
|
||||
1. Unconditional generation (no properties specified)
|
||||
2. Single property conditional generation (one property specified)
|
||||
3. Multi-property conditional generation (multiple properties specified)
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints. Can be:
|
||||
- None or empty dict for unconditional generation
|
||||
- Dict with single key-value pair for single property conditioning
|
||||
- Dict with multiple key-value pairs for multi-property conditioning
|
||||
Valid property names include: "dft_band_gap", "chemical_system", etc.
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
# 使用配置中的结果目录
|
||||
output_dir = config.MATTERGENMODEL_RESULT_PATH
|
||||
|
||||
# 处理字符串输入(如果提供)
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||
|
||||
# 如果为None,默认为空字典
|
||||
properties = properties or {}
|
||||
|
||||
# 根据生成模式处理属性
|
||||
if not properties:
|
||||
# 无条件生成
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
properties_to_condition_on = None
|
||||
generation_type = "unconditional"
|
||||
property_description = "unconditionally"
|
||||
else:
|
||||
# 条件生成(单属性或多属性)
|
||||
properties_to_condition_on = {}
|
||||
|
||||
# 处理每个属性
|
||||
for property_name, property_value in properties.items():
|
||||
_, processed_value = preprocess_property(property_name, property_value)
|
||||
properties_to_condition_on[property_name] = processed_value
|
||||
|
||||
# 根据属性确定使用哪个模型
|
||||
if len(properties) == 1:
|
||||
# 单属性条件
|
||||
property_name = list(properties.keys())[0]
|
||||
property_to_model = {
|
||||
"dft_mag_density": "dft_mag_density",
|
||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||
"dft_shear_modulus": "dft_shear_modulus",
|
||||
"energy_above_hull": "energy_above_hull",
|
||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||
"space_group": "space_group",
|
||||
"hhi_score": "hhi_score",
|
||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||
"chemical_system": "chemical_system",
|
||||
"dft_band_gap": "dft_band_gap"
|
||||
}
|
||||
model_dir = property_to_model.get(property_name, property_name)
|
||||
generation_type = "single_property"
|
||||
property_description = f"conditioned on {property_name} = {properties[property_name]}"
|
||||
else:
|
||||
# 多属性条件
|
||||
property_keys = set(properties.keys())
|
||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||
model_dir = "dft_mag_density_hhi_score"
|
||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||
model_dir = "chemical_system_energy_above_hull"
|
||||
else:
|
||||
# 如果没有特定的多属性模型,使用第一个属性的模型
|
||||
first_property = list(properties.keys())[0]
|
||||
model_dir = first_property
|
||||
generation_type = "multi_property"
|
||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||
|
||||
# 构建完整的模型路径
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
# 检查模型目录是否存在
|
||||
if not os.path.exists(model_path):
|
||||
# 如果特定模型不存在,回退到基础模型
|
||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
|
||||
# 使用适当的参数调用main函数
|
||||
main(
|
||||
output_path=output_dir,
|
||||
model_path=model_path,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
|
||||
)
|
||||
|
||||
# 创建字典存储文件内容
|
||||
result_dict = {}
|
||||
|
||||
# 定义文件路径
|
||||
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
|
||||
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
|
||||
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
|
||||
|
||||
# 读取CIF压缩文件
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
|
||||
# 根据生成类型创建描述性提示
|
||||
if generation_type == "unconditional":
|
||||
title = "Generated Material Structures"
|
||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||
elif generation_type == "single_property":
|
||||
property_name = list(properties.keys())[0]
|
||||
property_value = properties[property_name]
|
||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||
else: # multi_property
|
||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||
|
||||
# 创建完整的提示
|
||||
prompt = f"""
|
||||
# {title}
|
||||
|
||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||
|
||||
{'' if generation_type == 'unconditional' else f'''
|
||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||
the generation adheres to the specified property values. Higher values produce samples that more
|
||||
closely match the target properties but may reduce diversity.
|
||||
'''}
|
||||
|
||||
## CIF Files (Crystallographic Information Files)
|
||||
|
||||
- Standard format for crystallographic structures
|
||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||
- Used by crystallographic software and visualization tools
|
||||
|
||||
```
|
||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||
```
|
||||
|
||||
{description}
|
||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||
"""
|
||||
|
||||
# 清理文件(读取后删除)
|
||||
try:
|
||||
if os.path.exists(cif_zip_path):
|
||||
os.remove(cif_zip_path)
|
||||
if os.path.exists(xyz_file_path):
|
||||
os.remove(xyz_file_path)
|
||||
if os.path.exists(trajectories_zip_path):
|
||||
os.remove(trajectories_zip_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up files: {e}")
|
||||
return prompt
|
||||
72
mars_toolkit/compute/property_pred.py
Normal file
72
mars_toolkit/compute/property_pred.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Property Prediction Module
|
||||
|
||||
This module provides functions for predicting properties of crystal structures.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
import numpy as np
|
||||
from ase.units import GPa
|
||||
from mattersim.forcefield import MatterSimCalculator
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.compute.structure_opt import convert_structure
|
||||
|
||||
@llm_tool(
|
||||
name="predict_properties",
|
||||
description="Predict energy, forces, and stress of crystal structures based on CIF string",
|
||||
)
|
||||
async def predict_properties(cif_content: str) -> str:
|
||||
"""
|
||||
Use MatterSim to predict energy, forces, and stress of crystal structures.
|
||||
|
||||
Args:
|
||||
cif_content: Crystal structure string in CIF format
|
||||
|
||||
Returns:
|
||||
String containing prediction results
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_prediction():
|
||||
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
|
||||
structure = convert_structure("cif", cif_content)
|
||||
if structure is None:
|
||||
return "Unable to parse CIF string. Please check if the format is correct."
|
||||
|
||||
# 设置设备
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 使用 MatterSimCalculator 计算属性
|
||||
structure.calc = MatterSimCalculator(device=device)
|
||||
|
||||
# 直接获取能量、力和应力
|
||||
energy = structure.get_potential_energy()
|
||||
forces = structure.get_forces()
|
||||
stresses = structure.get_stress(voigt=False)
|
||||
|
||||
# 计算每原子能量
|
||||
num_atoms = len(structure)
|
||||
energy_per_atom = energy / num_atoms
|
||||
|
||||
# 计算应力(GPa和eV/A^3格式)
|
||||
stresses_ev_a3 = stresses
|
||||
stresses_gpa = stresses / GPa
|
||||
|
||||
# 构建返回的提示信息
|
||||
prompt = f"""
|
||||
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
|
||||
|
||||
Prediction results using the provided CIF structure:
|
||||
|
||||
- Total Energy (eV): {energy}
|
||||
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
|
||||
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
|
||||
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
|
||||
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
# 异步执行预测操作
|
||||
return await asyncio.to_thread(run_prediction)
|
||||
192
mars_toolkit/compute/structure_opt.py
Normal file
192
mars_toolkit/compute/structure_opt.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Structure Optimization Module
|
||||
|
||||
This module provides functions for optimizing crystal structures using the FairChem model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from ase.io import read, write
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.error_handlers import handle_general_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FairChem模型
|
||||
calc = None
|
||||
|
||||
def init_model():
|
||||
"""初始化FairChem模型"""
|
||||
global calc
|
||||
if calc is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from fairchem.core import OCPCalculator
|
||||
calc = OCPCalculator(checkpoint_path=config.FAIRCHEM_MODEL_PATH)
|
||||
logger.info("FairChem model initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||||
raise
|
||||
|
||||
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
||||
"""
|
||||
将输入内容转换为Atoms对象
|
||||
|
||||
Args:
|
||||
input_format: 输入格式 (cif, xyz, vasp等)
|
||||
content: 结构内容字符串
|
||||
|
||||
Returns:
|
||||
ASE Atoms对象,如果转换失败则返回None
|
||||
"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
atoms = read(tmp_path)
|
||||
os.unlink(tmp_path)
|
||||
return atoms
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert structure: {str(e)}")
|
||||
return None
|
||||
|
||||
def generate_symmetry_cif(structure: Structure) -> str:
|
||||
"""
|
||||
生成对称性CIF
|
||||
|
||||
Args:
|
||||
structure: Pymatgen Structure对象
|
||||
|
||||
Returns:
|
||||
CIF格式的字符串
|
||||
"""
|
||||
analyzer = SpacegroupAnalyzer(structure)
|
||||
structure_refined = analyzer.get_refined_structure()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
|
||||
cif_writer.write_file(tmp_file.name)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
os.unlink(tmp_file.name)
|
||||
return content
|
||||
|
||||
def optimize_structure(atoms: Atoms, output_format: str) -> str:
|
||||
"""
|
||||
优化晶体结构
|
||||
|
||||
Args:
|
||||
atoms: ASE Atoms对象
|
||||
output_format: 输出格式 (cif, xyz, vasp等)
|
||||
|
||||
Returns:
|
||||
包含优化结果的格式化字符串
|
||||
"""
|
||||
atoms.calc = calc
|
||||
|
||||
try:
|
||||
# 捕获优化过程的输出
|
||||
temp_output = StringIO()
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = temp_output
|
||||
|
||||
# 执行优化
|
||||
dyn = FIRE(FrechetCellFilter(atoms))
|
||||
dyn.run(fmax=config.FMAX)
|
||||
|
||||
# 恢复标准输出并获取日志
|
||||
sys.stdout = original_stdout
|
||||
optimization_log = temp_output.getvalue()
|
||||
temp_output.close()
|
||||
|
||||
# 获取总能量
|
||||
total_energy = atoms.get_potential_energy()
|
||||
|
||||
# 处理优化后的结构
|
||||
if output_format == "cif":
|
||||
optimized_structure = Structure.from_ase_atoms(atoms)
|
||||
content = generate_symmetry_cif(optimized_structure)
|
||||
content = remove_symmetry_equiv_xyz(content)
|
||||
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||||
write(tmp_file.name, atoms)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
# 格式化返回结果
|
||||
format_result = f"""
|
||||
The following is the optimized crystal structure information:
|
||||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||
**Total Energy: {total_energy} eV**
|
||||
|
||||
#### Optimizing Log:
|
||||
```text
|
||||
{optimization_log}
|
||||
```
|
||||
### Optimized {output_format.upper()} Content:
|
||||
```
|
||||
{content}
|
||||
```
|
||||
"""
|
||||
return format_result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to optimize structure: {str(e)}")
|
||||
raise e
|
||||
|
||||
@llm_tool(name="optimize_crystal_structure",
|
||||
description="Optimize crystal structure using FairChem model")
|
||||
async def optimize_crystal_structure(
|
||||
content: str,
|
||||
input_format: str = "cif",
|
||||
output_format: str = "cif"
|
||||
) -> str:
|
||||
"""
|
||||
Optimize crystal structure using FairChem model.
|
||||
|
||||
Args:
|
||||
content: Crystal structure content string
|
||||
input_format: Input format (cif, xyz, vasp)
|
||||
output_format: Output format (cif, xyz, vasp)
|
||||
|
||||
Returns:
|
||||
Optimized structure with energy and optimization log
|
||||
"""
|
||||
# 确保模型已初始化
|
||||
if calc is None:
|
||||
init_model()
|
||||
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_optimization():
|
||||
# 转换结构
|
||||
atoms = convert_structure(input_format, content)
|
||||
if atoms is None:
|
||||
raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确")
|
||||
|
||||
# 优化结构
|
||||
return optimize_structure(atoms, output_format)
|
||||
|
||||
try:
|
||||
# 直接返回结果或抛出异常
|
||||
return await asyncio.to_thread(run_optimization)
|
||||
except Exception as e:
|
||||
return handle_general_error(e)
|
||||
13
mars_toolkit/core/__init__.py
Normal file
13
mars_toolkit/core/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Core Module
|
||||
|
||||
This module provides core functionality for the Mars Toolkit.
|
||||
"""
|
||||
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.utils import settings, setup_logging
|
||||
from mars_toolkit.core.error_handlers import (
|
||||
handle_minio_error, handle_http_error,
|
||||
handle_validation_error, handle_general_error
|
||||
)
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
BIN
mars_toolkit/core/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/config.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/config.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/error_handlers.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/error_handlers.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/utils.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
121
mars_toolkit/core/cif_utils.py
Normal file
121
mars_toolkit/core/cif_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
CIF Utilities Module
|
||||
|
||||
This module provides basic functions for handling CIF (Crystallographic Information File) files,
|
||||
which are commonly used in materials science for representing crystal structures.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def read_cif_txt_file(file_path):
|
||||
"""
|
||||
Read the CIF file and return its content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the CIF file
|
||||
|
||||
Returns:
|
||||
String content of the CIF file or None if an error occurs
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def extract_cif_info(path: str, fields_name: list):
|
||||
"""
|
||||
Extract specific fields from the CIF description JSON file.
|
||||
|
||||
Args:
|
||||
path: Path to the JSON file containing CIF information
|
||||
fields_name: List of field categories to extract. Use 'all_fields' to extract all fields.
|
||||
Other options include 'basic_fields', 'energy_electronic_fields', 'metal_magentic_fields'
|
||||
|
||||
Returns:
|
||||
Dictionary containing the extracted fields
|
||||
"""
|
||||
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
|
||||
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
|
||||
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
|
||||
|
||||
selected_fields = []
|
||||
if fields_name[0] == 'all_fields':
|
||||
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
|
||||
else:
|
||||
for field in fields_name:
|
||||
selected_fields.extend(locals().get(field, []))
|
||||
|
||||
with open(path, 'r') as f:
|
||||
docs = json.load(f)
|
||||
|
||||
new_docs = {}
|
||||
for field_name in selected_fields:
|
||||
new_docs[field_name] = docs.get(field_name, '')
|
||||
|
||||
return new_docs
|
||||
|
||||
def remove_symmetry_equiv_xyz(cif_content):
|
||||
"""
|
||||
Remove symmetry operations section from CIF file content.
|
||||
|
||||
This is often useful when working with CIF files in certain visualization tools
|
||||
or when focusing on the basic structure without symmetry operations.
|
||||
|
||||
Args:
|
||||
cif_content: CIF file content string
|
||||
|
||||
Returns:
|
||||
Cleaned CIF content string with symmetry operations removed
|
||||
"""
|
||||
lines = cif_content.split('\n')
|
||||
output_lines = []
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# 检测循环开始
|
||||
if line == 'loop_':
|
||||
# 查看下一行,检查是否是对称性循环
|
||||
next_lines = []
|
||||
j = i + 1
|
||||
while j < len(lines) and lines[j].strip().startswith('_'):
|
||||
next_lines.append(lines[j].strip())
|
||||
j += 1
|
||||
|
||||
# 检查是否包含对称性操作标签
|
||||
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
|
||||
# 跳过整个循环块
|
||||
while i < len(lines):
|
||||
if i + 1 >= len(lines):
|
||||
break
|
||||
|
||||
next_line = lines[i + 1].strip()
|
||||
# 检查是否到达下一个循环或数据块
|
||||
if next_line == 'loop_' or next_line.startswith('data_'):
|
||||
break
|
||||
|
||||
# 检查是否到达原子位置部分
|
||||
if next_line.startswith('_atom_site_'):
|
||||
break
|
||||
|
||||
i += 1
|
||||
else:
|
||||
# 不是对称性循环,保留loop_行
|
||||
output_lines.append(lines[i])
|
||||
else:
|
||||
# 非循环开始行,直接保留
|
||||
output_lines.append(lines[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
59
mars_toolkit/core/config.py
Normal file
59
mars_toolkit/core/config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Configuration Module
|
||||
|
||||
This module provides configuration settings for the Mars Toolkit.
|
||||
It includes API keys, endpoints, paths, and other configuration parameters.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
class Config:
|
||||
"""Configuration class for Mars Toolkit"""
|
||||
|
||||
# Materials Project
|
||||
MP_API_KEY = 'PMASAg256b814q3OaSRWeVc7MKx4mlKI'
|
||||
MP_ENDPOINT = 'https://api.materialsproject.org/'
|
||||
MP_TOPK = 3
|
||||
|
||||
LOCAL_MP_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/'
|
||||
|
||||
# Proxy
|
||||
HTTP_PROXY = 'http://192.168.168.1:20171'
|
||||
HTTPS_PROXY = 'http://192.168.168.1:20171'
|
||||
|
||||
# FairChem
|
||||
FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
|
||||
FMAX = 0.05
|
||||
|
||||
# MatterGen
|
||||
MATTERGENMODEL_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
|
||||
MATTERGENMODEL_RESULT_PATH = 'results/'
|
||||
|
||||
# Dify
|
||||
DIFY_ROOT_URL = 'http://192.168.191.101:6080'
|
||||
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
|
||||
|
||||
# Searxng
|
||||
SEARXNG_HOST="http://192.168.191.101:40032/"
|
||||
|
||||
# Visualization
|
||||
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
|
||||
|
||||
@classmethod
|
||||
def as_dict(cls) -> Dict[str, Any]:
|
||||
"""Return all configuration settings as a dictionary"""
|
||||
return {
|
||||
key: value for key, value in cls.__dict__.items()
|
||||
if not key.startswith('__') and not callable(value)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def update(cls, **kwargs):
|
||||
"""Update configuration settings"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(cls, key):
|
||||
setattr(cls, key, value)
|
||||
|
||||
|
||||
# Create a global instance for easy access
|
||||
config = Config()
|
||||
55
mars_toolkit/core/error_handlers.py
Normal file
55
mars_toolkit/core/error_handlers.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Error Handlers Module
|
||||
|
||||
This module provides error handling utilities for the Mars Toolkit.
|
||||
It includes functions for handling various types of errors that may occur
|
||||
during toolkit operations.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class APIError(HTTPException):
|
||||
"""自定义API错误类"""
|
||||
def __init__(self, status_code: int, detail: Any = None):
|
||||
super().__init__(status_code=status_code, detail=detail)
|
||||
logger.error(f"API Error: {status_code} - {detail}")
|
||||
|
||||
def handle_minio_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理MinIO相关错误"""
|
||||
logger.error(f"MinIO operation failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"MinIO operation failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_http_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理HTTP请求错误"""
|
||||
logger.error(f"HTTP request failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"HTTP request failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_validation_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理数据验证错误"""
|
||||
logger.error(f"Validation failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"Validation failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_general_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理通用错误"""
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"Unexpected error: {str(e)}"
|
||||
}
|
||||
213
mars_toolkit/core/llm_tools.py
Normal file
213
mars_toolkit/core/llm_tools.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
LLM Tools Module
|
||||
|
||||
This module provides decorators and utilities for defining, registering, and managing LLM tools.
|
||||
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
|
||||
registered tools for use with LLM APIs.
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
|
||||
import docstring_parser
|
||||
from pydantic import BaseModel, create_model, Field
|
||||
|
||||
# Registry to store all registered tools
|
||||
_TOOL_REGISTRY = {}
|
||||
|
||||
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
|
||||
"""
|
||||
Decorator to mark a function as an LLM tool.
|
||||
|
||||
This decorator registers the function as an LLM tool, generates a JSON schema for it,
|
||||
and makes it available for retrieval through the get_tools function.
|
||||
|
||||
Args:
|
||||
name: Optional custom name for the tool. If not provided, the function name will be used.
|
||||
description: Optional custom description for the tool. If not provided, the function's
|
||||
docstring will be used.
|
||||
|
||||
Returns:
|
||||
The decorated function with additional attributes for LLM tool functionality.
|
||||
|
||||
Example:
|
||||
@llm_tool(name="weather_lookup", description="Get current weather for a location")
|
||||
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
|
||||
'''Get weather information for a specific location.'''
|
||||
# Implementation...
|
||||
return {"temperature": 22.5, "conditions": "sunny"}
|
||||
"""
|
||||
# Handle case when decorator is used without parentheses: @llm_tool
|
||||
if callable(name):
|
||||
func = name
|
||||
name = None
|
||||
description = None
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
|
||||
def decorator(func: Callable) -> Callable:
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
return decorator
|
||||
|
||||
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
|
||||
"""Implementation of the llm_tool decorator."""
|
||||
# Get function signature and docstring
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func) or ""
|
||||
parsed_doc = docstring_parser.parse(doc)
|
||||
|
||||
# Determine tool name
|
||||
tool_name = name or func.__name__
|
||||
|
||||
# Determine tool description
|
||||
tool_description = description or doc
|
||||
|
||||
# Create parameter properties for JSON schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self parameter for methods
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = None if param.default is inspect.Parameter.empty else param.default
|
||||
param_required = param.default is inspect.Parameter.empty
|
||||
|
||||
# Get parameter description from docstring if available
|
||||
param_desc = ""
|
||||
for param_doc in parsed_doc.params:
|
||||
if param_doc.arg_name == param_name:
|
||||
param_desc = param_doc.description
|
||||
break
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0] # The actual type
|
||||
if len(args) > 1 and isinstance(args[1], str):
|
||||
param_desc = args[1] # The description
|
||||
|
||||
# Create property for parameter
|
||||
param_schema = {
|
||||
"type": _get_json_type(param_type),
|
||||
"description": param_desc,
|
||||
"title": param_name.replace("_", " ").title()
|
||||
}
|
||||
|
||||
# Add default value if available
|
||||
if param_default is not None:
|
||||
param_schema["default"] = param_default
|
||||
|
||||
properties[param_name] = param_schema
|
||||
|
||||
# Add to required list if no default value
|
||||
if param_required:
|
||||
required.append(param_name)
|
||||
|
||||
# Create JSON schema
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Create Pydantic model for args schema
|
||||
field_definitions = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = ... if param.default is inspect.Parameter.empty else param.default
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0]
|
||||
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
|
||||
else:
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default))
|
||||
|
||||
# Create args schema model
|
||||
model_name = f"{tool_name.title().replace('_', '')}Schema"
|
||||
args_schema = create_model(model_name, **field_definitions)
|
||||
|
||||
# 根据原始函数是否是异步函数来创建相应类型的包装函数
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Attach metadata to function
|
||||
wrapper.is_llm_tool = True
|
||||
wrapper.tool_name = tool_name
|
||||
wrapper.tool_description = tool_description
|
||||
wrapper.json_schema = schema
|
||||
wrapper.args_schema = args_schema
|
||||
|
||||
# Register the tool
|
||||
_TOOL_REGISTRY[tool_name] = wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_tools() -> Dict[str, Callable]:
|
||||
"""
|
||||
Get all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping tool names to their corresponding functions.
|
||||
"""
|
||||
return _TOOL_REGISTRY
|
||||
|
||||
def get_tool_schemas() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get JSON schemas for all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
|
||||
"""
|
||||
return [tool.json_schema for tool in _TOOL_REGISTRY.values()]
|
||||
|
||||
def _get_json_type(python_type: Any) -> str:
|
||||
"""
|
||||
Convert Python type to JSON schema type.
|
||||
|
||||
Args:
|
||||
python_type: Python type annotation
|
||||
|
||||
Returns:
|
||||
Corresponding JSON schema type as string
|
||||
"""
|
||||
if python_type is str:
|
||||
return "string"
|
||||
elif python_type is int:
|
||||
return "integer"
|
||||
elif python_type is float:
|
||||
return "number"
|
||||
elif python_type is bool:
|
||||
return "boolean"
|
||||
elif python_type is list or python_type is List:
|
||||
return "array"
|
||||
elif python_type is dict or python_type is Dict:
|
||||
return "object"
|
||||
else:
|
||||
# Default to string for complex types
|
||||
return "string"
|
||||
75
mars_toolkit/core/utils.py
Normal file
75
mars_toolkit/core/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
import boto3
|
||||
import logging
|
||||
import logging.config
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Material Project
|
||||
mp_api_key: Optional[str] = Field(None, env="MP_API_KEY")
|
||||
mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT")
|
||||
mp_topk: Optional[int] = Field(3, env="MP_TOPK")
|
||||
|
||||
# Proxy
|
||||
http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
|
||||
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
|
||||
|
||||
# FairChem
|
||||
fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH")
|
||||
fmax: Optional[float] = Field(0.05, env="FMAX")
|
||||
|
||||
# MinIO
|
||||
minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT")
|
||||
internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT")
|
||||
minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY")
|
||||
minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY")
|
||||
minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET")
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
def setup_logging():
|
||||
"""配置日志记录"""
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
log_file_path = os.path.join(parent_dir, 'mars_toolkit.log')
|
||||
|
||||
logging.config.dictConfig({
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'standard': {
|
||||
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
'datefmt': '%Y-%m-%d %H:%M:%S'
|
||||
},
|
||||
},
|
||||
'handlers': {
|
||||
'console': {
|
||||
'level': 'INFO',
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': 'standard'
|
||||
},
|
||||
'file': {
|
||||
'level': 'DEBUG',
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'filename': log_file_path,
|
||||
'maxBytes': 10485760, # 10MB
|
||||
'backupCount': 5,
|
||||
'formatter': 'standard'
|
||||
}
|
||||
},
|
||||
'loggers': {
|
||||
'': {
|
||||
'handlers': ['console', 'file'],
|
||||
'level': 'INFO',
|
||||
'propagate': True
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
# 初始化配置
|
||||
settings = Settings()
|
||||
7
mars_toolkit/misc/__init__.py
Normal file
7
mars_toolkit/misc/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Basic Module
|
||||
|
||||
This module provides basic utility functions for the Mars Toolkit.
|
||||
"""
|
||||
|
||||
from mars_toolkit.misc.misc_tools import get_current_time
|
||||
BIN
mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/misc/__pycache__/general_tools.cpython-310.pyc
Normal file
BIN
mars_toolkit/misc/__pycache__/general_tools.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc
Normal file
BIN
mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc
Normal file
Binary file not shown.
29
mars_toolkit/misc/misc_tools.py
Normal file
29
mars_toolkit/misc/misc_tools.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
General Tools Module
|
||||
|
||||
This module provides basic utility functions that are not specific to materials science.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
from typing import Annotated
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
|
||||
@llm_tool(name="get_current_time", description="Get current date and time in specified timezone")
|
||||
async def get_current_time(timezone: str = "UTC") -> str:
|
||||
"""Returns the current date and time in the specified timezone.
|
||||
|
||||
Args:
|
||||
timezone: Timezone name (e.g., UTC, Asia/Shanghai, America/New_York)
|
||||
|
||||
Returns:
|
||||
Formatted date and time string
|
||||
"""
|
||||
try:
|
||||
tz = pytz.timezone(timezone)
|
||||
current_time = datetime.now(tz)
|
||||
return f"The current {timezone} time is: {current_time.strftime('%Y-%m-%d %H:%M:%S %Z')}"
|
||||
except pytz.exceptions.UnknownTimeZoneError:
|
||||
return f"Unknown timezone: {timezone}. Please use a valid timezone such as 'UTC', 'Asia/Shanghai', etc."
|
||||
18
mars_toolkit/query/__init__.py
Normal file
18
mars_toolkit/query/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Query Module
|
||||
|
||||
This module provides query tools for materials science, including:
|
||||
- Materials Project database queries
|
||||
- OQMD database queries
|
||||
- Dify knowledge base retrieval
|
||||
- Web search
|
||||
"""
|
||||
|
||||
from mars_toolkit.query.mp_query import (
|
||||
search_material_property_from_material_project,
|
||||
get_crystal_structures_from_materials_project,
|
||||
get_mpid_from_formula
|
||||
)
|
||||
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
from mars_toolkit.query.web_search import search_online
|
||||
BIN
mars_toolkit/query/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/web_search.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/web_search.cpython-310.pyc
Normal file
Binary file not shown.
84
mars_toolkit/query/dify_search.py
Normal file
84
mars_toolkit/query/dify_search.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Dify Search Module
|
||||
|
||||
This module provides functions for retrieving information from local materials science
|
||||
literature knowledge base using Dify API.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import requests
|
||||
import codecs
|
||||
from typing import Dict, Any
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
|
||||
@llm_tool(
|
||||
name="retrieval_from_knowledge_base",
|
||||
description="Retrieve information from local materials science literature knowledge base"
|
||||
)
|
||||
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
|
||||
"""
|
||||
检索本地材料科学文献知识库中的相关信息
|
||||
|
||||
Args:
|
||||
query: 查询字符串,如材料名称"CsPbBr3"
|
||||
topk: 返回结果数量,默认3条
|
||||
|
||||
Returns:
|
||||
包含文档ID、标题和相关性分数的字典
|
||||
"""
|
||||
# 设置Dify API的URL端点
|
||||
url = f'{config.DIFY_ROOT_URL}/v1/chat-messages'
|
||||
|
||||
# 配置请求头,包含API密钥和内容类型
|
||||
headers = {
|
||||
'Authorization': f'Bearer {config.DIFY_API_KEY}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# 准备请求数据
|
||||
data = {
|
||||
"inputs": {"topK": topk}, # 设置返回的最大结果数量
|
||||
"query": query, # 设置查询字符串
|
||||
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
|
||||
"conversation_id": "", # 不使用会话ID,每次都是独立查询
|
||||
"user": "abc-123" # 用户标识符
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送POST请求到Dify API并获取响应
|
||||
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
|
||||
response = requests.post(url, headers=headers, json=data, timeout=1111)
|
||||
|
||||
# 获取响应文本
|
||||
response_text = response.text
|
||||
|
||||
# 解码响应文本中的Unicode转义序列
|
||||
response_text = codecs.decode(response_text, 'unicode_escape')
|
||||
|
||||
# 将响应文本解析为JSON对象
|
||||
result_json = json.loads(response_text)
|
||||
|
||||
# 从响应中提取元数据
|
||||
metadata = result_json.get("metadata", {})
|
||||
|
||||
# 构建包含关键信息的结果字典
|
||||
useful_info = {
|
||||
"id": metadata.get("document_id"), # 文档ID
|
||||
"title": result_json.get("title"), # 文档标题
|
||||
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
|
||||
"score": metadata.get("score") # 相关性分数
|
||||
}
|
||||
|
||||
# 返回提取的有用信息
|
||||
return json.dumps(useful_info, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并处理所有可能的异常,返回错误信息
|
||||
return f"错误: {str(e)}"
|
||||
433
mars_toolkit/query/mp_query.py
Normal file
433
mars_toolkit/query/mp_query.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Materials Project Query Module
|
||||
|
||||
This module provides functions for querying the Materials Project database,
|
||||
processing search results, and formatting responses.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import datetime
|
||||
import os
|
||||
from multiprocessing import Process, Manager
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from mp_api.client import MPRester
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.error_handlers import handle_general_error
|
||||
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_bool(param: str) -> bool | None:
|
||||
"""
|
||||
Parse a string parameter into a boolean value.
|
||||
|
||||
Args:
|
||||
param: String parameter to parse (e.g., "true", "false")
|
||||
|
||||
Returns:
|
||||
Boolean value if param is not empty, None otherwise
|
||||
"""
|
||||
if not param:
|
||||
return None
|
||||
return param.lower() == 'true'
|
||||
|
||||
def parse_list(param: str) -> List[str] | None:
|
||||
"""
|
||||
Parse a comma-separated string into a list of strings.
|
||||
|
||||
Args:
|
||||
param: Comma-separated string (e.g., "Li,Fe,O")
|
||||
|
||||
Returns:
|
||||
List of strings if param is not empty, None otherwise
|
||||
"""
|
||||
if not param:
|
||||
return None
|
||||
return param.split(',')
|
||||
|
||||
def parse_tuple(param: str) -> tuple[float, float] | None:
|
||||
"""
|
||||
Parse a comma-separated string into a tuple of two float values.
|
||||
|
||||
Used for range parameters like band_gap, density, etc.
|
||||
|
||||
Args:
|
||||
param: Comma-separated string of two numbers (e.g., "0,3.5")
|
||||
|
||||
Returns:
|
||||
Tuple of two float values if param is valid, None otherwise
|
||||
"""
|
||||
if not param:
|
||||
return None
|
||||
try:
|
||||
values = param.split(',')
|
||||
return (float(values[0]), float(values[1]))
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse search parameters from query parameters.
|
||||
|
||||
Converts string query parameters into appropriate types for the Materials Project API.
|
||||
"""
|
||||
return {
|
||||
'band_gap': parse_tuple(query_params.get('band_gap')),
|
||||
'chemsys': parse_list(query_params.get('chemsys')),
|
||||
'crystal_system': parse_list(query_params.get('crystal_system')),
|
||||
'density': parse_tuple(query_params.get('density')),
|
||||
'formation_energy': parse_tuple(query_params.get('formation_energy')),
|
||||
'elements': parse_list(query_params.get('elements')),
|
||||
'exclude_elements': parse_list(query_params.get('exclude_elements')),
|
||||
'formula': parse_list(query_params.get('formula')),
|
||||
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
|
||||
'is_metal': parse_bool(query_params.get('is_metal')),
|
||||
'is_stable': parse_bool(query_params.get('is_stable')),
|
||||
'magnetic_ordering': query_params.get('magnetic_ordering'),
|
||||
'material_ids': parse_list(query_params.get('material_ids')),
|
||||
'total_energy': parse_tuple(query_params.get('total_energy')),
|
||||
'num_elements': parse_tuple(query_params.get('num_elements')),
|
||||
'volume': parse_tuple(query_params.get('volume')),
|
||||
'chunk_size': int(query_params.get('chunk_size', '5'))
|
||||
}
|
||||
|
||||
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]] | str:
|
||||
"""
|
||||
Process search results from the Materials Project API.
|
||||
|
||||
Extracts relevant fields from each document and formats them into a consistent structure.
|
||||
|
||||
Returns:
|
||||
List of processed documents or error message string if an exception occurs
|
||||
"""
|
||||
try:
|
||||
fields = [
|
||||
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
|
||||
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
|
||||
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
|
||||
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
|
||||
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
|
||||
'universal_anisotropy', 'theoretical'
|
||||
]
|
||||
|
||||
res = []
|
||||
for doc in docs:
|
||||
try:
|
||||
new_docs = {}
|
||||
for field_name in fields:
|
||||
new_docs[field_name] = doc.get(field_name, '')
|
||||
res.append(new_docs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing document: {str(e)}")
|
||||
continue
|
||||
return res
|
||||
except Exception as e:
|
||||
error_msg = f"Error in process_search_results: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
def _search_worker(queue, api_key, **kwargs):
|
||||
"""
|
||||
Worker function for executing Materials Project API searches.
|
||||
|
||||
Runs in a separate process to perform the actual API call and puts results in the queue.
|
||||
|
||||
Args:
|
||||
queue: Multiprocessing queue for returning results
|
||||
api_key: Materials Project API key
|
||||
**kwargs: Search parameters to pass to the API
|
||||
"""
|
||||
try:
|
||||
import os
|
||||
import traceback
|
||||
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
|
||||
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
|
||||
|
||||
# 初始化 MPRester 客户端
|
||||
with MPRester(api_key) as mpr:
|
||||
result = mpr.materials.summary.search(**kwargs)
|
||||
|
||||
# 检查结果
|
||||
if result:
|
||||
# 尝试使用更安全的方式处理结果
|
||||
processed_results = []
|
||||
for doc in result:
|
||||
try:
|
||||
# 尝试使用 model_dump 方法
|
||||
processed_doc = doc.model_dump()
|
||||
processed_results.append(processed_doc)
|
||||
except AttributeError:
|
||||
# 如果没有 model_dump 方法,尝试使用 dict 方法
|
||||
try:
|
||||
processed_doc = doc.dict()
|
||||
processed_results.append(processed_doc)
|
||||
except AttributeError:
|
||||
# 如果没有 dict 方法,尝试直接转换为字典
|
||||
if hasattr(doc, "__dict__"):
|
||||
processed_doc = doc.__dict__
|
||||
# 移除可能导致序列化问题的特殊属性
|
||||
if "_sa_instance_state" in processed_doc:
|
||||
del processed_doc["_sa_instance_state"]
|
||||
processed_results.append(processed_doc)
|
||||
else:
|
||||
# 最后的尝试,直接使用 doc
|
||||
processed_results.append(doc)
|
||||
|
||||
queue.put(processed_results)
|
||||
else:
|
||||
queue.put([])
|
||||
except Exception as e:
|
||||
queue.put(e)
|
||||
|
||||
async def execute_search(search_args: Dict[str, Any], timeout: int = 120) -> List[Dict[str, Any]] | str:
|
||||
"""
|
||||
Execute a search against the Materials Project API.
|
||||
|
||||
Runs the search in a separate process to handle potential timeouts and returns the results.
|
||||
|
||||
Args:
|
||||
search_args: Dictionary of search parameters
|
||||
timeout: Maximum time in seconds to wait for the search to complete
|
||||
|
||||
Returns:
|
||||
List of document dictionaries from the search results or error message string if an exception occurs
|
||||
"""
|
||||
# 确保 formula 参数是列表类型
|
||||
if 'formula' in search_args and isinstance(search_args['formula'], str):
|
||||
search_args['formula'] = [search_args['formula']]
|
||||
|
||||
manager = Manager()
|
||||
queue = manager.Queue()
|
||||
|
||||
try:
|
||||
p = Process(target=_search_worker, args=(queue, config.MP_API_KEY), kwargs=search_args)
|
||||
p.start()
|
||||
p.join(timeout=timeout)
|
||||
|
||||
if p.is_alive():
|
||||
logger.warning(f"Terminating worker process {p.pid} due to timeout")
|
||||
p.terminate()
|
||||
p.join()
|
||||
error_msg = f"Request timed out after {timeout} seconds"
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
result = queue.get(timeout=timeout)
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error in search worker: {str(result)}")
|
||||
if hasattr(result, "__traceback__"):
|
||||
import traceback
|
||||
tb_str = ''.join(traceback.format_exception(None, result, result.__traceback__))
|
||||
return f"Error in search worker: {str(result)}"
|
||||
|
||||
return result
|
||||
|
||||
except queue.Empty:
|
||||
error_msg = "Failed to retrieve data from queue (timeout)"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Error in execute_search: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
@llm_tool(name="search_material_property_from_material_project", description="Search materials in Materials Project database by formula and properties")
|
||||
async def search_material_property_from_material_project(
|
||||
formula: str | list[str],
|
||||
chemsys: Optional[str | list[str] | None] = None,
|
||||
crystal_system: Optional[str | list[str] | None] = None,
|
||||
is_gap_direct: Optional[bool | None] = None,
|
||||
is_stable: Optional[bool | None] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search materials in Materials Project database.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula(s) (e.g., "Fe2O3" or ["ABO3", "Si*"])
|
||||
chemsys: Chemical system(s) (e.g., "Li-Fe-O")
|
||||
crystal_system: Crystal system(s) (e.g., "Cubic")
|
||||
is_gap_direct: Filter for direct band gap materials
|
||||
is_stable: Filter for thermodynamically stable materials
|
||||
Returns:
|
||||
JSON formatted material properties data
|
||||
"""
|
||||
# 验证晶系参数
|
||||
VALID_CRYSTAL_SYSTEMS = ['Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal', 'Cubic']
|
||||
|
||||
# 验证晶系参数是否有效
|
||||
if crystal_system is not None:
|
||||
if isinstance(crystal_system, str):
|
||||
if crystal_system not in VALID_CRYSTAL_SYSTEMS:
|
||||
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
|
||||
elif isinstance(crystal_system, list):
|
||||
for cs in crystal_system:
|
||||
if cs not in VALID_CRYSTAL_SYSTEMS:
|
||||
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
|
||||
|
||||
# 确保 formula 是列表类型
|
||||
if isinstance(formula, str):
|
||||
formula = [formula]
|
||||
|
||||
params = {
|
||||
"chemsys": chemsys,
|
||||
"crystal_system": crystal_system,
|
||||
"formula": formula,
|
||||
"is_gap_direct": is_gap_direct,
|
||||
"is_stable": is_stable,
|
||||
"chunk_size": 5,
|
||||
}
|
||||
|
||||
# Filter out None values
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
mp_id_list = await get_mpid_from_formula(formula=formula)
|
||||
try:
|
||||
res=[]
|
||||
for mp_id in mp_id_list:
|
||||
crystal_props = extract_cif_info(config.LOCAL_MP_ROOT+f"/Props/{mp_id}.json", ['all_fields'])
|
||||
res.append(crystal_props)
|
||||
|
||||
if len(res) == 0:
|
||||
return "No results found, please try again."
|
||||
|
||||
# Format response with top results
|
||||
try:
|
||||
# 创建包含索引的JSON结果
|
||||
formatted_results = []
|
||||
for i, item in enumerate(res[:config.MP_TOPK], 1):
|
||||
formatted_result = f"[property {i} begin]\n"
|
||||
formatted_result += json.dumps(item, indent=2)
|
||||
formatted_result += f"\n[property {i} end]\n\n"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
# 将所有结果合并为一个字符串
|
||||
res_chunk = "\n\n".join(formatted_results)
|
||||
res_template = f"""
|
||||
Here are the search results from the Materials Project database:
|
||||
Due to length limitations, only the top {config.MP_TOPK} results are shown below:\n
|
||||
{res_chunk}
|
||||
If you need more results, please modify your search criteria or try different query parameters.
|
||||
"""
|
||||
return res_template
|
||||
except Exception as format_error:
|
||||
logger.error(f"Error formatting results: {str(format_error)}")
|
||||
return str(format_error)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search_material_property_from_material_project: {str(e)}")
|
||||
return str(e)
|
||||
|
||||
@llm_tool(name="get_crystal_structures_from_materials_project", description="Get symmetrized crystal structures CIF data from Materials Project database by chemical formula")
|
||||
async def get_crystal_structures_from_materials_project(
|
||||
formulas: list[str],
|
||||
conventional_unit_cell: bool = True,
|
||||
symprec: float = 0.1
|
||||
) -> str:
|
||||
"""
|
||||
Get crystal structures from Materials Project database by chemical formula and apply symmetrization.
|
||||
|
||||
Args:
|
||||
formulas: List of chemical formulas (e.g., ["Fe2O3", "SiO2", "TiO2"])
|
||||
conventional_unit_cell: Whether to return conventional unit cell (True) or primitive cell (False)
|
||||
symprec: Precision parameter for symmetrization
|
||||
|
||||
Returns:
|
||||
Formatted text containing symmetrized CIF data
|
||||
"""
|
||||
result={}
|
||||
mp_id_list=await get_mpid_from_formula(formula=formulas)
|
||||
|
||||
for i,mp_id in enumerate(mp_id_list):
|
||||
cif_file = glob.glob(config.LOCAL_MP_ROOT+f"/MPDatasets/{mp_id}.cif")[0]
|
||||
structure = Structure.from_file(cif_file)
|
||||
# 如果需要常规单元格
|
||||
if conventional_unit_cell:
|
||||
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
|
||||
|
||||
# 对结构进行对称化处理
|
||||
sga = SpacegroupAnalyzer(structure, symprec=symprec)
|
||||
symmetrized_structure = sga.get_refined_structure()
|
||||
|
||||
# 使用CifWriter生成CIF数据
|
||||
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
|
||||
cif_data = str(cif_writer)
|
||||
|
||||
# 删除CIF文件中的对称性操作部分
|
||||
cif_data = remove_symmetry_equiv_xyz(cif_data)
|
||||
cif_data=cif_data.replace('# generated using pymatgen',"")
|
||||
# 生成一个唯一的键
|
||||
formula = structure.composition.reduced_formula
|
||||
key = f"{formula}_{i}"
|
||||
|
||||
result[key] = cif_data
|
||||
|
||||
# 只保留前config.MP_TOPK个结果
|
||||
if len(result) >= config.MP_TOPK:
|
||||
break
|
||||
|
||||
try:
|
||||
prompt = f"""
|
||||
# Materials Project Symmetrized Crystal Structure Data
|
||||
|
||||
Below are symmetrized crystal structure data for {len(result)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
|
||||
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
|
||||
"""
|
||||
|
||||
for i, (key, cif_data) in enumerate(result.items(), 1):
|
||||
prompt += f"[cif {i} begin]\n"
|
||||
prompt += cif_data
|
||||
prompt += f"\n[cif {i} end]\n\n"
|
||||
|
||||
prompt += """
|
||||
## Usage Instructions
|
||||
|
||||
1. You can copy the above CIF data and save it as .cif files
|
||||
2. Open these files with crystal structure visualization software (such as VESTA, Mercury, Avogadro, etc.)
|
||||
3. These structures can be used for further material analysis, simulation, or visualization
|
||||
|
||||
CIF files contain complete structural information of crystals, including cell parameters, atomic coordinates, symmetry, etc.
|
||||
Symmetrization helps identify and optimize crystal symmetry, making the structure more standardized and accurate.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
except Exception as format_error:
|
||||
logger.error(f"Error formatting crystal structures: {str(format_error)}")
|
||||
return str(format_error)
|
||||
|
||||
@llm_tool(name="get_mpid_from_formula", description="Get material IDs (mpid) from Materials Project database by chemical formula")
|
||||
async def get_mpid_from_formula(formula: str) -> List[str]:
|
||||
"""
|
||||
Get material IDs (mpid) from Materials Project database by chemical formula.
|
||||
Returns mpids for the lowest energy structures.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula (e.g., "Fe2O3")
|
||||
|
||||
Returns:
|
||||
List of material IDs
|
||||
"""
|
||||
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
|
||||
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
|
||||
id_list = []
|
||||
try:
|
||||
with MPRester(config.MP_API_KEY) as mpr:
|
||||
docs = mpr.materials.summary.search(formula=formula)
|
||||
for doc in docs:
|
||||
id_list.append(doc.material_id)
|
||||
return id_list
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting mpid from formula: {str(e)}")
|
||||
return []
|
||||
105
mars_toolkit/query/oqmd_query.py
Normal file
105
mars_toolkit/query/oqmd_query.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
OQMD Query Module
|
||||
|
||||
This module provides functions for querying the Open Quantum Materials Database (OQMD).
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from bs4 import BeautifulSoup
|
||||
from io import StringIO
|
||||
from typing import Annotated
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@llm_tool(name="fetch_chemical_composition_from_OQMD", description="Fetch material data for a chemical composition from OQMD database")
|
||||
async def fetch_chemical_composition_from_OQMD(
|
||||
composition: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
|
||||
) -> str:
|
||||
"""
|
||||
Fetch material data for a chemical composition from OQMD database.
|
||||
|
||||
Args:
|
||||
composition: Chemical formula (e.g., Fe2O3, LiFePO4)
|
||||
|
||||
Returns:
|
||||
Formatted text with material information and property tables
|
||||
"""
|
||||
# Fetch data from OQMD
|
||||
url = f"https://www.oqmd.org/materials/composition/{composition}"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=100.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Validate response content
|
||||
if not response.text or len(response.text) < 100:
|
||||
raise ValueError("Invalid response content from OQMD API")
|
||||
|
||||
# Parse HTML data
|
||||
html = response.text
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# Parse basic data
|
||||
basic_data = []
|
||||
h1_element = soup.find('h1')
|
||||
if h1_element:
|
||||
basic_data.append(h1_element.text.strip())
|
||||
else:
|
||||
basic_data.append(f"Material: {composition}")
|
||||
|
||||
for script in soup.find_all('p'):
|
||||
if script:
|
||||
combined_text = ""
|
||||
for element in script.contents:
|
||||
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
|
||||
url = "https://www.oqmd.org" + element['href']
|
||||
combined_text += f"[{element.text.strip()}]({url}) "
|
||||
elif hasattr(element, 'text'):
|
||||
combined_text += element.text.strip() + " "
|
||||
else:
|
||||
combined_text += str(element).strip() + " "
|
||||
basic_data.append(combined_text.strip())
|
||||
|
||||
# Parse table data
|
||||
table_data = ""
|
||||
table = soup.find('table')
|
||||
if table:
|
||||
try:
|
||||
df = pd.read_html(StringIO(str(table)))[0]
|
||||
df = df.fillna('')
|
||||
df = df.replace([float('inf'), float('-inf')], '')
|
||||
table_data = df.to_markdown(index=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing table: {str(e)}")
|
||||
table_data = "Error parsing table data"
|
||||
|
||||
# Integrate data into a single text
|
||||
combined_text = "\n\n".join(basic_data)
|
||||
if table_data:
|
||||
combined_text += "\n\n## Material Properties Table\n\n" + table_data
|
||||
|
||||
return combined_text
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"OQMD API request failed: {str(e)}")
|
||||
return f"Error: OQMD API request failed - {str(e)}"
|
||||
except httpx.TimeoutException:
|
||||
logger.error("OQMD API request timed out")
|
||||
return "Error: OQMD API request timed out"
|
||||
except httpx.NetworkError as e:
|
||||
logger.error(f"Network error occurred: {str(e)}")
|
||||
return f"Error: Network error occurred - {str(e)}"
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid response content: {str(e)}")
|
||||
return f"Error: Invalid response content - {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
return f"Error: Unexpected error occurred - {str(e)}"
|
||||
77
mars_toolkit/query/web_search.py
Normal file
77
mars_toolkit/query/web_search.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Web Search Module
|
||||
|
||||
This module provides functions for searching information on the web.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated, Dict, Any, List
|
||||
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
|
||||
@llm_tool(name="search_online", description="Search scientific information online and return results as a string")
|
||||
async def search_online(
|
||||
query: Annotated[str, "Search term"],
|
||||
num_results: Annotated[int, "Number of results (1-20)"] = 5
|
||||
) -> str:
|
||||
"""
|
||||
Searches for scientific information online and returns results as a formatted string.
|
||||
|
||||
Args:
|
||||
query: Search term for scientific content
|
||||
num_results: Number of results to return (1-20)
|
||||
|
||||
Returns:
|
||||
Formatted string with search results (titles, snippets, links)
|
||||
"""
|
||||
# 确保 num_results 是整数
|
||||
try:
|
||||
num_results = int(num_results)
|
||||
except (TypeError, ValueError):
|
||||
num_results = 5
|
||||
|
||||
# Parameter validation
|
||||
if num_results < 1:
|
||||
num_results = 1
|
||||
elif num_results > 20:
|
||||
num_results = 20
|
||||
|
||||
# Initialize search wrapper
|
||||
search = SearxSearchWrapper(
|
||||
searx_host=config.SEARXNG_HOST,
|
||||
categories=["science",],
|
||||
k=num_results
|
||||
)
|
||||
|
||||
# Execute search in a separate thread to avoid blocking the event loop
|
||||
# since SearxSearchWrapper doesn't have native async support
|
||||
loop = asyncio.get_event_loop()
|
||||
raw_results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: search.results(query, language=['en','zh'], num_results=num_results)
|
||||
)
|
||||
|
||||
# Transform results into structured format
|
||||
formatted_results = []
|
||||
for result in raw_results:
|
||||
formatted_results.append({
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("snippet", ""),
|
||||
"link": result.get("link", ""),
|
||||
"source": result.get("source", "")
|
||||
})
|
||||
|
||||
# Convert the results to a formatted string
|
||||
result_str = f"Search Results for '{query}' ({len(formatted_results)} items):\n\n"
|
||||
|
||||
for i, result in enumerate(formatted_results, 1):
|
||||
result_str += f"Result {i}:\n"
|
||||
result_str += f"Title: {result['title']}\n"
|
||||
result_str += f"Summary: {result['snippet']}\n"
|
||||
result_str += f"Link: {result['link']}\n"
|
||||
result_str += f"Source: {result['source']}\n\n"
|
||||
|
||||
return result_str
|
||||
0
mars_toolkit/visualization/__init__.py
Normal file
0
mars_toolkit/visualization/__init__.py
Normal file
Reference in New Issue
Block a user