构建mars_toolkit,删除tools_for_ms

This commit is contained in:
lzy
2025-04-02 12:53:50 +08:00
parent 603304e10f
commit a77c2cd377
73 changed files with 1884 additions and 896 deletions

View File

@@ -190,7 +190,7 @@ def worker(data, output_file_path):
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['obeservation']=final_result
data['observation']=final_result
# print("#"*50,"start","#"*50)
# print(data['obeservation'])
# print("#"*50,'end',"#"*50)
@@ -199,7 +199,7 @@ def worker(data, output_file_path):
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # obeservation . data
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:

1172
mars_toolkit.log Normal file

File diff suppressed because it is too large Load Diff

46
mars_toolkit/__init__.py Normal file
View File

@@ -0,0 +1,46 @@
"""
Mars Toolkit
A comprehensive toolkit for materials science research, providing tools for:
- Material generation and property prediction
- Structure optimization
- Database queries (Materials Project, OQMD)
- Knowledge base retrieval
- Web search
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
# Core modules
from mars_toolkit.core.config import config
from mars_toolkit.core.utils import setup_logging
# Basic tools
from mars_toolkit.misc.misc_tools import get_current_time
# Compute modules
from mars_toolkit.compute.material_gen import generate_material
from mars_toolkit.compute.property_pred import predict_properties
from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure
# Query modules
from mars_toolkit.query.mp_query import (
search_material_property_from_material_project,
get_crystal_structures_from_materials_project,
get_mpid_from_formula
)
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
from mars_toolkit.query.web_search import search_online
from mars_toolkit.core.llm_tools import llm_tool, get_tools, get_tool_schemas
# Initialize logging
setup_logging()
__version__ = "0.1.0"
__all__ = ["llm_tool", "get_tools", "get_tool_schemas"]

Binary file not shown.

View File

@@ -0,0 +1,12 @@
"""
Compute Module
This module provides computational tools for materials science, including:
- Material generation
- Property prediction
- Structure optimization
"""
from mars_toolkit.compute.material_gen import generate_material
from mars_toolkit.compute.property_pred import predict_properties
from mars_toolkit.compute.structure_opt import optimize_crystal_structure, convert_structure

View File

@@ -1,8 +1,13 @@
"""
Material Generation Module
This module provides functions for generating crystal structures with optional property constraints.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import ast
import json
import logging
@@ -12,8 +17,10 @@ import datetime
import asyncio
import zipfile
import shutil
import re
from pathlib import Path
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
@@ -21,14 +28,15 @@ from ase.io import read, write
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
# Use our wrapper module instead of direct imports
# 导入路径已更新
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
# 使用mattergen_wrapper
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
#import mattergen_wrapper
# Access the modules through the wrapper
# The generator module is re-exported as an attribute, not a submodule
from mattergen_wrapper import generator
CrystalGenerator = generator.CrystalGenerator
from mattergen.common.data.types import TargetProperty
@@ -37,12 +45,8 @@ from mattergen.common.utils.data_classes import (
PRETRAINED_MODEL_NAME,
MatterGenCheckpointInfo,
)
logger = logging.getLogger(__name__)
from tools_for_ms.services_tools.Configs import *
from tools_for_ms.llm_tools import llm_tool
import json
import re
logger = logging.getLogger(__name__)
def format_cif_content(content):
@@ -100,9 +104,16 @@ def format_cif_content(content):
return "\n".join(result)
def convert_values(data_str):
# 将字符串转换为字典
"""
将字符串转换为字典
Args:
data_str: JSON字符串
Returns:
解析后的数据如果解析失败则返回原字符串
"""
try:
data = json.loads(data_str)
except json.JSONDecodeError:
@@ -265,38 +276,38 @@ async def generate_material(
Returns:
Descriptive text with generated crystal structures in CIF format
"""
# Use the configured results directory
output_dir = MATTERGENMODEL_RESULT_PATH
# 使用配置中的结果目录
output_dir = config.MATTERGENMODEL_RESULT_PATH
# Handle string input if provided
# 处理字符串输入(如果提供)
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError:
raise ValueError(f"Invalid properties JSON string: {properties}")
# Default to empty dict if None
# 如果为None默认为空字典
properties = properties or {}
# Process properties based on generation mode
# 根据生成模式处理属性
if not properties:
# Unconditional generation
model_path = os.path.join(MATTERGENMODEL_ROOT, "mattergen_base")
# 无条件生成
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
properties_to_condition_on = None
generation_type = "unconditional"
property_description = "unconditionally"
else:
# Conditional generation (single or multi-property)
# 条件生成(单属性或多属性)
properties_to_condition_on = {}
# Process each property
# 处理每个属性
for property_name, property_value in properties.items():
_, processed_value = preprocess_property(property_name, property_value)
properties_to_condition_on[property_name] = processed_value
# Determine which model to use based on properties
# 根据属性确定使用哪个模型
if len(properties) == 1:
# Single property conditioning
# 单属性条件
property_name = list(properties.keys())[0]
property_to_model = {
"dft_mag_density": "dft_mag_density",
@@ -314,29 +325,29 @@ async def generate_material(
generation_type = "single_property"
property_description = f"conditioned on {property_name} = {properties[property_name]}"
else:
# Multi-property conditioning
# 多属性条件
property_keys = set(properties.keys())
if property_keys == {"dft_mag_density", "hhi_score"}:
model_dir = "dft_mag_density_hhi_score"
elif property_keys == {"chemical_system", "energy_above_hull"}:
model_dir = "chemical_system_energy_above_hull"
else:
# If no specific multi-property model exists, use the first property's model
# 如果没有特定的多属性模型,使用第一个属性的模型
first_property = list(properties.keys())[0]
model_dir = first_property
generation_type = "multi_property"
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
# Construct the full model path
model_path = os.path.join(MATTERGENMODEL_ROOT, model_dir)
# 构建完整的模型路径
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
# Check if the model directory exists
# 检查模型目录是否存在
if not os.path.exists(model_path):
# Fallback to base model if specific model doesn't exist
# 如果特定模型不存在,回退到基础模型
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
model_path = os.path.join(MATTERGENMODEL_ROOT, "mattergen_base")
model_path = os.path.join(config.MATTERGENMODEL_ROOT, "mattergen_base")
# Call the main function with appropriate parameters
# 使用适当的参数调用main函数
main(
output_path=output_dir,
model_path=model_path,
@@ -347,20 +358,20 @@ async def generate_material(
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
)
# Create a dictionary to store the file contents
# 创建字典存储文件内容
result_dict = {}
# Define file paths
# 定义文件路径
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
# Read the CIF zip file
# 读取CIF压缩文件
if os.path.exists(cif_zip_path):
with open(cif_zip_path, 'rb') as f:
result_dict['cif_content'] = f.read()
# Create a descriptive prompt based on generation type
# 根据生成类型创建描述性提示
if generation_type == "unconditional":
title = "Generated Material Structures"
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
@@ -373,7 +384,7 @@ async def generate_material(
title = "Generated Material Structures Conditioned on Multiple Properties"
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
# Create the full prompt
# 创建完整的提示
prompt = f"""
# {title}
@@ -399,7 +410,7 @@ This data contains {batch_size * num_batches} crystal structures generated by th
You can use these structures for materials discovery, property prediction, or further analysis.
"""
# Clean up the files (delete them after reading)
# 清理文件(读取后删除)
try:
if os.path.exists(cif_zip_path):
os.remove(cif_zip_path)

View File

@@ -1,10 +1,17 @@
from ..llm_tools import llm_tool
"""
Property Prediction Module
This module provides functions for predicting properties of crystal structures.
"""
import asyncio
import torch
import numpy as np
from ase.units import GPa
from mattersim.forcefield import MatterSimCalculator
import asyncio
from .fairchem_tools import convert_structure
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.compute.structure_opt import convert_structure
@llm_tool(
name="predict_properties",

View File

@@ -0,0 +1,192 @@
"""
Structure Optimization Module
This module provides functions for optimizing crystal structures using the FairChem model.
"""
import asyncio
from io import StringIO
import sys
import tempfile
import os
import logging
from typing import Optional, Dict, Any
from ase.io import read, write
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from mars_toolkit.core.cif_utils import remove_symmetry_equiv_xyz
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
from mars_toolkit.core.error_handlers import handle_general_error
logger = logging.getLogger(__name__)
# 初始化FairChem模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
if calc is not None:
return
try:
from fairchem.core import OCPCalculator
calc = OCPCalculator(checkpoint_path=config.FAIRCHEM_MODEL_PATH)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""
将输入内容转换为Atoms对象
Args:
input_format: 输入格式 (cif, xyz, vasp等)
content: 结构内容字符串
Returns:
ASE Atoms对象如果转换失败则返回None
"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""
生成对称性CIF
Args:
structure: Pymatgen Structure对象
Returns:
CIF格式的字符串
"""
analyzer = SpacegroupAnalyzer(structure)
structure_refined = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
return content
def optimize_structure(atoms: Atoms, output_format: str) -> str:
"""
优化晶体结构
Args:
atoms: ASE Atoms对象
output_format: 输出格式 (cif, xyz, vasp等)
Returns:
包含优化结果的格式化字符串
"""
atoms.calc = calc
try:
# 捕获优化过程的输出
temp_output = StringIO()
original_stdout = sys.stdout
sys.stdout = temp_output
# 执行优化
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=config.FMAX)
# 恢复标准输出并获取日志
sys.stdout = original_stdout
optimization_log = temp_output.getvalue()
temp_output.close()
# 获取总能量
total_energy = atoms.get_potential_energy()
# 处理优化后的结构
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
content = remove_symmetry_equiv_xyz(content)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
os.unlink(tmp_file.name)
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
**Total Energy: {total_energy} eV**
#### Optimizing Log:
```text
{optimization_log}
```
### Optimized {output_format.upper()} Content:
```
{content}
```
"""
return format_result
except Exception as e:
logger.error(f"Failed to optimize structure: {str(e)}")
raise e
@llm_tool(name="optimize_crystal_structure",
description="Optimize crystal structure using FairChem model")
async def optimize_crystal_structure(
content: str,
input_format: str = "cif",
output_format: str = "cif"
) -> str:
"""
Optimize crystal structure using FairChem model.
Args:
content: Crystal structure content string
input_format: Input format (cif, xyz, vasp)
output_format: Output format (cif, xyz, vasp)
Returns:
Optimized structure with energy and optimization log
"""
# 确保模型已初始化
if calc is None:
init_model()
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_optimization():
# 转换结构
atoms = convert_structure(input_format, content)
if atoms is None:
raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确")
# 优化结构
return optimize_structure(atoms, output_format)
try:
# 直接返回结果或抛出异常
return await asyncio.to_thread(run_optimization)
except Exception as e:
return handle_general_error(e)

View File

@@ -0,0 +1,13 @@
"""
Core Module
This module provides core functionality for the Mars Toolkit.
"""
from mars_toolkit.core.config import config
from mars_toolkit.core.utils import settings, setup_logging
from mars_toolkit.core.error_handlers import (
handle_minio_error, handle_http_error,
handle_validation_error, handle_general_error
)
from mars_toolkit.core.llm_tools import llm_tool

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,121 @@
"""
CIF Utilities Module
This module provides basic functions for handling CIF (Crystallographic Information File) files,
which are commonly used in materials science for representing crystal structures.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import json
import logging
logger = logging.getLogger(__name__)
def read_cif_txt_file(file_path):
"""
Read the CIF file and return its content.
Args:
file_path: Path to the CIF file
Returns:
String content of the CIF file or None if an error occurs
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"Error reading file {file_path}: {e}")
return None
def extract_cif_info(path: str, fields_name: list):
"""
Extract specific fields from the CIF description JSON file.
Args:
path: Path to the JSON file containing CIF information
fields_name: List of field categories to extract. Use 'all_fields' to extract all fields.
Other options include 'basic_fields', 'energy_electronic_fields', 'metal_magentic_fields'
Returns:
Dictionary containing the extracted fields
"""
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
selected_fields = []
if fields_name[0] == 'all_fields':
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
else:
for field in fields_name:
selected_fields.extend(locals().get(field, []))
with open(path, 'r') as f:
docs = json.load(f)
new_docs = {}
for field_name in selected_fields:
new_docs[field_name] = docs.get(field_name, '')
return new_docs
def remove_symmetry_equiv_xyz(cif_content):
"""
Remove symmetry operations section from CIF file content.
This is often useful when working with CIF files in certain visualization tools
or when focusing on the basic structure without symmetry operations.
Args:
cif_content: CIF file content string
Returns:
Cleaned CIF content string with symmetry operations removed
"""
lines = cif_content.split('\n')
output_lines = []
i = 0
while i < len(lines):
line = lines[i].strip()
# 检测循环开始
if line == 'loop_':
# 查看下一行,检查是否是对称性循环
next_lines = []
j = i + 1
while j < len(lines) and lines[j].strip().startswith('_'):
next_lines.append(lines[j].strip())
j += 1
# 检查是否包含对称性操作标签
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
# 跳过整个循环块
while i < len(lines):
if i + 1 >= len(lines):
break
next_line = lines[i + 1].strip()
# 检查是否到达下一个循环或数据块
if next_line == 'loop_' or next_line.startswith('data_'):
break
# 检查是否到达原子位置部分
if next_line.startswith('_atom_site_'):
break
i += 1
else:
# 不是对称性循环保留loop_行
output_lines.append(lines[i])
else:
# 非循环开始行,直接保留
output_lines.append(lines[i])
i += 1
return '\n'.join(output_lines)

View File

@@ -0,0 +1,59 @@
"""
Configuration Module
This module provides configuration settings for the Mars Toolkit.
It includes API keys, endpoints, paths, and other configuration parameters.
"""
from typing import Dict, Any
class Config:
"""Configuration class for Mars Toolkit"""
# Materials Project
MP_API_KEY = 'PMASAg256b814q3OaSRWeVc7MKx4mlKI'
MP_ENDPOINT = 'https://api.materialsproject.org/'
MP_TOPK = 3
LOCAL_MP_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/'
# Proxy
HTTP_PROXY = 'http://192.168.168.1:20171'
HTTPS_PROXY = 'http://192.168.168.1:20171'
# FairChem
FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
FMAX = 0.05
# MatterGen
MATTERGENMODEL_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
MATTERGENMODEL_RESULT_PATH = 'results/'
# Dify
DIFY_ROOT_URL = 'http://192.168.191.101:6080'
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
# Searxng
SEARXNG_HOST="http://192.168.191.101:40032/"
# Visualization
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
@classmethod
def as_dict(cls) -> Dict[str, Any]:
"""Return all configuration settings as a dictionary"""
return {
key: value for key, value in cls.__dict__.items()
if not key.startswith('__') and not callable(value)
}
@classmethod
def update(cls, **kwargs):
"""Update configuration settings"""
for key, value in kwargs.items():
if hasattr(cls, key):
setattr(cls, key, value)
# Create a global instance for easy access
config = Config()

View File

@@ -1,4 +1,10 @@
"""
Error Handlers Module
This module provides error handling utilities for the Mars Toolkit.
It includes functions for handling various types of errors that may occur
during toolkit operations.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn

