307 lines
9.0 KiB
Python
Executable File
307 lines
9.0 KiB
Python
Executable File
"""Mars Toolkit MCP Server implementation."""
|
||
|
||
import anyio
|
||
import asyncio
|
||
import click
|
||
import json
|
||
import logging
|
||
import os
|
||
import sys
|
||
import traceback
|
||
from typing import Any, Dict, List, Optional, Union
|
||
import time
|
||
from prompts.material_synthesis import create_messages
|
||
|
||
# 添加mars_toolkit模块的路径
|
||
sys.path.append('/home/ubuntu/50T/lzy/mars-mcp')
|
||
|
||
import mcp.types as types
|
||
from mcp.server.lowlevel import Server
|
||
|
||
# 导入提示词处理器
|
||
#from prompts.material_synthesis import register_prompt_handlers
|
||
|
||
# 导入Mars Toolkit工具函数
|
||
try:
|
||
# 获取当前时间
|
||
from mars_toolkit.misc.misc_tools import get_current_time
|
||
# 网络搜索
|
||
from mars_toolkit.query.web_search import search_online
|
||
# 从Materials Project查询材料属性
|
||
from mars_toolkit.query.mp_query import search_material_property_from_material_project
|
||
# 从Materials Project获取晶体结构
|
||
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
|
||
# 从化学式获取Materials Project ID
|
||
from mars_toolkit.query.mp_query import get_mpid_from_formula
|
||
# 优化晶体结构
|
||
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
|
||
# 生成材料
|
||
from mars_toolkit.compute.material_gen import generate_material
|
||
# 从OQMD获取化学成分
|
||
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||
# 从知识库检索
|
||
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||
# 预测属性
|
||
from mars_toolkit.compute.property_pred import predict_properties
|
||
|
||
# 获取所有工具函数
|
||
from mars_toolkit import get_tools, get_tool_schemas
|
||
|
||
MARS_TOOLKIT_AVAILABLE = True
|
||
except ImportError as e:
|
||
print(f"警告: 无法导入Mars Toolkit: {e}", file=sys.stderr)
|
||
MARS_TOOLKIT_AVAILABLE = False
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
app = Server("mars-toolkit-server")
|
||
|
||
|
||
async def call_mars_toolkit_function(func_name: str, arguments: Dict[str, Any]) -> Any:
|
||
"""
|
||
调用Mars Toolkit中的工具函数
|
||
|
||
Args:
|
||
func_name: 工具函数名称
|
||
arguments: 工具函数参数
|
||
|
||
Returns:
|
||
工具函数的执行结果
|
||
"""
|
||
if not MARS_TOOLKIT_AVAILABLE:
|
||
raise ValueError("Mars Toolkit不可用")
|
||
|
||
# 获取所有注册的工具函数
|
||
tools = get_tools()
|
||
|
||
# 检查函数名是否存在于工具函数字典中
|
||
if func_name not in tools:
|
||
raise ValueError(f"函数 '{func_name}' 不存在于工具函数字典中")
|
||
|
||
# 获取对应的工具函数
|
||
tool_func = tools[func_name]
|
||
|
||
# 调用工具函数
|
||
if asyncio.iscoroutinefunction(tool_func):
|
||
# 如果是异步函数,使用await调用
|
||
result = await tool_func(**arguments)
|
||
print("result1",result)
|
||
else:
|
||
# 如果是同步函数,直接调用
|
||
result = tool_func(**arguments)
|
||
|
||
return result
|
||
|
||
|
||
def get_tool_schemas_dict() -> Dict[str, Dict[str, Any]]:
|
||
"""
|
||
获取所有工具函数的模式字典
|
||
|
||
Returns:
|
||
工具函数名称到模式的映射字典
|
||
"""
|
||
if not MARS_TOOLKIT_AVAILABLE:
|
||
return {}
|
||
|
||
schemas = get_tool_schemas()
|
||
schemas_dict = {}
|
||
|
||
for schema in schemas:
|
||
func_name = schema["function"]["name"]
|
||
schemas_dict[func_name] = schema
|
||
|
||
return schemas_dict
|
||
|
||
|
||
@click.command()
|
||
@click.option("--port", default=5666, help="Port to listen on for SSE")
|
||
@click.option(
|
||
"--transport",
|
||
type=click.Choice(["stdio", "sse"]),
|
||
default="sse",
|
||
help="Transport type",
|
||
)
|
||
def main(port: int, transport: str='SSE') -> int:
|
||
"""
|
||
Mars Toolkit MCP Server主函数
|
||
|
||
Args:
|
||
port: SSE传输的端口号
|
||
transport: 传输类型,stdio或sse
|
||
|
||
Returns:
|
||
退出码
|
||
"""
|
||
if not MARS_TOOLKIT_AVAILABLE:
|
||
print("错误: Mars Toolkit不可用,请确保已正确安装", file=sys.stderr)
|
||
return 1
|
||
|
||
|
||
# 获取工具函数模式字典
|
||
schemas_dict = get_tool_schemas_dict()
|
||
|
||
# 注册提示词处理器
|
||
#register_prompt_handlers(app)
|
||
|
||
|
||
|
||
|
||
@app.list_prompts()
|
||
async def list_prompts() -> list[types.Prompt]:
|
||
return [
|
||
types.Prompt(
|
||
name="material_synthesis",
|
||
description="生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||
arguments=[
|
||
types.PromptArgument(
|
||
name="properties",
|
||
description="材料性质及其值的JSON字符串,例如 {\"dft_band_gap\": \"2.0\"}",
|
||
required=False,
|
||
),
|
||
types.PromptArgument(
|
||
name="batch_size",
|
||
description="生成材料的数量,默认为2",
|
||
required=False,
|
||
),
|
||
],
|
||
)
|
||
]
|
||
|
||
@app.get_prompt()
|
||
async def get_prompt(
|
||
name: str, arguments: dict[str, str] | None = None
|
||
) -> types.GetPromptResult:
|
||
if name != "material_synthesis":
|
||
raise ValueError(f"未知的提示词: {name}")
|
||
|
||
if arguments is None:
|
||
arguments = {}
|
||
|
||
# 解析properties参数
|
||
properties = {}
|
||
if "properties" in arguments and arguments["properties"]:
|
||
try:
|
||
import json
|
||
properties = json.loads(arguments["properties"])
|
||
except json.JSONDecodeError:
|
||
properties = {}
|
||
|
||
# 解析batch_size参数
|
||
batch_size = 2 # 默认值
|
||
if "batch_size" in arguments and arguments["batch_size"]:
|
||
try:
|
||
batch_size = int(arguments["batch_size"])
|
||
except ValueError:
|
||
pass # 使用默认值
|
||
|
||
return types.GetPromptResult(
|
||
messages=create_messages(properties=properties, batch_size=batch_size),
|
||
description="生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||
)
|
||
@app.call_tool()
|
||
async def call_tool(
|
||
name: str, arguments: Dict[str, Any]
|
||
) -> List[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||
"""
|
||
调用工具函数
|
||
|
||
Args:
|
||
name: 工具函数名称
|
||
arguments: 工具函数参数
|
||
|
||
Returns:
|
||
工具函数的执行结果
|
||
"""
|
||
try:
|
||
print(f"调用{name},参数为{arguments}")
|
||
result = await call_mars_toolkit_function(name, arguments)
|
||
print("result",result)
|
||
# 将结果转换为字符串
|
||
if isinstance(result, (dict, list)):
|
||
result_str = json.dumps(result, ensure_ascii=False, indent=2)
|
||
else:
|
||
result_str = str(result)
|
||
|
||
return [types.TextContent(type="text", text=result_str)]
|
||
except Exception as e:
|
||
error_msg = f"调用工具函数 {name} 时出错: {str(e)}\n{traceback.format_exc()}"
|
||
logger.error(error_msg)
|
||
return [types.TextContent(type="text", text=error_msg)]
|
||
|
||
@app.list_tools()
|
||
async def list_tools() -> List[types.Tool]:
|
||
"""
|
||
列出所有可用的工具函数
|
||
|
||
Returns:
|
||
工具函数列表
|
||
"""
|
||
tools = []
|
||
print("列举所有可用的工具函数")
|
||
for func_name, schema in schemas_dict.items():
|
||
# 获取函数描述
|
||
description = schema["function"].get("description", f"Mars Toolkit工具: {func_name}")
|
||
|
||
# 获取参数模式
|
||
parameters = schema["function"].get("parameters", {})
|
||
|
||
# 创建工具
|
||
tool = types.Tool(
|
||
name=func_name,
|
||
description=description,
|
||
inputSchema=parameters,
|
||
)
|
||
|
||
tools.append(tool)
|
||
|
||
return tools
|
||
|
||
if transport == "sse":
|
||
from mcp.server.sse import SseServerTransport
|
||
from starlette.applications import Starlette
|
||
from starlette.routing import Mount, Route
|
||
|
||
sse = SseServerTransport("/messages/")
|
||
|
||
async def handle_sse(request):
|
||
async with sse.connect_sse(
|
||
request.scope, request.receive, request._send
|
||
) as streams:
|
||
await app.run(
|
||
streams[0], streams[1], app.create_initialization_options()
|
||
)
|
||
|
||
starlette_app = Starlette(
|
||
debug=True,
|
||
routes=[
|
||
Route("/sse", endpoint=handle_sse),
|
||
Mount("/messages/", app=sse.handle_post_message),
|
||
],
|
||
)
|
||
|
||
import uvicorn
|
||
|
||
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
|
||
else:
|
||
from mcp.server.stdio import stdio_server
|
||
|
||
async def arun():
|
||
async with stdio_server() as streams:
|
||
await app.run(
|
||
streams[0], streams[1], app.create_initialization_options()
|
||
)
|
||
|
||
anyio.run(arun)
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print(get_tool_schemas_dict())
|
||
main()
|
||
|