构建mars_toolkit,删除tools_for_ms

This commit is contained in:
lzy
2025-04-02 12:53:50 +08:00
parent 603304e10f
commit a77c2cd377
73 changed files with 1884 additions and 896 deletions

View File

@@ -0,0 +1,18 @@
"""
Query Module
This module provides query tools for materials science, including:
- Materials Project database queries
- OQMD database queries
- Dify knowledge base retrieval
- Web search
"""
from mars_toolkit.query.mp_query import (
search_material_property_from_material_project,
get_crystal_structures_from_materials_project,
get_mpid_from_formula
)
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
from mars_toolkit.query.web_search import search_online

View File

@@ -0,0 +1,84 @@
"""
Dify Search Module
This module provides functions for retrieving information from local materials science
literature knowledge base using Dify API.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import asyncio
import json
import requests
import codecs
from typing import Dict, Any
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
@llm_tool(
name="retrieval_from_knowledge_base",
description="Retrieve information from local materials science literature knowledge base"
)
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
"""
检索本地材料科学文献知识库中的相关信息
Args:
query: 查询字符串,如材料名称"CsPbBr3"
topk: 返回结果数量默认3条
Returns:
包含文档ID、标题和相关性分数的字典
"""
# 设置Dify API的URL端点
url = f'{config.DIFY_ROOT_URL}/v1/chat-messages'
# 配置请求头包含API密钥和内容类型
headers = {
'Authorization': f'Bearer {config.DIFY_API_KEY}',
'Content-Type': 'application/json'
}
# 准备请求数据
data = {
"inputs": {"topK": topk}, # 设置返回的最大结果数量
"query": query, # 设置查询字符串
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
"conversation_id": "", # 不使用会话ID每次都是独立查询
"user": "abc-123" # 用户标识符
}
try:
# 发送POST请求到Dify API并获取响应
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
response = requests.post(url, headers=headers, json=data, timeout=1111)
# 获取响应文本
response_text = response.text
# 解码响应文本中的Unicode转义序列
response_text = codecs.decode(response_text, 'unicode_escape')
# 将响应文本解析为JSON对象
result_json = json.loads(response_text)
# 从响应中提取元数据
metadata = result_json.get("metadata", {})
# 构建包含关键信息的结果字典
useful_info = {
"id": metadata.get("document_id"), # 文档ID
"title": result_json.get("title"), # 文档标题
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
"score": metadata.get("score") # 相关性分数
}
# 返回提取的有用信息
return json.dumps(useful_info, ensure_ascii=False, indent=2)
except Exception as e:
# 捕获并处理所有可能的异常,返回错误信息
return f"错误: {str(e)}"

View File

@@ -0,0 +1,433 @@
"""
Materials Project Query Module
This module provides functions for querying the Materials Project database,
processing search results, and formatting responses.
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import glob
import json
import asyncio
import logging
import datetime
import os
from multiprocessing import Process, Manager
from typing import Dict, Any, List, Optional
from mp_api.client import MPRester
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
from mars_toolkit.core.error_handlers import handle_general_error
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
logger = logging.getLogger(__name__)
def parse_bool(param: str) -> bool | None:
"""
Parse a string parameter into a boolean value.
Args:
param: String parameter to parse (e.g., "true", "false")
Returns:
Boolean value if param is not empty, None otherwise
"""
if not param:
return None
return param.lower() == 'true'
def parse_list(param: str) -> List[str] | None:
"""
Parse a comma-separated string into a list of strings.
Args:
param: Comma-separated string (e.g., "Li,Fe,O")
Returns:
List of strings if param is not empty, None otherwise
"""
if not param:
return None
return param.split(',')
def parse_tuple(param: str) -> tuple[float, float] | None:
"""
Parse a comma-separated string into a tuple of two float values.
Used for range parameters like band_gap, density, etc.
Args:
param: Comma-separated string of two numbers (e.g., "0,3.5")
Returns:
Tuple of two float values if param is valid, None otherwise
"""
if not param:
return None
try:
values = param.split(',')
return (float(values[0]), float(values[1]))
except (ValueError, IndexError):
return None
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
"""
Parse search parameters from query parameters.
Converts string query parameters into appropriate types for the Materials Project API.
"""
return {
'band_gap': parse_tuple(query_params.get('band_gap')),
'chemsys': parse_list(query_params.get('chemsys')),
'crystal_system': parse_list(query_params.get('crystal_system')),
'density': parse_tuple(query_params.get('density')),
'formation_energy': parse_tuple(query_params.get('formation_energy')),
'elements': parse_list(query_params.get('elements')),
'exclude_elements': parse_list(query_params.get('exclude_elements')),
'formula': parse_list(query_params.get('formula')),
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
'is_metal': parse_bool(query_params.get('is_metal')),
'is_stable': parse_bool(query_params.get('is_stable')),
'magnetic_ordering': query_params.get('magnetic_ordering'),
'material_ids': parse_list(query_params.get('material_ids')),
'total_energy': parse_tuple(query_params.get('total_energy')),
'num_elements': parse_tuple(query_params.get('num_elements')),
'volume': parse_tuple(query_params.get('volume')),
'chunk_size': int(query_params.get('chunk_size', '5'))
}
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]] | str:
"""
Process search results from the Materials Project API.
Extracts relevant fields from each document and formats them into a consistent structure.
Returns:
List of processed documents or error message string if an exception occurs
"""
try:
fields = [
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
'universal_anisotropy', 'theoretical'
]
res = []
for doc in docs:
try:
new_docs = {}
for field_name in fields:
new_docs[field_name] = doc.get(field_name, '')
res.append(new_docs)
except Exception as e:
logger.warning(f"Error processing document: {str(e)}")
continue
return res
except Exception as e:
error_msg = f"Error in process_search_results: {str(e)}"
logger.error(error_msg)
return error_msg
def _search_worker(queue, api_key, **kwargs):
"""
Worker function for executing Materials Project API searches.
Runs in a separate process to perform the actual API call and puts results in the queue.
Args:
queue: Multiprocessing queue for returning results
api_key: Materials Project API key
**kwargs: Search parameters to pass to the API
"""
try:
import os
import traceback
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
# 初始化 MPRester 客户端
with MPRester(api_key) as mpr:
result = mpr.materials.summary.search(**kwargs)
# 检查结果
if result:
# 尝试使用更安全的方式处理结果
processed_results = []
for doc in result:
try:
# 尝试使用 model_dump 方法
processed_doc = doc.model_dump()
processed_results.append(processed_doc)
except AttributeError:
# 如果没有 model_dump 方法,尝试使用 dict 方法
try:
processed_doc = doc.dict()
processed_results.append(processed_doc)
except AttributeError:
# 如果没有 dict 方法,尝试直接转换为字典
if hasattr(doc, "__dict__"):
processed_doc = doc.__dict__
# 移除可能导致序列化问题的特殊属性
if "_sa_instance_state" in processed_doc:
del processed_doc["_sa_instance_state"]
processed_results.append(processed_doc)
else:
# 最后的尝试,直接使用 doc
processed_results.append(doc)
queue.put(processed_results)
else:
queue.put([])
except Exception as e:
queue.put(e)
async def execute_search(search_args: Dict[str, Any], timeout: int = 120) -> List[Dict[str, Any]] | str:
"""
Execute a search against the Materials Project API.
Runs the search in a separate process to handle potential timeouts and returns the results.
Args:
search_args: Dictionary of search parameters
timeout: Maximum time in seconds to wait for the search to complete
Returns:
List of document dictionaries from the search results or error message string if an exception occurs
"""
# 确保 formula 参数是列表类型
if 'formula' in search_args and isinstance(search_args['formula'], str):
search_args['formula'] = [search_args['formula']]
manager = Manager()
queue = manager.Queue()
try:
p = Process(target=_search_worker, args=(queue, config.MP_API_KEY), kwargs=search_args)
p.start()
p.join(timeout=timeout)
if p.is_alive():
logger.warning(f"Terminating worker process {p.pid} due to timeout")
p.terminate()
p.join()
error_msg = f"Request timed out after {timeout} seconds"
return error_msg
try:
result = queue.get(timeout=timeout)
if isinstance(result, Exception):
logger.error(f"Error in search worker: {str(result)}")
if hasattr(result, "__traceback__"):
import traceback
tb_str = ''.join(traceback.format_exception(None, result, result.__traceback__))
return f"Error in search worker: {str(result)}"
return result
except queue.Empty:
error_msg = "Failed to retrieve data from queue (timeout)"
logger.error(error_msg)
return error_msg
except Exception as e:
error_msg = f"Error in execute_search: {str(e)}"
logger.error(error_msg)
return error_msg
@llm_tool(name="search_material_property_from_material_project", description="Search materials in Materials Project database by formula and properties")
async def search_material_property_from_material_project(
formula: str | list[str],
chemsys: Optional[str | list[str] | None] = None,
crystal_system: Optional[str | list[str] | None] = None,
is_gap_direct: Optional[bool | None] = None,
is_stable: Optional[bool | None] = None,
) -> str:
"""
Search materials in Materials Project database.
Args:
formula: Chemical formula(s) (e.g., "Fe2O3" or ["ABO3", "Si*"])
chemsys: Chemical system(s) (e.g., "Li-Fe-O")
crystal_system: Crystal system(s) (e.g., "Cubic")
is_gap_direct: Filter for direct band gap materials
is_stable: Filter for thermodynamically stable materials
Returns:
JSON formatted material properties data
"""
# 验证晶系参数
VALID_CRYSTAL_SYSTEMS = ['Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal', 'Cubic']
# 验证晶系参数是否有效
if crystal_system is not None:
if isinstance(crystal_system, str):
if crystal_system not in VALID_CRYSTAL_SYSTEMS:
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
elif isinstance(crystal_system, list):
for cs in crystal_system:
if cs not in VALID_CRYSTAL_SYSTEMS:
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
# 确保 formula 是列表类型
if isinstance(formula, str):
formula = [formula]
params = {
"chemsys": chemsys,
"crystal_system": crystal_system,
"formula": formula,
"is_gap_direct": is_gap_direct,
"is_stable": is_stable,
"chunk_size": 5,
}
# Filter out None values
params = {k: v for k, v in params.items() if v is not None}
mp_id_list = await get_mpid_from_formula(formula=formula)
try:
res=[]
for mp_id in mp_id_list:
crystal_props = extract_cif_info(config.LOCAL_MP_ROOT+f"/Props/{mp_id}.json", ['all_fields'])
res.append(crystal_props)
if len(res) == 0:
return "No results found, please try again."
# Format response with top results
try:
# 创建包含索引的JSON结果
formatted_results = []
for i, item in enumerate(res[:config.MP_TOPK], 1):
formatted_result = f"[property {i} begin]\n"
formatted_result += json.dumps(item, indent=2)
formatted_result += f"\n[property {i} end]\n\n"
formatted_results.append(formatted_result)
# 将所有结果合并为一个字符串
res_chunk = "\n\n".join(formatted_results)
res_template = f"""
Here are the search results from the Materials Project database:
Due to length limitations, only the top {config.MP_TOPK} results are shown below:\n
{res_chunk}
If you need more results, please modify your search criteria or try different query parameters.
"""
return res_template
except Exception as format_error:
logger.error(f"Error formatting results: {str(format_error)}")
return str(format_error)
except Exception as e:
logger.error(f"Error in search_material_property_from_material_project: {str(e)}")
return str(e)
@llm_tool(name="get_crystal_structures_from_materials_project", description="Get symmetrized crystal structures CIF data from Materials Project database by chemical formula")
async def get_crystal_structures_from_materials_project(
formulas: list[str],
conventional_unit_cell: bool = True,
symprec: float = 0.1
) -> str:
"""
Get crystal structures from Materials Project database by chemical formula and apply symmetrization.
Args:
formulas: List of chemical formulas (e.g., ["Fe2O3", "SiO2", "TiO2"])
conventional_unit_cell: Whether to return conventional unit cell (True) or primitive cell (False)
symprec: Precision parameter for symmetrization
Returns:
Formatted text containing symmetrized CIF data
"""
result={}
mp_id_list=await get_mpid_from_formula(formula=formulas)
for i,mp_id in enumerate(mp_id_list):
cif_file = glob.glob(config.LOCAL_MP_ROOT+f"/MPDatasets/{mp_id}.cif")[0]
structure = Structure.from_file(cif_file)
# 如果需要常规单元格
if conventional_unit_cell:
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
# 对结构进行对称化处理
sga = SpacegroupAnalyzer(structure, symprec=symprec)
symmetrized_structure = sga.get_refined_structure()
# 使用CifWriter生成CIF数据
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
cif_data = str(cif_writer)
# 删除CIF文件中的对称性操作部分
cif_data = remove_symmetry_equiv_xyz(cif_data)
cif_data=cif_data.replace('# generated using pymatgen',"")
# 生成一个唯一的键
formula = structure.composition.reduced_formula
key = f"{formula}_{i}"
result[key] = cif_data
# 只保留前config.MP_TOPK个结果
if len(result) >= config.MP_TOPK:
break
try:
prompt = f"""
# Materials Project Symmetrized Crystal Structure Data
Below are symmetrized crystal structure data for {len(result)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
"""
for i, (key, cif_data) in enumerate(result.items(), 1):
prompt += f"[cif {i} begin]\n"
prompt += cif_data
prompt += f"\n[cif {i} end]\n\n"
prompt += """
## Usage Instructions
1. You can copy the above CIF data and save it as .cif files
2. Open these files with crystal structure visualization software (such as VESTA, Mercury, Avogadro, etc.)
3. These structures can be used for further material analysis, simulation, or visualization
CIF files contain complete structural information of crystals, including cell parameters, atomic coordinates, symmetry, etc.
Symmetrization helps identify and optimize crystal symmetry, making the structure more standardized and accurate.
"""
return prompt
except Exception as format_error:
logger.error(f"Error formatting crystal structures: {str(format_error)}")
return str(format_error)
@llm_tool(name="get_mpid_from_formula", description="Get material IDs (mpid) from Materials Project database by chemical formula")
async def get_mpid_from_formula(formula: str) -> List[str]:
"""
Get material IDs (mpid) from Materials Project database by chemical formula.
Returns mpids for the lowest energy structures.
Args:
formula: Chemical formula (e.g., "Fe2O3")
Returns:
List of material IDs
"""
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
id_list = []
try:
with MPRester(config.MP_API_KEY) as mpr:
docs = mpr.materials.summary.search(formula=formula)
for doc in docs:
id_list.append(doc.material_id)
return id_list
except Exception as e:
logger.error(f"Error getting mpid from formula: {str(e)}")
return []