View File

@@ -1,8 +1,7 @@
import os
import boto3
import logging
import logging.config
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
@@ -34,36 +33,11 @@ class Settings(BaseSettings):
env_file = ".env"
env_file_encoding = "utf-8"
def get_minio_client(settings: Settings):
"""获取MinIO客户端"""
return boto3.client(
's3',
endpoint_url=settings.internal_minio_endpoint or settings.minio_endpoint,
aws_access_key_id=settings.minio_access_key,
aws_secret_access_key=settings.minio_secret_key
)
def handle_minio_upload(file_path: str, file_name: str) -> str:
"""统一处理MinIO上传"""
try:
client = get_minio_client(settings)
client.upload_file(file_path, settings.minio_bucket, file_name, ExtraArgs={"ACL": "private"})
# 生成预签名 URL
url = client.generate_presigned_url(
'get_object',
Params={'Bucket': settings.minio_bucket, 'Key': file_name},
ExpiresIn=3600
)
return url.replace(settings.internal_minio_endpoint or "", settings.minio_endpoint)
except Exception as e:
from tools_for_ms.services_tools.error_handlers import handle_minio_error
return handle_minio_error(e)
def setup_logging():
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
log_file_path = os.path.join(parent_dir, 'mars_toolkit.log')
"""配置日志记录"""
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
log_file_path = os.path.join(parent_dir, 'mars_toolkit.log')
logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': False,

