构建mars_toolkit,删除tools_for_ms
This commit is contained in:
13
mars_toolkit/core/__init__.py
Normal file
13
mars_toolkit/core/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Core Module
|
||||
|
||||
This module provides core functionality for the Mars Toolkit.
|
||||
"""
|
||||
|
||||
from mars_toolkit.core.config import config
|
||||
from mars_toolkit.core.utils import settings, setup_logging
|
||||
from mars_toolkit.core.error_handlers import (
|
||||
handle_minio_error, handle_http_error,
|
||||
handle_validation_error, handle_general_error
|
||||
)
|
||||
from mars_toolkit.core.llm_tools import llm_tool
|
||||
BIN
mars_toolkit/core/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/config.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/config.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/error_handlers.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/error_handlers.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc
Normal file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/utils.cpython-310.pyc
Normal file
BIN
mars_toolkit/core/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
121
mars_toolkit/core/cif_utils.py
Normal file
121
mars_toolkit/core/cif_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
CIF Utilities Module
|
||||
|
||||
This module provides basic functions for handling CIF (Crystallographic Information File) files,
|
||||
which are commonly used in materials science for representing crystal structures.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def read_cif_txt_file(file_path):
|
||||
"""
|
||||
Read the CIF file and return its content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the CIF file
|
||||
|
||||
Returns:
|
||||
String content of the CIF file or None if an error occurs
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def extract_cif_info(path: str, fields_name: list):
|
||||
"""
|
||||
Extract specific fields from the CIF description JSON file.
|
||||
|
||||
Args:
|
||||
path: Path to the JSON file containing CIF information
|
||||
fields_name: List of field categories to extract. Use 'all_fields' to extract all fields.
|
||||
Other options include 'basic_fields', 'energy_electronic_fields', 'metal_magentic_fields'
|
||||
|
||||
Returns:
|
||||
Dictionary containing the extracted fields
|
||||
"""
|
||||
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
|
||||
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
|
||||
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
|
||||
|
||||
selected_fields = []
|
||||
if fields_name[0] == 'all_fields':
|
||||
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
|
||||
else:
|
||||
for field in fields_name:
|
||||
selected_fields.extend(locals().get(field, []))
|
||||
|
||||
with open(path, 'r') as f:
|
||||
docs = json.load(f)
|
||||
|
||||
new_docs = {}
|
||||
for field_name in selected_fields:
|
||||
new_docs[field_name] = docs.get(field_name, '')
|
||||
|
||||
return new_docs
|
||||
|
||||
def remove_symmetry_equiv_xyz(cif_content):
|
||||
"""
|
||||
Remove symmetry operations section from CIF file content.
|
||||
|
||||
This is often useful when working with CIF files in certain visualization tools
|
||||
or when focusing on the basic structure without symmetry operations.
|
||||
|
||||
Args:
|
||||
cif_content: CIF file content string
|
||||
|
||||
Returns:
|
||||
Cleaned CIF content string with symmetry operations removed
|
||||
"""
|
||||
lines = cif_content.split('\n')
|
||||
output_lines = []
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# 检测循环开始
|
||||
if line == 'loop_':
|
||||
# 查看下一行,检查是否是对称性循环
|
||||
next_lines = []
|
||||
j = i + 1
|
||||
while j < len(lines) and lines[j].strip().startswith('_'):
|
||||
next_lines.append(lines[j].strip())
|
||||
j += 1
|
||||
|
||||
# 检查是否包含对称性操作标签
|
||||
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
|
||||
# 跳过整个循环块
|
||||
while i < len(lines):
|
||||
if i + 1 >= len(lines):
|
||||
break
|
||||
|
||||
next_line = lines[i + 1].strip()
|
||||
# 检查是否到达下一个循环或数据块
|
||||
if next_line == 'loop_' or next_line.startswith('data_'):
|
||||
break
|
||||
|
||||
# 检查是否到达原子位置部分
|
||||
if next_line.startswith('_atom_site_'):
|
||||
break
|
||||
|
||||
i += 1
|
||||
else:
|
||||
# 不是对称性循环,保留loop_行
|
||||
output_lines.append(lines[i])
|
||||
else:
|
||||
# 非循环开始行,直接保留
|
||||
output_lines.append(lines[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
59
mars_toolkit/core/config.py
Normal file
59
mars_toolkit/core/config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Configuration Module
|
||||
|
||||
This module provides configuration settings for the Mars Toolkit.
|
||||
It includes API keys, endpoints, paths, and other configuration parameters.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
class Config:
|
||||
"""Configuration class for Mars Toolkit"""
|
||||
|
||||
# Materials Project
|
||||
MP_API_KEY = 'PMASAg256b814q3OaSRWeVc7MKx4mlKI'
|
||||
MP_ENDPOINT = 'https://api.materialsproject.org/'
|
||||
MP_TOPK = 3
|
||||
|
||||
LOCAL_MP_ROOT = '/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/'
|
||||
|
||||
# Proxy
|
||||
HTTP_PROXY = 'http://192.168.168.1:20171'
|
||||
HTTPS_PROXY = 'http://192.168.168.1:20171'
|
||||
|
||||
# FairChem
|
||||
FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
|
||||
FMAX = 0.05
|
||||
|
||||
# MatterGen
|
||||
MATTERGENMODEL_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
|
||||
MATTERGENMODEL_RESULT_PATH = 'results/'
|
||||
|
||||
# Dify
|
||||
DIFY_ROOT_URL = 'http://192.168.191.101:6080'
|
||||
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
|
||||
|
||||
# Searxng
|
||||
SEARXNG_HOST="http://192.168.191.101:40032/"
|
||||
|
||||
# Visualization
|
||||
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
|
||||
|
||||
@classmethod
|
||||
def as_dict(cls) -> Dict[str, Any]:
|
||||
"""Return all configuration settings as a dictionary"""
|
||||
return {
|
||||
key: value for key, value in cls.__dict__.items()
|
||||
if not key.startswith('__') and not callable(value)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def update(cls, **kwargs):
|
||||
"""Update configuration settings"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(cls, key):
|
||||
setattr(cls, key, value)
|
||||
|
||||
|
||||
# Create a global instance for easy access
|
||||
config = Config()
|
||||
55
mars_toolkit/core/error_handlers.py
Normal file
55
mars_toolkit/core/error_handlers.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Error Handlers Module
|
||||
|
||||
This module provides error handling utilities for the Mars Toolkit.
|
||||
It includes functions for handling various types of errors that may occur
|
||||
during toolkit operations.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class APIError(HTTPException):
|
||||
"""自定义API错误类"""
|
||||
def __init__(self, status_code: int, detail: Any = None):
|
||||
super().__init__(status_code=status_code, detail=detail)
|
||||
logger.error(f"API Error: {status_code} - {detail}")
|
||||
|
||||
def handle_minio_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理MinIO相关错误"""
|
||||
logger.error(f"MinIO operation failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"MinIO operation failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_http_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理HTTP请求错误"""
|
||||
logger.error(f"HTTP request failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"HTTP request failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_validation_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理数据验证错误"""
|
||||
logger.error(f"Validation failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"Validation failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_general_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理通用错误"""
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"Unexpected error: {str(e)}"
|
||||
}
|
||||
213
mars_toolkit/core/llm_tools.py
Normal file
213
mars_toolkit/core/llm_tools.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
LLM Tools Module
|
||||
|
||||
This module provides decorators and utilities for defining, registering, and managing LLM tools.
|
||||
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
|
||||
registered tools for use with LLM APIs.
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
|
||||
import docstring_parser
|
||||
from pydantic import BaseModel, create_model, Field
|
||||
|
||||
# Registry to store all registered tools
|
||||
_TOOL_REGISTRY = {}
|
||||
|
||||
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
|
||||
"""
|
||||
Decorator to mark a function as an LLM tool.
|
||||
|
||||
This decorator registers the function as an LLM tool, generates a JSON schema for it,
|
||||
and makes it available for retrieval through the get_tools function.
|
||||
|
||||
Args:
|
||||
name: Optional custom name for the tool. If not provided, the function name will be used.
|
||||
description: Optional custom description for the tool. If not provided, the function's
|
||||
docstring will be used.
|
||||
|
||||
Returns:
|
||||
The decorated function with additional attributes for LLM tool functionality.
|
||||
|
||||
Example:
|
||||
@llm_tool(name="weather_lookup", description="Get current weather for a location")
|
||||
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
|
||||
'''Get weather information for a specific location.'''
|
||||
# Implementation...
|
||||
return {"temperature": 22.5, "conditions": "sunny"}
|
||||
"""
|
||||
# Handle case when decorator is used without parentheses: @llm_tool
|
||||
if callable(name):
|
||||
func = name
|
||||
name = None
|
||||
description = None
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
|
||||
def decorator(func: Callable) -> Callable:
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
return decorator
|
||||
|
||||
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
|
||||
"""Implementation of the llm_tool decorator."""
|
||||
# Get function signature and docstring
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func) or ""
|
||||
parsed_doc = docstring_parser.parse(doc)
|
||||
|
||||
# Determine tool name
|
||||
tool_name = name or func.__name__
|
||||
|
||||
# Determine tool description
|
||||
tool_description = description or doc
|
||||
|
||||
# Create parameter properties for JSON schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self parameter for methods
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = None if param.default is inspect.Parameter.empty else param.default
|
||||
param_required = param.default is inspect.Parameter.empty
|
||||
|
||||
# Get parameter description from docstring if available
|
||||
param_desc = ""
|
||||
for param_doc in parsed_doc.params:
|
||||
if param_doc.arg_name == param_name:
|
||||
param_desc = param_doc.description
|
||||
break
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0] # The actual type
|
||||
if len(args) > 1 and isinstance(args[1], str):
|
||||
param_desc = args[1] # The description
|
||||
|
||||
# Create property for parameter
|
||||
param_schema = {
|
||||
"type": _get_json_type(param_type),
|
||||
"description": param_desc,
|
||||
"title": param_name.replace("_", " ").title()
|
||||
}
|
||||
|
||||
# Add default value if available
|
||||
if param_default is not None:
|
||||
param_schema["default"] = param_default
|
||||
|
||||
properties[param_name] = param_schema
|
||||
|
||||
# Add to required list if no default value
|
||||
if param_required:
|
||||
required.append(param_name)
|
||||
|
||||
# Create JSON schema
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Create Pydantic model for args schema
|
||||
field_definitions = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = ... if param.default is inspect.Parameter.empty else param.default
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0]
|
||||
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
|
||||
else:
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default))
|
||||
|
||||
# Create args schema model
|
||||
model_name = f"{tool_name.title().replace('_', '')}Schema"
|
||||
args_schema = create_model(model_name, **field_definitions)
|
||||
|
||||
# 根据原始函数是否是异步函数来创建相应类型的包装函数
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Attach metadata to function
|
||||
wrapper.is_llm_tool = True
|
||||
wrapper.tool_name = tool_name
|
||||
wrapper.tool_description = tool_description
|
||||
wrapper.json_schema = schema
|
||||
wrapper.args_schema = args_schema
|
||||
|
||||
# Register the tool
|
||||
_TOOL_REGISTRY[tool_name] = wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_tools() -> Dict[str, Callable]:
|
||||
"""
|
||||
Get all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping tool names to their corresponding functions.
|
||||
"""
|
||||
return _TOOL_REGISTRY
|
||||
|
||||
def get_tool_schemas() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get JSON schemas for all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
|
||||
"""
|
||||
return [tool.json_schema for tool in _TOOL_REGISTRY.values()]
|
||||
|
||||
def _get_json_type(python_type: Any) -> str:
|
||||
"""
|
||||
Convert Python type to JSON schema type.
|
||||
|
||||
Args:
|
||||
python_type: Python type annotation
|
||||
|
||||
Returns:
|
||||
Corresponding JSON schema type as string
|
||||
"""
|
||||
if python_type is str:
|
||||
return "string"
|
||||
elif python_type is int:
|
||||
return "integer"
|
||||
elif python_type is float:
|
||||
return "number"
|
||||
elif python_type is bool:
|
||||
return "boolean"
|
||||
elif python_type is list or python_type is List:
|
||||
return "array"
|
||||
elif python_type is dict or python_type is Dict:
|
||||
return "object"
|
||||
else:
|
||||
# Default to string for complex types
|
||||
return "string"
|
||||
75
mars_toolkit/core/utils.py
Normal file
75
mars_toolkit/core/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
import boto3
|
||||
import logging
|
||||
import logging.config
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Material Project
|
||||
mp_api_key: Optional[str] = Field(None, env="MP_API_KEY")
|
||||
mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT")
|
||||
mp_topk: Optional[int] = Field(3, env="MP_TOPK")
|
||||
|
||||
# Proxy
|
||||
http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
|
||||
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
|
||||
|
||||
# FairChem
|
||||
fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH")
|
||||
fmax: Optional[float] = Field(0.05, env="FMAX")
|
||||
|
||||
# MinIO
|
||||
minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT")
|
||||
internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT")
|
||||
minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY")
|
||||
minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY")
|
||||
minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET")
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
def setup_logging():
|
||||
"""配置日志记录"""
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
log_file_path = os.path.join(parent_dir, 'mars_toolkit.log')
|
||||
|
||||
logging.config.dictConfig({
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'standard': {
|
||||
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
'datefmt': '%Y-%m-%d %H:%M:%S'
|
||||
},
|
||||
},
|
||||
'handlers': {
|
||||
'console': {
|
||||
'level': 'INFO',
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': 'standard'
|
||||
},
|
||||
'file': {
|
||||
'level': 'DEBUG',
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'filename': log_file_path,
|
||||
'maxBytes': 10485760, # 10MB
|
||||
'backupCount': 5,
|
||||
'formatter': 'standard'
|
||||
}
|
||||
},
|
||||
'loggers': {
|
||||
'': {
|
||||
'handlers': ['console', 'file'],
|
||||
'level': 'INFO',
|
||||
'propagate': True
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
# 初始化配置
|
||||
settings = Settings()
|
||||
Reference in New Issue
Block a user