构建mars_toolkit,删除tools_for_ms
This commit is contained in:
18
mars_toolkit/query/__init__.py
Normal file
18
mars_toolkit/query/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Query Module
|
||||
|
||||
This module provides query tools for materials science, including:
|
||||
- Materials Project database queries
|
||||
- OQMD database queries
|
||||
- Dify knowledge base retrieval
|
||||
- Web search
|
||||
"""
|
||||
|
||||
from mars_toolkit.query.mp_query import (
|
||||
search_material_property_from_material_project,
|
||||
get_crystal_structures_from_materials_project,
|
||||
get_mpid_from_formula
|
||||
)
|
||||
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
from mars_toolkit.query.web_search import search_online
|
||||
BIN
mars_toolkit/query/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/web_search.cpython-310.pyc
Normal file
BIN
mars_toolkit/query/__pycache__/web_search.cpython-310.pyc
Normal file
Binary file not shown.
84
mars_toolkit/query/dify_search.py
Normal file
84
mars_toolkit/query/dify_search.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Dify Search Module
|
||||
|
||||
This module provides functions for retrieving information from local materials science
|
||||
literature knowledge base using Dify API.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import requests
|
||||
import codecs
|
||||
from typing import Dict, Any
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
|
||||
@llm_tool(
|
||||
name="retrieval_from_knowledge_base",
|
||||
description="Retrieve information from local materials science literature knowledge base"
|
||||
)
|
||||
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
|
||||
"""
|
||||
检索本地材料科学文献知识库中的相关信息
|
||||
|
||||
Args:
|
||||
query: 查询字符串,如材料名称"CsPbBr3"
|
||||
topk: 返回结果数量,默认3条
|
||||
|
||||
Returns:
|
||||
包含文档ID、标题和相关性分数的字典
|
||||
"""
|
||||
# 设置Dify API的URL端点
|
||||
url = f'{config.DIFY_ROOT_URL}/v1/chat-messages'
|
||||
|
||||
# 配置请求头,包含API密钥和内容类型
|
||||
headers = {
|
||||
'Authorization': f'Bearer {config.DIFY_API_KEY}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# 准备请求数据
|
||||
data = {
|
||||
"inputs": {"topK": topk}, # 设置返回的最大结果数量
|
||||
"query": query, # 设置查询字符串
|
||||
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
|
||||
"conversation_id": "", # 不使用会话ID,每次都是独立查询
|
||||
"user": "abc-123" # 用户标识符
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送POST请求到Dify API并获取响应
|
||||
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
|
||||
response = requests.post(url, headers=headers, json=data, timeout=1111)
|
||||
|
||||
# 获取响应文本
|
||||
response_text = response.text
|
||||
|
||||
# 解码响应文本中的Unicode转义序列
|
||||
response_text = codecs.decode(response_text, 'unicode_escape')
|
||||
|
||||
# 将响应文本解析为JSON对象
|
||||
result_json = json.loads(response_text)
|
||||
|
||||
# 从响应中提取元数据
|
||||
metadata = result_json.get("metadata", {})
|
||||
|
||||
# 构建包含关键信息的结果字典
|
||||
useful_info = {
|
||||
"id": metadata.get("document_id"), # 文档ID
|
||||
"title": result_json.get("title"), # 文档标题
|
||||
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
|
||||
"score": metadata.get("score") # 相关性分数
|
||||
}
|
||||
|
||||
# 返回提取的有用信息
|
||||
return json.dumps(useful_info, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并处理所有可能的异常,返回错误信息
|
||||
return f"错误: {str(e)}"
|
||||
433
mars_toolkit/query/mp_query.py
Normal file
433
mars_toolkit/query/mp_query.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Materials Project Query Module
|
||||
|
||||
This module provides functions for querying the Materials Project database,
|
||||
processing search results, and formatting responses.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import datetime
|
||||
import os
|
||||
from multiprocessing import Process, Manager
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from mp_api.client import MPRester
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.error_handlers import handle_general_error
|
||||
from mars_toolkit.core.cif_utils import read_cif_txt_file, extract_cif_info, remove_symmetry_equiv_xyz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_bool(param: str) -> bool | None:
|
||||
"""
|
||||
Parse a string parameter into a boolean value.
|
||||
|
||||
Args:
|
||||
param: String parameter to parse (e.g., "true", "false")
|
||||
|
||||
Returns:
|
||||
Boolean value if param is not empty, None otherwise
|
||||
"""
|
||||
if not param:
|
||||
return None
|
||||
return param.lower() == 'true'
|
||||
|
||||
def parse_list(param: str) -> List[str] | None:
|
||||
"""
|
||||
Parse a comma-separated string into a list of strings.
|
||||
|
||||
Args:
|
||||
param: Comma-separated string (e.g., "Li,Fe,O")
|
||||
|
||||
Returns:
|
||||
List of strings if param is not empty, None otherwise
|
||||
"""
|
||||
if not param:
|
||||
return None
|
||||
return param.split(',')
|
||||
|
||||
def parse_tuple(param: str) -> tuple[float, float] | None:
|
||||
"""
|
||||
Parse a comma-separated string into a tuple of two float values.
|
||||
|
||||
Used for range parameters like band_gap, density, etc.
|
||||
|
||||
Args:
|
||||
param: Comma-separated string of two numbers (e.g., "0,3.5")
|
||||
|
||||
Returns:
|
||||
Tuple of two float values if param is valid, None otherwise
|
||||
"""
|
||||
if not param:
|
||||
return None
|
||||
try:
|
||||
values = param.split(',')
|
||||
return (float(values[0]), float(values[1]))
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse search parameters from query parameters.
|
||||
|
||||
Converts string query parameters into appropriate types for the Materials Project API.
|
||||
"""
|
||||
return {
|
||||
'band_gap': parse_tuple(query_params.get('band_gap')),
|
||||
'chemsys': parse_list(query_params.get('chemsys')),
|
||||
'crystal_system': parse_list(query_params.get('crystal_system')),
|
||||
'density': parse_tuple(query_params.get('density')),
|
||||
'formation_energy': parse_tuple(query_params.get('formation_energy')),
|
||||
'elements': parse_list(query_params.get('elements')),
|
||||
'exclude_elements': parse_list(query_params.get('exclude_elements')),
|
||||
'formula': parse_list(query_params.get('formula')),
|
||||
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
|
||||
'is_metal': parse_bool(query_params.get('is_metal')),
|
||||
'is_stable': parse_bool(query_params.get('is_stable')),
|
||||
'magnetic_ordering': query_params.get('magnetic_ordering'),
|
||||
'material_ids': parse_list(query_params.get('material_ids')),
|
||||
'total_energy': parse_tuple(query_params.get('total_energy')),
|
||||
'num_elements': parse_tuple(query_params.get('num_elements')),
|
||||
'volume': parse_tuple(query_params.get('volume')),
|
||||
'chunk_size': int(query_params.get('chunk_size', '5'))
|
||||
}
|
||||
|
||||
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]] | str:
|
||||
"""
|
||||
Process search results from the Materials Project API.
|
||||
|
||||
Extracts relevant fields from each document and formats them into a consistent structure.
|
||||
|
||||
Returns:
|
||||
List of processed documents or error message string if an exception occurs
|
||||
"""
|
||||
try:
|
||||
fields = [
|
||||
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
|
||||
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
|
||||
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
|
||||
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
|
||||
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
|
||||
'universal_anisotropy', 'theoretical'
|
||||
]
|
||||
|
||||
res = []
|
||||
for doc in docs:
|
||||
try:
|
||||
new_docs = {}
|
||||
for field_name in fields:
|
||||
new_docs[field_name] = doc.get(field_name, '')
|
||||
res.append(new_docs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing document: {str(e)}")
|
||||
continue
|
||||
return res
|
||||
except Exception as e:
|
||||
error_msg = f"Error in process_search_results: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
def _search_worker(queue, api_key, **kwargs):
|
||||
"""
|
||||
Worker function for executing Materials Project API searches.
|
||||
|
||||
Runs in a separate process to perform the actual API call and puts results in the queue.
|
||||
|
||||
Args:
|
||||
queue: Multiprocessing queue for returning results
|
||||
api_key: Materials Project API key
|
||||
**kwargs: Search parameters to pass to the API
|
||||
"""
|
||||
try:
|
||||
import os
|
||||
import traceback
|
||||
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
|
||||
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
|
||||
|
||||
# 初始化 MPRester 客户端
|
||||
with MPRester(api_key) as mpr:
|
||||
result = mpr.materials.summary.search(**kwargs)
|
||||
|
||||
# 检查结果
|
||||
if result:
|
||||
# 尝试使用更安全的方式处理结果
|
||||
processed_results = []
|
||||
for doc in result:
|
||||
try:
|
||||
# 尝试使用 model_dump 方法
|
||||
processed_doc = doc.model_dump()
|
||||
processed_results.append(processed_doc)
|
||||
except AttributeError:
|
||||
# 如果没有 model_dump 方法,尝试使用 dict 方法
|
||||
try:
|
||||
processed_doc = doc.dict()
|
||||
processed_results.append(processed_doc)
|
||||
except AttributeError:
|
||||
# 如果没有 dict 方法,尝试直接转换为字典
|
||||
if hasattr(doc, "__dict__"):
|
||||
processed_doc = doc.__dict__
|
||||
# 移除可能导致序列化问题的特殊属性
|
||||
if "_sa_instance_state" in processed_doc:
|
||||
del processed_doc["_sa_instance_state"]
|
||||
processed_results.append(processed_doc)
|
||||
else:
|
||||
# 最后的尝试,直接使用 doc
|
||||
processed_results.append(doc)
|
||||
|
||||
queue.put(processed_results)
|
||||
else:
|
||||
queue.put([])
|
||||
except Exception as e:
|
||||
queue.put(e)
|
||||
|
||||
async def execute_search(search_args: Dict[str, Any], timeout: int = 120) -> List[Dict[str, Any]] | str:
|
||||
"""
|
||||
Execute a search against the Materials Project API.
|
||||
|
||||
Runs the search in a separate process to handle potential timeouts and returns the results.
|
||||
|
||||
Args:
|
||||
search_args: Dictionary of search parameters
|
||||
timeout: Maximum time in seconds to wait for the search to complete
|
||||
|
||||
Returns:
|
||||
List of document dictionaries from the search results or error message string if an exception occurs
|
||||
"""
|
||||
# 确保 formula 参数是列表类型
|
||||
if 'formula' in search_args and isinstance(search_args['formula'], str):
|
||||
search_args['formula'] = [search_args['formula']]
|
||||
|
||||
manager = Manager()
|
||||
queue = manager.Queue()
|
||||
|
||||
try:
|
||||
p = Process(target=_search_worker, args=(queue, config.MP_API_KEY), kwargs=search_args)
|
||||
p.start()
|
||||
p.join(timeout=timeout)
|
||||
|
||||
if p.is_alive():
|
||||
logger.warning(f"Terminating worker process {p.pid} due to timeout")
|
||||
p.terminate()
|
||||
p.join()
|
||||
error_msg = f"Request timed out after {timeout} seconds"
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
result = queue.get(timeout=timeout)
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error in search worker: {str(result)}")
|
||||
if hasattr(result, "__traceback__"):
|
||||
import traceback
|
||||
tb_str = ''.join(traceback.format_exception(None, result, result.__traceback__))
|
||||
return f"Error in search worker: {str(result)}"
|
||||
|
||||
return result
|
||||
|
||||
except queue.Empty:
|
||||
error_msg = "Failed to retrieve data from queue (timeout)"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Error in execute_search: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
@llm_tool(name="search_material_property_from_material_project", description="Search materials in Materials Project database by formula and properties")
|
||||
async def search_material_property_from_material_project(
|
||||
formula: str | list[str],
|
||||
chemsys: Optional[str | list[str] | None] = None,
|
||||
crystal_system: Optional[str | list[str] | None] = None,
|
||||
is_gap_direct: Optional[bool | None] = None,
|
||||
is_stable: Optional[bool | None] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search materials in Materials Project database.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula(s) (e.g., "Fe2O3" or ["ABO3", "Si*"])
|
||||
chemsys: Chemical system(s) (e.g., "Li-Fe-O")
|
||||
crystal_system: Crystal system(s) (e.g., "Cubic")
|
||||
is_gap_direct: Filter for direct band gap materials
|
||||
is_stable: Filter for thermodynamically stable materials
|
||||
Returns:
|
||||
JSON formatted material properties data
|
||||
"""
|
||||
# 验证晶系参数
|
||||
VALID_CRYSTAL_SYSTEMS = ['Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal', 'Cubic']
|
||||
|
||||
# 验证晶系参数是否有效
|
||||
if crystal_system is not None:
|
||||
if isinstance(crystal_system, str):
|
||||
if crystal_system not in VALID_CRYSTAL_SYSTEMS:
|
||||
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
|
||||
elif isinstance(crystal_system, list):
|
||||
for cs in crystal_system:
|
||||
if cs not in VALID_CRYSTAL_SYSTEMS:
|
||||
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
|
||||
|
||||
# 确保 formula 是列表类型
|
||||
if isinstance(formula, str):
|
||||
formula = [formula]
|
||||
|
||||
params = {
|
||||
"chemsys": chemsys,
|
||||
"crystal_system": crystal_system,
|
||||
"formula": formula,
|
||||
"is_gap_direct": is_gap_direct,
|
||||
"is_stable": is_stable,
|
||||
"chunk_size": 5,
|
||||
}
|
||||
|
||||
# Filter out None values
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
mp_id_list = await get_mpid_from_formula(formula=formula)
|
||||
try:
|
||||
res=[]
|
||||
for mp_id in mp_id_list:
|
||||
crystal_props = extract_cif_info(config.LOCAL_MP_ROOT+f"/Props/{mp_id}.json", ['all_fields'])
|
||||
res.append(crystal_props)
|
||||
|
||||
if len(res) == 0:
|
||||
return "No results found, please try again."
|
||||
|
||||
# Format response with top results
|
||||
try:
|
||||
# 创建包含索引的JSON结果
|
||||
formatted_results = []
|
||||
for i, item in enumerate(res[:config.MP_TOPK], 1):
|
||||
formatted_result = f"[property {i} begin]\n"
|
||||
formatted_result += json.dumps(item, indent=2)
|
||||
formatted_result += f"\n[property {i} end]\n\n"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
# 将所有结果合并为一个字符串
|
||||
res_chunk = "\n\n".join(formatted_results)
|
||||
res_template = f"""
|
||||
Here are the search results from the Materials Project database:
|
||||
Due to length limitations, only the top {config.MP_TOPK} results are shown below:\n
|
||||
{res_chunk}
|
||||
If you need more results, please modify your search criteria or try different query parameters.
|
||||
"""
|
||||
return res_template
|
||||
except Exception as format_error:
|
||||
logger.error(f"Error formatting results: {str(format_error)}")
|
||||
return str(format_error)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search_material_property_from_material_project: {str(e)}")
|
||||
return str(e)
|
||||
|
||||
@llm_tool(name="get_crystal_structures_from_materials_project", description="Get symmetrized crystal structures CIF data from Materials Project database by chemical formula")
|
||||
async def get_crystal_structures_from_materials_project(
|
||||
formulas: list[str],
|
||||
conventional_unit_cell: bool = True,
|
||||
symprec: float = 0.1
|
||||
) -> str:
|
||||
"""
|
||||
Get crystal structures from Materials Project database by chemical formula and apply symmetrization.
|
||||
|
||||
Args:
|
||||
formulas: List of chemical formulas (e.g., ["Fe2O3", "SiO2", "TiO2"])
|
||||
conventional_unit_cell: Whether to return conventional unit cell (True) or primitive cell (False)
|
||||
symprec: Precision parameter for symmetrization
|
||||
|
||||
Returns:
|
||||
Formatted text containing symmetrized CIF data
|
||||
"""
|
||||
result={}
|
||||
mp_id_list=await get_mpid_from_formula(formula=formulas)
|
||||
|
||||
for i,mp_id in enumerate(mp_id_list):
|
||||
cif_file = glob.glob(config.LOCAL_MP_ROOT+f"/MPDatasets/{mp_id}.cif")[0]
|
||||
structure = Structure.from_file(cif_file)
|
||||
# 如果需要常规单元格
|
||||
if conventional_unit_cell:
|
||||
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
|
||||
|
||||
# 对结构进行对称化处理
|
||||
sga = SpacegroupAnalyzer(structure, symprec=symprec)
|
||||
symmetrized_structure = sga.get_refined_structure()
|
||||
|
||||
# 使用CifWriter生成CIF数据
|
||||
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
|
||||
cif_data = str(cif_writer)
|
||||
|
||||
# 删除CIF文件中的对称性操作部分
|
||||
cif_data = remove_symmetry_equiv_xyz(cif_data)
|
||||
cif_data=cif_data.replace('# generated using pymatgen',"")
|
||||
# 生成一个唯一的键
|
||||
formula = structure.composition.reduced_formula
|
||||
key = f"{formula}_{i}"
|
||||
|
||||
result[key] = cif_data
|
||||
|
||||
# 只保留前config.MP_TOPK个结果
|
||||
if len(result) >= config.MP_TOPK:
|
||||
break
|
||||
|
||||
try:
|
||||
prompt = f"""
|
||||
# Materials Project Symmetrized Crystal Structure Data
|
||||
|
||||
Below are symmetrized crystal structure data for {len(result)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
|
||||
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
|
||||
"""
|
||||
|
||||
for i, (key, cif_data) in enumerate(result.items(), 1):
|
||||
prompt += f"[cif {i} begin]\n"
|
||||
prompt += cif_data
|
||||
prompt += f"\n[cif {i} end]\n\n"
|
||||
|
||||
prompt += """
|
||||
## Usage Instructions
|
||||
|
||||
1. You can copy the above CIF data and save it as .cif files
|
||||
2. Open these files with crystal structure visualization software (such as VESTA, Mercury, Avogadro, etc.)
|
||||
3. These structures can be used for further material analysis, simulation, or visualization
|
||||
|
||||
CIF files contain complete structural information of crystals, including cell parameters, atomic coordinates, symmetry, etc.
|
||||
Symmetrization helps identify and optimize crystal symmetry, making the structure more standardized and accurate.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
except Exception as format_error:
|
||||
logger.error(f"Error formatting crystal structures: {str(format_error)}")
|
||||
return str(format_error)
|
||||
|
||||
@llm_tool(name="get_mpid_from_formula", description="Get material IDs (mpid) from Materials Project database by chemical formula")
|
||||
async def get_mpid_from_formula(formula: str) -> List[str]:
|
||||
"""
|
||||
Get material IDs (mpid) from Materials Project database by chemical formula.
|
||||
Returns mpids for the lowest energy structures.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula (e.g., "Fe2O3")
|
||||
|
||||
Returns:
|
||||
List of material IDs
|
||||
"""
|
||||
os.environ['HTTP_PROXY'] = config.HTTP_PROXY or ''
|
||||
os.environ['HTTPS_PROXY'] = config.HTTPS_PROXY or ''
|
||||
id_list = []
|
||||
try:
|
||||
with MPRester(config.MP_API_KEY) as mpr:
|
||||
docs = mpr.materials.summary.search(formula=formula)
|
||||
for doc in docs:
|
||||
id_list.append(doc.material_id)
|
||||
return id_list
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting mpid from formula: {str(e)}")
|
||||
return []
|
||||
105
mars_toolkit/query/oqmd_query.py
Normal file
105
mars_toolkit/query/oqmd_query.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
OQMD Query Module
|
||||
|
||||
This module provides functions for querying the Open Quantum Materials Database (OQMD).
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from bs4 import BeautifulSoup
|
||||
from io import StringIO
|
||||
from typing import Annotated
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@llm_tool(name="fetch_chemical_composition_from_OQMD", description="Fetch material data for a chemical composition from OQMD database")
|
||||
async def fetch_chemical_composition_from_OQMD(
|
||||
composition: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
|
||||
) -> str:
|
||||
"""
|
||||
Fetch material data for a chemical composition from OQMD database.
|
||||
|
||||
Args:
|
||||
composition: Chemical formula (e.g., Fe2O3, LiFePO4)
|
||||
|
||||
Returns:
|
||||
Formatted text with material information and property tables
|
||||
"""
|
||||
# Fetch data from OQMD
|
||||
url = f"https://www.oqmd.org/materials/composition/{composition}"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=100.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Validate response content
|
||||
if not response.text or len(response.text) < 100:
|
||||
raise ValueError("Invalid response content from OQMD API")
|
||||
|
||||
# Parse HTML data
|
||||
html = response.text
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# Parse basic data
|
||||
basic_data = []
|
||||
h1_element = soup.find('h1')
|
||||
if h1_element:
|
||||
basic_data.append(h1_element.text.strip())
|
||||
else:
|
||||
basic_data.append(f"Material: {composition}")
|
||||
|
||||
for script in soup.find_all('p'):
|
||||
if script:
|
||||
combined_text = ""
|
||||
for element in script.contents:
|
||||
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
|
||||
url = "https://www.oqmd.org" + element['href']
|
||||
combined_text += f"[{element.text.strip()}]({url}) "
|
||||
elif hasattr(element, 'text'):
|
||||
combined_text += element.text.strip() + " "
|
||||
else:
|
||||
combined_text += str(element).strip() + " "
|
||||
basic_data.append(combined_text.strip())
|
||||
|
||||
# Parse table data
|
||||
table_data = ""
|
||||
table = soup.find('table')
|
||||
if table:
|
||||
try:
|
||||
df = pd.read_html(StringIO(str(table)))[0]
|
||||
df = df.fillna('')
|
||||
df = df.replace([float('inf'), float('-inf')], '')
|
||||
table_data = df.to_markdown(index=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing table: {str(e)}")
|
||||
table_data = "Error parsing table data"
|
||||
|
||||
# Integrate data into a single text
|
||||
combined_text = "\n\n".join(basic_data)
|
||||
if table_data:
|
||||
combined_text += "\n\n## Material Properties Table\n\n" + table_data
|
||||
|
||||
return combined_text
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"OQMD API request failed: {str(e)}")
|
||||
return f"Error: OQMD API request failed - {str(e)}"
|
||||
except httpx.TimeoutException:
|
||||
logger.error("OQMD API request timed out")
|
||||
return "Error: OQMD API request timed out"
|
||||
except httpx.NetworkError as e:
|
||||
logger.error(f"Network error occurred: {str(e)}")
|
||||
return f"Error: Network error occurred - {str(e)}"
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid response content: {str(e)}")
|
||||
return f"Error: Invalid response content - {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
return f"Error: Unexpected error occurred - {str(e)}"
|
||||
77
mars_toolkit/query/web_search.py
Normal file
77
mars_toolkit/query/web_search.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Web Search Module
|
||||
|
||||
This module provides functions for searching information on the web.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated, Dict, Any, List
|
||||
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
from mars_toolkit.core.config import config
|
||||
|
||||
@llm_tool(name="search_online", description="Search scientific information online and return results as a string")
|
||||
async def search_online(
|
||||
query: Annotated[str, "Search term"],
|
||||
num_results: Annotated[int, "Number of results (1-20)"] = 5
|
||||
) -> str:
|
||||
"""
|
||||
Searches for scientific information online and returns results as a formatted string.
|
||||
|
||||
Args:
|
||||
query: Search term for scientific content
|
||||
num_results: Number of results to return (1-20)
|
||||
|
||||
Returns:
|
||||
Formatted string with search results (titles, snippets, links)
|
||||
"""
|
||||
# 确保 num_results 是整数
|
||||
try:
|
||||
num_results = int(num_results)
|
||||
except (TypeError, ValueError):
|
||||
num_results = 5
|
||||
|
||||
# Parameter validation
|
||||
if num_results < 1:
|
||||
num_results = 1
|
||||
elif num_results > 20:
|
||||
num_results = 20
|
||||
|
||||
# Initialize search wrapper
|
||||
search = SearxSearchWrapper(
|
||||
searx_host=config.SEARXNG_HOST,
|
||||
categories=["science",],
|
||||
k=num_results
|
||||
)
|
||||
|
||||
# Execute search in a separate thread to avoid blocking the event loop
|
||||
# since SearxSearchWrapper doesn't have native async support
|
||||
loop = asyncio.get_event_loop()
|
||||
raw_results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: search.results(query, language=['en','zh'], num_results=num_results)
|
||||
)
|
||||
|
||||
# Transform results into structured format
|
||||
formatted_results = []
|
||||
for result in raw_results:
|
||||
formatted_results.append({
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("snippet", ""),
|
||||
"link": result.get("link", ""),
|
||||
"source": result.get("source", "")
|
||||
})
|
||||
|
||||
# Convert the results to a formatted string
|
||||
result_str = f"Search Results for '{query}' ({len(formatted_results)} items):\n\n"
|
||||
|
||||
for i, result in enumerate(formatted_results, 1):
|
||||
result_str += f"Result {i}:\n"
|
||||
result_str += f"Title: {result['title']}\n"
|
||||
result_str += f"Summary: {result['snippet']}\n"
|
||||
result_str += f"Link: {result['link']}\n"
|
||||
result_str += f"Source: {result['source']}\n\n"
|
||||
|
||||
return result_str
|
||||
Reference in New Issue
Block a user