View File

@@ -0,0 +1,7 @@
"""
Basic Module
This module provides basic utility functions for the Mars Toolkit.
"""
from mars_toolkit.misc.misc_tools import get_current_time

Binary file not shown.

View File

@@ -0,0 +1,29 @@
"""
General Tools Module
This module provides basic utility functions that are not specific to materials science.
"""
import asyncio
from datetime import datetime
import pytz
from typing import Annotated
from mars_toolkit.core.llm_tools import llm_tool
@llm_tool(name="get_current_time", description="Get current date and time in specified timezone")
async def get_current_time(timezone: str = "UTC") -> str:
"""Returns the current date and time in the specified timezone.
Args:
timezone: Timezone name (e.g., UTC, Asia/Shanghai, America/New_York)
Returns:
Formatted date and time string
"""
try:
tz = pytz.timezone(timezone)
current_time = datetime.now(tz)
return f"The current {timezone} time is: {current_time.strftime('%Y-%m-%d %H:%M:%S %Z')}"
except pytz.exceptions.UnknownTimeZoneError:
return f"Unknown timezone: {timezone}. Please use a valid timezone such as 'UTC', 'Asia/Shanghai', etc."

View File

@@ -0,0 +1,18 @@
"""
Query Module
This module provides query tools for materials science, including:
- Materials Project database queries
- OQMD database queries
- Dify knowledge base retrieval
- Web search
"""
from mars_toolkit.query.mp_query import (
search_material_property_from_material_project,
get_crystal_structures_from_materials_project,
get_mpid_from_formula
)
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
from mars_toolkit.query.web_search import search_online

View File