View File

@@ -0,0 +1,105 @@
"""
OQMD Query Module
This module provides functions for querying the Open Quantum Materials Database (OQMD).
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import logging
import httpx
import pandas as pd
from bs4 import BeautifulSoup
from io import StringIO
from typing import Annotated
from mars_toolkit.core.llm_tools import llm_tool
logger = logging.getLogger(__name__)
@llm_tool(name="fetch_chemical_composition_from_OQMD", description="Fetch material data for a chemical composition from OQMD database")
async def fetch_chemical_composition_from_OQMD(
composition: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
) -> str:
"""
Fetch material data for a chemical composition from OQMD database.
Args:
composition: Chemical formula (e.g., Fe2O3, LiFePO4)
Returns:
Formatted text with material information and property tables
"""
# Fetch data from OQMD
url = f"https://www.oqmd.org/materials/composition/{composition}"
try:
async with httpx.AsyncClient(timeout=100.0) as client:
response = await client.get(url)
response.raise_for_status()
# Validate response content
if not response.text or len(response.text) < 100:
raise ValueError("Invalid response content from OQMD API")
# Parse HTML data
html = response.text
soup = BeautifulSoup(html, 'html.parser')
# Parse basic data
basic_data = []
h1_element = soup.find('h1')
if h1_element:
basic_data.append(h1_element.text.strip())
else:
basic_data.append(f"Material: {composition}")
for script in soup.find_all('p'):
if script:
combined_text = ""
for element in script.contents:
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
url = "https://www.oqmd.org" + element['href']
combined_text += f"[{element.text.strip()}]({url}) "
elif hasattr(element, 'text'):
combined_text += element.text.strip() + " "
else:
combined_text += str(element).strip() + " "
basic_data.append(combined_text.strip())
# Parse table data
table_data = ""
table = soup.find('table')
if table:
try:
df = pd.read_html(StringIO(str(table)))[0]
df = df.fillna('')
df = df.replace([float('inf'), float('-inf')], '')
table_data = df.to_markdown(index=False)
except Exception as e:
logger.error(f"Error parsing table: {str(e)}")
table_data = "Error parsing table data"
# Integrate data into a single text
combined_text = "\n\n".join(basic_data)
if table_data:
combined_text += "\n\n## Material Properties Table\n\n" + table_data
return combined_text
except httpx.HTTPStatusError as e:
logger.error(f"OQMD API request failed: {str(e)}")
return f"Error: OQMD API request failed - {str(e)}"
except httpx.TimeoutException:
logger.error("OQMD API request timed out")
return "Error: OQMD API request timed out"
except httpx.NetworkError as e:
logger.error(f"Network error occurred: {str(e)}")
return f"Error: Network error occurred - {str(e)}"
except ValueError as e:
logger.error(f"Invalid response content: {str(e)}")
return f"Error: Invalid response content - {str(e)}"
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return f"Error: Unexpected error occurred - {str(e)}"

View File

@@ -0,0 +1,77 @@
"""
Web Search Module
This module provides functions for searching information on the web.
"""
import asyncio
from typing import Annotated, Dict, Any, List
from langchain_community.utilities import SearxSearchWrapper
from mars_toolkit.core.llm_tools import llm_tool
from mars_toolkit.core.config import config
@llm_tool(name="search_online", description="Search scientific information online and return results as a string")
async def search_online(
query: Annotated[str, "Search term"],
num_results: Annotated[int, "Number of results (1-20)"] = 5
) -> str:
"""
Searches for scientific information online and returns results as a formatted string.
Args:
query: Search term for scientific content
num_results: Number of results to return (1-20)
Returns:
Formatted string with search results (titles, snippets, links)
"""
# 确保 num_results 是整数
try:
num_results = int(num_results)
except (TypeError, ValueError):
num_results = 5
# Parameter validation
if num_results < 1:
num_results = 1
elif num_results > 20:
num_results = 20
# Initialize search wrapper
search = SearxSearchWrapper(
searx_host=config.SEARXNG_HOST,
categories=["science",],
k=num_results
)
# Execute search in a separate thread to avoid blocking the event loop
# since SearxSearchWrapper doesn't have native async support
loop = asyncio.get_event_loop()
raw_results = await loop.run_in_executor(
None,
lambda: search.results(query, language=['en','zh'], num_results=num_results)
)
# Transform results into structured format
formatted_results = []
for result in raw_results:
formatted_results.append({
"title": result.get("title", ""),
"snippet": result.get("snippet", ""),
"link": result.get("link", ""),
"source": result.get("source", "")
})
# Convert the results to a formatted string
result_str = f"Search Results for '{query}' ({len(formatted_results)} items):\n\n"
for i, result in enumerate(formatted_results, 1):
result_str += f"Result {i}:\n"
result_str += f"Title: {result['title']}\n"
result_str += f"Summary: {result['snippet']}\n"
result_str += f"Link: {result['link']}\n"
result_str += f"Source: {result['source']}\n\n"
return result_str