@@ -1,9 +1,22 @@
"""
Dify Search Module
This module provides functions for retrieving information from local materials science
literature knowledge base using Dify API.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import asyncio
import json
from .Configs import DIFY_API_KEY
import requests
import codecs
from ..llm_tools import llm_tool
from typing import Dict, Any
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
@llm_tool(
name="retrieval_from_knowledge_base",
@@ -13,19 +26,19 @@ async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
"""
检索本地材料科学文献知识库中的相关信息
输入:
Args:
query: 查询字符串如材料名称"CsPbBr3"
topk: 返回结果数量默认3条
输出:
Returns:
包含文档ID标题和相关性分数的字典
"""
# 设置Dify API的URL端点
url = 'http://192.168.191.101:6080/v1/chat-messages'
url = f'{config.DIFY_ROOT_URL}/v1/chat-messages'
# 配置请求头包含API密钥和内容类型
headers = {
'Authorization': f'Bearer {DIFY_API_KEY}', # 使用配置文件中的API密钥
'Authorization': f'Bearer {config.DIFY_API_KEY}',
'Content-Type': 'application/json'
}
@@ -45,11 +58,9 @@ async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
# 获取响应文本
response_text = response.text
useful_results = [] # 初始化结果列表(当前未使用)
# 解码响应文本中的Unicode转义序列
response_text = codecs.decode(response_text, 'unicode_escape')
print(response_text) # 打印完整响应用于调试
# 将响应文本解析为JSON对象
result_json = json.loads(response_text)
@@ -61,20 +72,13 @@ async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
useful_info = {
"id": metadata.get("document_id"), # 文档ID
"title": result_json.get("title"), # 文档标题
"content": None, # 内容字段设为空,注意:原字典使用'answer'字段存储内容
"metadata": None, # 元数据字段设为空
"embedding": None, # 嵌入向量字段设为空
"content": result_json.get("answer", ""), # 内容字段使用'answer'字段存储内容
"score": metadata.get("score") # 相关性分数
}
# 返回提取的有用信息
return useful_info
return json.dumps(useful_info, ensure_ascii=False, indent=2)
except Exception as e:
# 捕获并处理所有可能的异常,返回错误信息
return f"错误: {str(e)}"
# 当脚本直接运行时的测试代码
if __name__ == "__main__":
# 使用示例查询"CsPbBr3"测试函数
print(asyncio.run(retrieval_from_knowledge_base('CsPbBr3')))

View File

@@ -1,9 +1,8 @@
"""
Materials Project API Service Tools
Materials Project Query Module
This module provides functions for querying the Materials Project database,
processing search results, and formatting responses. It includes a LLM tool
for integration with large language models.
processing search results, and formatting responses.
Author: Yutang LI
Institution: SIAT-MIC
@@ -23,97 +22,11 @@ from mp_api.client import MPRester
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from ..services_tools import Configs
from ..utils import settings, handle_minio_upload
from .error_handlers import handle_general_error
from ..llm_tools import llm_tool
def read_cif_txt_file(file_path):
"""Read the markdown file and return its content."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return None
def get_extra_cif_info(path: str, fields_name: list):
"""Extract specific fields from the CIF description."""
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
# metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"]
selected_fields = []
if fields_name[0] == 'all_fields':
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
# selected_fields = energy_electronic_fields + metal_magentic_fields
else:
for field in fields_name:
selected_fields.extend(locals().get(field, []))
with open(path, 'r') as f:
docs = json.load(f)
new_docs = {}
for field_name in selected_fields:
new_docs[field_name] = docs.get(field_name, '')
# new_docs['structure'] = {"lattice": docs['structure']['lattice']}
return new_docs
def remove_symmetry_equiv_xyz(cif_content):
"""
Remove symmetry operations section from CIF file content
Args:
cif_content: CIF file content string
Returns:
Cleaned CIF content string
"""
lines = cif_content.split('\n')
output_lines = []
i = 0
while i < len(lines):
line = lines[i].strip()
# 检测循环开始
if line == 'loop_':
# 查看下一行,检查是否是对称性循环
next_lines = []
j = i + 1
while j < len(lines) and lines[j].strip().startswith('_'):
next_lines.append(lines[j].strip())
j += 1
# 检查是否包含对称性操作标签
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
# 跳过整个循环块
while i < len(lines):
if i + 1 >= len(lines):
break
next_line = lines[i + 1].strip()
# 检查是否到达下一个循环或数据块
if next_line == 'loop_' or next_line.startswith('data_'):
break
# 检查是否到达原子位置部分
if next_line.startswith('_atom_site_'):
break
i += 1
else:
# 不是对称性循环保留loop_行
output_lines.append(lines[i])
else:
# 非循环开始行,直接保留
output_lines.append(lines[i])
i += 1
return '\n'.join(output_lines)
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
from mars_toolkit.core.error_handlers import handle_general_error
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
logger = logging.getLogger(__name__)
@@ -218,17 +131,14 @@ def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]] |
new_docs[field_name] = doc.get(field_name, '')
res.append(new_docs)
except Exception as e:
# logger.warning(f"Error processing document: {str(e)}")
logger.warning(f"Error processing document: {str(e)}")
continue
return res
except Exception as e:
error_msg = f"Error in process_search_results: {str(e)}"
# logger.error(error_msg)
import traceback
# logger.error(traceback.format_exc())
logger.error(error_msg)
return error_msg
def _search_worker(queue, api_key, **kwargs):
"""
Worker function for executing Materials Project API searches.
@@ -243,24 +153,15 @@ def _search_worker(queue, api_key, **kwargs):
try:
import os
import traceback
os.environ['HTTP_PROXY'] = Configs.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] = Configs.HTTPS_PROXY or ''
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
# 初始化 MPRester 客户端
with MPRester(api_key) as mpr:
# print(f"MPRester initialized with endpoint:")
# print("Executing search...")
result = mpr.materials.summary.search(**kwargs)
# print(f"Search completed, result type: {type(result)}")
# 检查结果
if result:
# print(f"Number of results: {len(result)}")
# print(f"First result type: {type(result[0])}")
# 尝试使用更安全的方式处理结果
processed_results = []
for doc in result:
@@ -285,17 +186,12 @@ def _search_worker(queue, api_key, **kwargs):
# 最后的尝试,直接使用 doc
processed_results.append(doc)
# print(f"Processed {len(processed_results)} results")
queue.put(processed_results)
else:
# print("No results found")
queue.put([])
except Exception as e:
# print(f"Error in _search_worker: {str(e)}")
# print(traceback.format_exc())
queue.put(e)
async def execute_search(search_args: Dict[str, Any], timeout: int = 120) -> List[Dict[str, Any]] | str:
"""
Execute a search against the Materials Project API.
@@ -309,78 +205,44 @@ async def execute_search(search_args: Dict[str, Any], timeout: int = 120) -> Lis
Returns:
List of document dictionaries from the search results or error message string if an exception occurs
"""
# print(f"Starting execute_search with args: {search_args}")
# 确保 formula 参数是列表类型
if 'formula' in search_args and isinstance(search_args['formula'], str):
search_args['formula'] = [search_args['formula']]
# print(f"Converted formula to list in execute_search: {search_args['formula']}")
manager = Manager()
queue = manager.Queue()
try:
p = Process(target=_search_worker, args=(queue, Configs.MP_API_KEY), kwargs=search_args)
p = Process(target=_search_worker, args=(queue, config.MP_API_KEY), kwargs=search_args)
p.start()
# logger.info(f"Started worker process with PID: {p.pid}")
# print(f"Waiting for process {p.pid} to complete (timeout: {timeout}s)...")
p.join(timeout=timeout)
if p.is_alive():
# logger.warning(f"Terminating worker process {p.pid} due to timeout")
# print(f"Process {p.pid} timed out, terminating...")
logger.warning(f"Terminating worker process {p.pid} due to timeout")
p.terminate()
p.join()
error_msg = f"Request timed out after {timeout} seconds"
return error_msg
# print("Process completed, retrieving results from queue...")
try:
if queue.empty():
# logger.warning("Queue is empty after process completion")
# print("Warning: Queue is empty after process completion")
pass
else:
# logger.info("Queue contains data, retrieving...")
# print("Queue contains data, retrieving...")
pass
result = queue.get(timeout=timeout)
# print(f"Result type: {type(result)}")
if isinstance(result, Exception):
# logger.error(f"Error in search worker: {str(result)}")
# print(f"Error in search worker: {str(result)}")
# 尝试获取更详细的错误信息
logger.error(f"Error in search worker: {str(result)}")
if hasattr(result, "__traceback__"):
import traceback
tb_str = ''.join(traceback.format_exception(None, result, result.__traceback__))
# print(f"Error traceback: {tb_str}")
return f"Error in search worker: {str(result)}"
if isinstance(result, list):
# print(f"Successfully retrieved {len(result)} documents")
# logger.info(f"Successfully retrieved {len(result)} documents")
pass
else:
# print(f"Result is not a list, but {type(result)}")
pass
return result
except queue.Empty:
error_msg = "Failed to retrieve data from queue (timeout)"
# logger.error(error_msg)
# print(error_msg)
logger.error(error_msg)
return error_msg
except Exception as e:
error_msg = f"Error in execute_search: {str(e)}"
# logger.error(error_msg)
# print(error_msg)
import traceback
# print(traceback.format_exc())
logger.error(error_msg)
return error_msg
@llm_tool(name="search_material_property_from_material_project", description="Search materials in Materials Project database by formula and properties")
@@ -403,8 +265,6 @@ async def search_material_property_from_material_project(
Returns:
JSON formatted material properties data
"""
# print(f"search_material_property_from_material_project called with formula: {formula}, type: {type(formula)}")
# 验证晶系参数
VALID_CRYSTAL_SYSTEMS = ['Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal', 'Cubic']
@@ -421,9 +281,6 @@ async def search_material_property_from_material_project(
# 确保 formula 是列表类型
if isinstance(formula, str):
formula = [formula]
# print(f"Converted formula to list: {formula}")
params = {
"chemsys": chemsys,
@@ -432,37 +289,25 @@ async def search_material_property_from_material_project(
"is_gap_direct": is_gap_direct,
"is_stable": is_stable,
"chunk_size": 5,
}
# Filter out None values
params = {k: v for k, v in params.items() if v is not None}
# print("Parameters after filtering:", params)
mp_id_list = await get_mpid_from_formula(formula=formula)
try:
res=[]
# Execute search against Materials Project API
#docs = await execute_search(params)
# for mp_id in id_list:
for mp_id in mp_id_list:
crystal_props = get_extra_cif_info(Configs.LOCAL_MP_PROPERTY_ROOT+f"{mp_id}.json", ['all_fields'])
crystal_props = extract_cif_info(config.LOCAL_MP_ROOT+f"/Props/{mp_id}.json", ['all_fields'])
res.append(crystal_props)
#res = process_search_results(docs)
# print(f"Processed {len(res)} results")
if len(res) == 0:
# print("No results found")
return "No results found, please try again."
# Format response with top results
# print(f"Formatting top {Configs.MP_TOPK} results")
try:
# 创建包含索引的JSON结果
formatted_results = []
for i, item in enumerate(res[:Configs.MP_TOPK], 1):
for i, item in enumerate(res[:config.MP_TOPK], 1):
formatted_result = f"[property {i} begin]\n"
formatted_result += json.dumps(item, indent=2)
formatted_result += f"\n[property {i} end]\n\n"
@@ -472,25 +317,19 @@ async def search_material_property_from_material_project(
res_chunk = "\n\n".join(formatted_results)
res_template = f"""
Here are the search results from the Materials Project database:
Due to length limitations, only the top {Configs.MP_TOPK} results are shown below:\n
Due to length limitations, only the top {config.MP_TOPK} results are shown below:\n
{res_chunk}
If you need more results, please modify your search criteria or try different query parameters.
"""
# print("Successfully formatted results")
return res_template
except Exception as format_error:
# print(f"Error formatting results: {str(format_error)}")
import traceback
# print(traceback.format_exc())
logger.error(f"Error formatting results: {str(format_error)}")
return str(format_error)
except Exception as e:
# print(f"Error in search_material_property_from_material_project: {str(e)}")
import traceback
# print(traceback.format_exc())
logger.error(f"Error in search_material_property_from_material_project: {str(e)}")
return str(e)
@llm_tool(name="get_crystal_structures_from_materials_project", description="Get symmetrized crystal structures CIF data from Materials Project database by chemical formula")
async def get_crystal_structures_from_materials_project(
formulas: list[str],
@@ -508,46 +347,11 @@ async def get_crystal_structures_from_materials_project(
Returns:
Formatted text containing symmetrized CIF data
"""
# 确保 formulas 是列表类型
# if isinstance(formulas, str):
# formulas = [formulas]
# try:
# # 构建搜索参数
# search_args = {
# "formula": formulas,
# "fields": ["material_id",]
# }
# # 使用execute_search函数查询晶体结构信息
# docs = await execute_search(search_args, timeout=60)
# if isinstance(docs, str):
# # 如果返回的是字符串,说明发生了错误
# return f"获取晶体结构时出错: {docs}"
# if not docs:
# return "未找到指定化学式的晶体结构数据。"
# 处理结果
# result = {}
# for i, doc in enumerate(docs):
# try:
# # 获取材料ID和结构
# material_id = doc.get('material_id')
# structure_data = doc.get('structure')
# if not structure_data:
# continue
# # 将结构数据转换为pymatgen Structure对象
result={}
mp_id_list=await get_mpid_from_formula(formula=formulas)
for i,mp_id in enumerate(mp_id_list):
cif_file = glob.glob(f"/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/MPDatasets/{mp_id}.cif")[0]
#print('111',cif_file)
cif_file = glob.glob(config.LOCAL_MP_ROOT+f"/MPDatasets/{mp_id}.cif")[0]
structure = Structure.from_file(cif_file)
# 如果需要常规单元格
if conventional_unit_cell:
@@ -570,28 +374,24 @@ async def get_crystal_structures_from_materials_project(
result[key] = cif_data
# 只保留前Configs.MP_TOPK个结果
if len(result) >= Configs.MP_TOPK:
# 只保留前config.MP_TOPK个结果
if len(result) >= config.MP_TOPK:
break
# except Exception as e:
# continue
# 格式化响应
try:
prompt = f"""
prompt = f"""
# Materials Project Symmetrized Crystal Structure Data
Below are symmetrized crystal structure data for {len(result)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
"""
for i, (key, cif_data) in enumerate(result.items(), 1):
prompt += f"[cif {i} begin]\n"
prompt += cif_data
prompt += f"\n[cif {i} end]\n\n"
for i, (key, cif_data) in enumerate(result.items(), 1):
prompt += f"[cif {i} begin]\n"
prompt += cif_data
prompt += f"\n[cif {i} end]\n\n"
prompt += """
prompt += """
## Usage Instructions
1. You can copy the above CIF data and save it as .cif files
@@ -601,15 +401,14 @@ These structures have been analyzed and optimized for symmetry using SpacegroupA
CIF files contain complete structural information of crystals, including cell parameters, atomic coordinates, symmetry, etc.
Symmetrization helps identify and optimize crystal symmetry, making the structure more standardized and accurate.
"""
return prompt
return prompt
except Exception as format_error:
import traceback
logger.error(f"Error formatting crystal structures: {str(format_error)}")
return str(format_error)
@llm_tool(name="get_mpid_from_formula", description="Get material IDs (mpid) from Materials Project database by chemical formula")
async def get_mpid_from_formula(formula: str) -> str:
async def get_mpid_from_formula(formula: str) -> List[str]:
"""
Get material IDs (mpid) from Materials Project database by chemical formula.
Returns mpids for the lowest energy structures.
@@ -618,27 +417,17 @@ async def get_mpid_from_formula(formula: str) -> str:
formula: Chemical formula (e.g., "Fe2O3")
Returns:
Formatted text containing material IDs
List of material IDs
"""
# 确保 formula 是列表类型,因为 _search_by_formula_worker 需要列表输入
os.environ['HTTP_PROXY'] = Configs.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] = Configs.HTTPS_PROXY or ''
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
id_list = []
with MPRester(Configs.MP_API_KEY) as mpr:
docs = mpr.materials.summary.search(formula=formula)#这里设定搜索条件id list =[]for doc in docs:#获取材料索引号id list.append(doc.material id)
for doc in docs:
id_list.append(doc.material_id)
return id_list
# cif_description_list= []
# cif_information_list=[]
# crystal_props_list=[]
# #print("mp_id",id_list)
# for mp_id in id_list:
# cif_description = read_cif_txt_file('/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Text-bond/{}.txt'.format(mp_id))
# cif_information = read_cif_txt_file('/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Symmetry_MPDatasets/{}_symmetrized.cif'.format(mp_id))
# cif_information = cif_information.replace('# generated using pymatgen\n', '')
# crystal_props = get_extra_cif_info("/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Props/{}.json".format(mp_id), ['all_fields'])
# cif_description_list.append(cif_description)
# cif_information_list.append(cif_information)
try:
with MPRester(config.MP_API_KEY) as mpr:
docs = mpr.materials.summary.search(formula=formula)
for doc in docs:
id_list.append(doc.material_id)
return id_list
except Exception as e:
logger.error(f"Error getting mpid from formula: {str(e)}")
return []

View File

@@ -1,3 +1,13 @@
"""
OQMD Query Module
This module provides functions for querying the Open Quantum Materials Database (OQMD).
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import logging
import httpx
import pandas as pd
@@ -5,13 +15,12 @@ from bs4 import BeautifulSoup
from io import StringIO
from typing import Annotated
from ..llm_tools import llm_tool
from ..llm_tools import *
from mars_toolkit.core.llm_tools import llm_tool
logger = logging.getLogger(__name__)
@llm_tool(name="fetch_chemical_composition_from_OQMD", description="Fetch material data for a chemical composition from OQMD database")
async def fetch_chemical_composition_from_OQMD (
async def fetch_chemical_composition_from_OQMD(
composition: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
) -> str:
"""

View File

@@ -1,48 +1,16 @@
"""
Web Search Module
This module provides functions for searching information on the web.
"""
import asyncio
from datetime import datetime
from typing import Annotated, Dict, Any
import pytz
from typing import Annotated, Dict, Any, List
from langchain_community.utilities import SearxSearchWrapper
from .llm_tools import llm_tool
# Json Schema 将函数转化为大模型能够理解的格式因为大模型训练时调用函数相关的数据使用Json Schema的格式
#1. 使用@tool装饰器装饰函数。
#2. 使用Annotated为参数添加描述。
#3. 完善函数的docstring以明确工具的功能。 让模型在调用函数时能清楚每个模块的功能
# @tool
# def online_search(
# query: Annotated[str, "The search term to find scientific content in English"]
# ) -> str:
# """Searches scientific information on the Internet and returns results in English."""
# search = SearxSearchWrapper(
# searx_host="http://192.168.191.101:40032/",
# categories=["science"],
# k=20
# )
# return search.run(query, language='es', num_results=2)
@llm_tool(name="get_current_time", description="Get current date and time in specified timezone")
async def get_current_time(timezone: str = "UTC") -> str:
"""Returns the current date and time in the specified timezone.
Args:
timezone: Timezone name (e.g., UTC, Asia/Shanghai, America/New_York)
Returns:
Formatted date and time string
"""
try:
tz = pytz.timezone(timezone)
current_time = datetime.now(tz)
return f"The current {timezone} time is: {current_time.strftime('%Y-%m-%d %H:%M:%S %Z')}"
except pytz.exceptions.UnknownTimeZoneError:
return f"Unknown timezone: {timezone}. Please use a valid timezone such as 'UTC', 'Asia/Shanghai', etc."
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
@llm_tool(name="search_online", description="Search scientific information online and return results as a string")
async def search_online(
@@ -73,7 +41,7 @@ async def search_online(
# Initialize search wrapper
search = SearxSearchWrapper(
searx_host="http://192.168.191.101:40032/",
searx_host=config.SEARXNG_HOST,
categories=["science",],
k=num_results
)
@@ -107,17 +75,3 @@ async def search_online(
result_str += f"Source: {result['source']}\n\n"
return result_str
# 让大模型可以根据函数名,直接调用函数
# tool_map = {
# "online_search": online_search,
# "get_current_time": get_current_time,
# }
#####要用时得加修饰符@tool为了实现异步不用@tool
#tools = [online_search,get_current_time]
#tools_json_shcema = [convert_to_openai_function_format(tool.args_schema.model_json_schema()) for tool in tools]

View File

View File

@@ -1,6 +1,8 @@
import asyncio
import json
from rich.console import Console
console = Console()
async def test_tool(tool_name: str) -> str:
"""
@@ -16,27 +18,27 @@ async def test_tool(tool_name: str) -> str:
print(f"开始测试工具: {tool_name}")
if tool_name == "get_current_time":
from tools_for_ms.basic_tools import get_current_time
from mars_toolkit.misc.misc_tools import get_current_time
result = await get_current_time(timezone="Asia/Shanghai")
elif tool_name == "search_online":
from tools_for_ms.basic_tools import search_online
#from tools_for_ms.basic_tools import search_online
from mars_toolkit.query.web_search import search_online
result = await search_online(query="material science", num_results=2)
elif tool_name == "search_material_property_from_material_project":
from tools_for_ms.services_tools.mp_tools import search_material_property_from_material_project
from mars_toolkit.query.mp_query import search_material_property_from_material_project
result = await search_material_property_from_material_project(formula="Fe2O3")
elif tool_name == "get_crystal_structures_from_materials_project":
from tools_for_ms.services_tools.mp_tools import get_crystal_structures_from_materials_project
result = await get_crystal_structures_from_materials_project(
formulas=["Fe2O3"])
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
result = await get_crystal_structures_from_materials_project(formulas=["Fe2O3"])
elif tool_name == "get_mpid_from_formula":
from tools_for_ms.query_tools.mp_tools import get_mpid_from_formula
result = await get_mpid_from_formula(['Fe2O3'])
from mars_toolkit.query.mp_query import get_mpid_from_formula
result = await get_mpid_from_formula(formula=["Fe2O3"])
elif tool_name == "optimize_crystal_structure":
from tools_for_ms.services_tools.fairchem_tools import optimize_crystal_structure
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
# 使用一个简单的CIF字符串作为测试输入
simple_cif = """
data_simple
@@ -58,21 +60,21 @@ async def test_tool(tool_name: str) -> str:
result = await optimize_crystal_structure(content=simple_cif, input_format="cif")
elif tool_name == "generate_material":
from tools_for_ms.services_tools.mattergen_tools import generate_material
from mars_toolkit.compute.material_gen import generate_material
# 使用简单的属性约束进行测试
result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
elif tool_name == "fetch_chemical_composition_from_OQMD":
from tools_for_ms.services_tools.oqmd_tools import fetch_chemical_composition_from_OQMD
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
result = await fetch_chemical_composition_from_OQMD(composition="Fe2O3")
elif tool_name == "retrieval_from_knowledge_base":
from tools_for_ms.query_tools.search_dify import retrieval_from_knowledge_base
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
result = await retrieval_from_knowledge_base(query="CsPbBr3", topk=3)
elif tool_name == "predict_properties":
from tools_for_ms.services_tools.mattersim_tools import predict_properties
# 使用一个简单的硅钻石结构CIF字符串作为测试输入
from mars_toolkit.compute.property_pred import predict_properties
# 使用一个简单的CsPbBr3结构CIF字符串作为测试输入
_cif = """
# generated using pymatgen
data_CsPbBr3
@@ -120,52 +122,14 @@ loop_
Br Br17 1 0.79480800 0.20538600 0.47431200 1
Br Br18 1 0.95175500 0.49370800 0.75000000 1
Br Br19 1 0.45175500 0.00629200 0.25000000 1
"""
result = await predict_properties(cif_content=_cif)
# elif tool_name == "visualize_cif":
# from tools_for_ms.services_tools.cif_visualization_tools import visualize_cif
# # 使用一个简单的CIF字符串作为测试输入
# simple_cif = """
# data_CdEu2NEu
# _chemical_formula_structural CdEu2NEu
# _chemical_formula_sum "Cd1 Eu3 N1"
# _cell_length_a 5.114863465543178
# _cell_length_b 5.110721509244114
# _cell_length_c 5.113552093505859
# _cell_angle_alpha 90.02261043268513
# _cell_angle_beta 90.00946914658029
# _cell_angle_gamma 89.99314499504335
# _space_group_name_H-M_alt "P 1"
# _space_group_IT_number 1
# loop_
# _space_group_symop_operation_xyz
# 'x, y, z'
# loop_
# _atom_site_type_symbol
# _atom_site_label
# _atom_site_symmetry_multiplicity
# _atom_site_fract_x
# _atom_site_fract_y
# _atom_site_fract_z
# _atom_site_occupancy
# Cd Cd1 1.0 0.6641489863395691 0.6804293394088744 0.3527604341506958 1.0000
# Eu Eu1 1.0 0.1641521006822586 0.18045939505100247 0.35262206196784973 1.0000
# Eu Eu2 1.0 0.16385404765605927 0.6803322434425354 0.8526210784912109 1.0000
# N N1 1.0 0.16389326751232147 0.1804375052452087 0.8527467250823975 1.0000
# Eu Eu3 1.0 0.664197564125061 0.1803932040929794 0.8526203036308289 1.0000
# """
# result = await visualize_cif(cif_content=simple_cif)
# else:
# return f"未知工具: {tool_name}"
else:
return f"未知工具: {tool_name}"
print(f"工具 {tool_name} 测试完成")
return f"工具 {tool_name} 测试成功,返回结果类型: {type(result)},返回的结果{result}"
return f"工具 {tool_name} 测试成功,返回结果类型: {type(result)}, 返回的结果: {result}"
except Exception as e:
import traceback
@@ -173,18 +137,44 @@ loop_
return f"工具 {tool_name} 测试失败: {str(e)}\n{error_details}"
if __name__ == "__main__":
def print_tool_schemas():
"""打印所有注册的工具函数的JSON模式"""
import mars_toolkit
schemas = mars_toolkit.get_tool_schemas()
console.print("[bold green]已注册的工具函数列表:[/bold green]")
for i, schema in enumerate(schemas, 1):
console.print(f"[bold cyan]工具 {i}:[/bold cyan] {schema['function']['name']}")
console.print(f"[bold yellow]描述:[/bold yellow] {schema['function']['description']}")
console.print("[bold magenta]参数:[/bold magenta]")
for param_name, param_info in schema['function']['parameters']['properties'].items():
required = "必需" if param_name in schema['function']['parameters'].get('required', []) else "可选"
console.print(f" - [bold]{param_name}[/bold] ({required}): {param_info.get('description', '无描述')}")
console.print("")
# 查询工具
# search_material_property_from_material_project 在material project中通过化学式查询材料性质 ✅
# get_crystal_structures_from_materials_project 在material project中通过化学式查询晶体性质✅
# fetch_chemical_composition_from_OQMD 在OQMD中通过化学式查询获取化学组成✅
# search_online
# 生成内容的工具
# optimize_crystal_structure 使用fairchem 优化晶体结构✅
# predict_properties 使用mattersim 预测晶体性质 ✅
# generate_material 使用matter 预测晶体性质✅
# 测试 MatterSim 工具
tool_name ='search_material_property_from_material_project'
if __name__ == "__main__":
# 打印所有工具函数的模式
#print_tool_schemas()
# 测试工具函数列表
tools_to_test = [
"get_current_time", # 基础工具
"search_online", # 网络搜索工具
"search_material_property_from_material_project", # 材料项目查询工具
"get_crystal_structures_from_materials_project", # 晶体结构查询工具
"get_mpid_from_formula", # 材料ID查询工具
"optimize_crystal_structure", # 晶体结构优化工具
"generate_material", # 材料生成工具
"fetch_chemical_composition_from_OQMD", # OQMD查询工具
"retrieval_from_knowledge_base", # 知识库检索工具
"predict_properties" # 属性预测工具
]
# 选择要测试的工具
tool_name = tools_to_test[5] # 测试 search_online 工具
# 运行测试
result = asyncio.run(test_tool(tool_name))
print(result)
console.print(f"[bold blue]测试结果:[/bold blue]")
console.print(result)

View File

@@ -1,16 +0,0 @@
"""
Tools package for LLM function calling.
This package provides utilities for defining, registering, and managing LLM tools.
"""
from .llm_tools import llm_tool, get_tools, get_tool_schemas
from .basic_tools import *
from .services_tools.oqmd_tools import fetch_chemical_composition_from_OQMD
from .services_tools.mp_tools import search_material_property_from_material_project,get_crystal_structures_from_materials_project
from .services_tools.search_dify import retrieval_from_knowledge_base
from .services_tools.fairchem_tools import optimize_crystal_structure
#from .services_tools.mattergen_tools import generate_material
#from .services_tools.mattersim_tools import predict_properties
__all__ = ["llm_tool", "get_tools", "get_tool_schemas"]

View File

@@ -1,23 +0,0 @@
MP_API_KEY='PMASAg256b814q3OaSRWeVc7MKx4mlKI'
MP_ENDPOINT='https://api.materialsproject.org/'
MP_TOPK = 3
LOCAL_MP_PROPERTY_ROOT='/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Props/'
# Proxy
HTTP_PROXY='http://192.168.168.1:20171' #'http://127.0.0.1:7897' #192.168.191.101:20171
HTTPS_PROXY='http://192.168.168.1:20171'#'http://127.0.0.1:7897'
FAIRCHEM_MODEL_PATH='/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
FMAX=0.05
MATTERGENMODEL_ROOT='/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
MATTERGENMODEL_RESULT_PATH='results/'
DIFY_ROOT_URL='http://192.168.191.101:6080'
DIFY_API_KEY='app-IKZrS1RqIyurPSzR73mz6XSA'
VIZ_CIF_OUTPUT_ROOT='/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'

View File

@@ -1,386 +0,0 @@
# 输入content-> 转换content生成atomase能处理的格式->用matgen生成优化后的结构-再生成对称性cif。
from io import StringIO
import sys
import tempfile
from ase.io import read, write
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
import os
from .. import llm_tool
os.environ["PYTHONWARNINGS"] = "ignore"
# 或者更精细的控制
os.environ["PYTHONWARNINGS"] = "ignore::DeprecationWarning"
from typing import Optional
import logging
from pymatgen.core.structure import Structure
from .error_handlers import handle_general_error
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
logger = logging.getLogger(__name__)
from pymatgen.io.cif import CifWriter
from ase.atoms import Atoms
from .Configs import*
calc = None
# def init_model():
# """初始化FairChem模型"""
# global calc
# try:
# from fairchem.core import OCPCalculator
# calc = OCPCalculator(checkpoint_path= FAIRCHEM_MODEL_PATH)
# logger.info("FairChem model initialized successfully")
# except Exception as e:
# logger.error(f"Failed to initialize FairChem model: {str(e)}")
# raise
# init_model()
# # 格式转化
# def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
# '''example:
# input_format = "xyz" cif vasp 等等
# content = """5
# H2O molecule with an extra oxygen and hydrogen
# O 0.0 0.0 0.0
# H 0.0 0.0 0.9
# H 0.0 0.9 0.0
# O 1.0 0.0 0.0
# H 1.0 0.0 0.9
# return Atoms(symbols='OH2OH', pbc=False)
# """
# '''
# """将输入内容转换为Atoms对象"""
# try:
# with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
# tmp_file.write(content)
# tmp_path = tmp_file.name
# atoms = read(tmp_path)
# os.unlink(tmp_path)
# return atoms
# except Exception as e:
# logger.error(f"Failed to convert structure: {str(e)}")
# return None
# def generate_symmetry_cif(structure: Structure) -> str:
# """生成对称性CIF"""
# analyzer = SpacegroupAnalyzer(structure)
# structure = analyzer.get_refined_structure()
# with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
# cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
# cif_writer.write_file(tmp_file.name)
# tmp_file.seek(0)
# return tmp_file.read()
# def optimize_structure(atoms: Atoms, output_format: str):
# """优化晶体结构"""
# atoms.calc = calc
# try:
# import io
# # from contextlib import redirect_stdout
# # # 创建StringIO对象捕获输出
# # f = io.StringIO()
# # dyn = FIRE(FrechetCellFilter(atoms))
# # # 同时捕获并输出到控制台
# # with redirect_stdout(f):
# # dyn.run(fmax=FMAX)
# # # 获取捕获的日志
# # optimization_log = f.getvalue()
# temp_output = StringIO()
# # 保存原始的stdout
# original_stdout = sys.stdout
# # 重定向stdout到StringIO对象
# sys.stdout = temp_output
# dyn = FIRE(FrechetCellFilter(atoms))
# dyn.run(fmax=FMAX)
# sys.stdout = original_stdout
# output_string = temp_output.getvalue()
# temp_output.close()
# optimization_log = output_string
# # 同时输出到控制台
# total_energy = atoms.get_potential_energy()
# except Exception as e:
# return handle_general_error(e)
# #atoms.get_potential_energy() 函数解析
# # atoms.get_potential_energy() 是 ASE (Atomic Simulation Environment) 中 Atoms 对象的一个方法,用于获取原子系统的势能(或总能量)。
# # 功能与用途
# # 获取能量 :返回原子系统的计算总能量,通常以电子伏特 (eV) 为单位。
# # 用途
# # 评估结构稳定性(能量越低的结构通常越稳定)
# # 计算反应能垒和反应能
# # 分析能量随结构变化的趋势
# # 作为结构优化的目标函数
# # 计算分子或材料的吸附能、形成能等
# # 工作原理
# # 计算引擎依赖
# # 该方法不会自行计算能量,而是从附加到 Atoms 对象的计算器 (calculator) 获取能量
# # 需要先给 Atoms 对象设置一个计算器(如 VASP、Quantum ESPRESSO、GPAW 等)
# # 执行机制
# # 如果能量已计算过且原子结构未改变,则返回缓存值
# # 否则会触发计算器执行能量计算
# # 处理对称性
# if output_format == "cif":
# optimized_structure = Structure.from_ase_atoms(atoms)
# content = generate_symmetry_cif(optimized_structure)
# #print('xxx',content)
# #print('yyy',total_energy)
# # 格式化返回结果
# format_result = f"""
# The following is the optimized crystal structure information:
# ### Optimization Results (using FIRE(eqV2_86M) algorithm):
# **Total Energy: {total_energy} eV**
# #### Optimizing Log:
# ```text
# {optimization_log}
# ```
# ### Optimized {output_format.upper()} Content:
# ```{content}
# {optimized_structure[:300]}
# ```
# """
# print("output_log",format_result)
input_format = "cif" # generated using pymatgen
content = """
data_H2O
_symmetry_space_group_name_H-M 'P 1'
_cell_length_a 7.60356659
_cell_length_b 7.60356659
_cell_length_c 7.14296200
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 120.00000516
_symmetry_Int_Tables_number 1
_chemical_formula_structural H2O
_chemical_formula_sum 'H24 O12'
_cell_volume 357.63799926
_cell_formula_units_Z 12
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
H H0 1 0.33082300 0.33082300 0.69642800 1
H H1 1 0.66917700 0.00000000 0.69642800 1
H H2 1 0.00000000 0.66917700 0.69642800 1
H H3 1 0.66917700 0.66917700 0.19642800 1
H H4 1 0.33082300 0.00000000 0.19642800 1
H H5 1 0.00000000 0.33082300 0.19642800 1
H H6 1 0.45234700 0.45234700 0.51064600 1
H H7 1 0.54765300 0.00000000 0.51064600 1
H H8 1 0.00000000 0.54765300 0.51064600 1
H H9 1 0.54765300 0.54765300 0.01064600 1
H H10 1 0.45234700 0.00000000 0.01064600 1
H H11 1 0.00000000 0.45234700 0.01064600 1
H H12 1 0.78617100 0.66371600 0.47884700 1
H H13 1 0.33628400 0.12245500 0.47884700 1
H H14 1 0.87754500 0.21382900 0.47884700 1
H H15 1 0.66371600 0.78617100 0.47884700 1
H H16 1 0.12245500 0.33628400 0.47884700 1
H H17 1 0.21382900 0.87754500 0.47884700 1
H H18 1 0.21382900 0.33628400 0.97884700 1
H H19 1 0.66371600 0.87754500 0.97884700 1
H H20 1 0.12245500 0.78617100 0.97884700 1
H H21 1 0.33628400 0.21382900 0.97884700 1
H H22 1 0.87754500 0.66371600 0.97884700 1
H H23 1 0.78617100 0.12245500 0.97884700 1
O O24 1 0.32664200 0.32664200 0.55565800 1
O O25 1 0.67335800 0.00000000 0.55565800 1
O O26 1 0.00000000 0.67335800 0.55565800 1
O O27 1 0.67335800 0.67335800 0.05565800 1
O O28 1 0.32664200 0.00000000 0.05565800 1
O O29 1 0.00000000 0.32664200 0.05565800 1
O O30 1 0.66060500 0.66060500 0.42957500 1
O O31 1 0.33939500 0.00000000 0.42957500 1
O O32 1 0.00000000 0.33939500 0.42957500 1
O O33 1 0.33939500 0.33939500 0.92957500 1
O O34 1 0.66060500 0.00000000 0.92957500 1
O O35 1 0.00000000 0.66060500 0.92957500 1
"""
# atoms=convert_structure(input_format=input_format,content=content)
# optimize_structure(atoms=atoms,output_format='cif')
# 添加新的异步LLM工具包装optimize_structure功能
import asyncio
from io import StringIO
import sys
import tempfile
from ase.io import read, write
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
import os
from typing import Optional, Dict, Any
from ase.atoms import Atoms
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
import logging
logger = logging.getLogger(__name__)
# 初始化FairChem模型
calc = None
def init_model():
"""初始化FairChem模型"""
global calc
if calc is not None:
return
try:
from fairchem.core import OCPCalculator
from tools_for_ms.services_tools.Configs import FAIRCHEM_MODEL_PATH
calc = OCPCalculator(checkpoint_path=FAIRCHEM_MODEL_PATH)
logger.info("FairChem model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize FairChem model: {str(e)}")
raise
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
"""将输入内容转换为Atoms对象"""
try:
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
atoms = read(tmp_path)
os.unlink(tmp_path)
return atoms
except Exception as e:
logger.error(f"Failed to convert structure: {str(e)}")
return None
def generate_symmetry_cif(structure: Structure) -> str:
"""生成对称性CIF"""
analyzer = SpacegroupAnalyzer(structure)
structure_refined = analyzer.get_refined_structure()
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
cif_writer.write_file(tmp_file.name)
tmp_file.seek(0)
return tmp_file.read()
def optimize_structure(atoms: Atoms, output_format: str) -> Dict[str, Any]:
"""优化晶体结构"""
atoms.calc = calc
try:
# 捕获优化过程的输出
temp_output = StringIO()
original_stdout = sys.stdout
sys.stdout = temp_output
# 执行优化
from tools_for_ms.services_tools.Configs import FMAX
dyn = FIRE(FrechetCellFilter(atoms))
dyn.run(fmax=FMAX)
# 恢复标准输出并获取日志
sys.stdout = original_stdout
optimization_log = temp_output.getvalue()
temp_output.close()
# 获取总能量
total_energy = atoms.get_potential_energy()
# 处理优化后的结构
if output_format == "cif":
optimized_structure = Structure.from_ase_atoms(atoms)
content = generate_symmetry_cif(optimized_structure)
else:
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
write(tmp_file.name, atoms)
tmp_file.seek(0)
content = tmp_file.read()
# 格式化返回结果
format_result = f"""
The following is the optimized crystal structure information:
### Optimization Results (using FIRE(eqV2_86M) algorithm):
**Total Energy: {total_energy} eV**
#### Optimizing Log:
```text
{optimization_log}
```
### Optimized {output_format.upper()} Content:
```
{content[:300]}
```
"""
# return {
# "total_energy": total_energy,
# "optimization_log": optimization_log,
# "content": content,
# "formatted_result": format_result
# }
return format_result
except Exception as e:
logger.error(f"Failed to optimize structure: {str(e)}")
raise e
@llm_tool(name="optimize_crystal_structure",
description="Optimize crystal structure using FairChem model")
async def optimize_crystal_structure(
content: str,
input_format: str = "cif",
output_format: str = "cif"
) -> Dict[str, Any]:
"""
Optimize crystal structure using FairChem model.
Args:
content: Crystal structure content string
input_format: Input format (cif, xyz, vasp)
output_format: Output format (cif, xyz, vasp)
Returns:
Optimized structure with energy and optimization log
"""
# 确保模型已初始化
if calc is None:
init_model()
# 使用asyncio.to_thread异步执行可能阻塞的操作
def run_optimization():
# 转换结构
atoms = convert_structure(input_format, content)
if atoms is None:
raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确")
# 优化结构
return optimize_structure(atoms, output_format)
# 直接返回结果或抛出异常
return await asyncio.to_thread(run_optimization)