Files
multi_mcp/server.py
2025-05-09 14:16:33 +08:00

307 lines
9.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Mars Toolkit MCP Server implementation."""
import anyio
import asyncio
import click
import json
import logging
import os
import sys
import traceback
from typing import Any, Dict, List, Optional, Union
import time
from prompts.material_synthesis import create_messages
# 添加mars_toolkit模块的路径
sys.path.append('/home/ubuntu/50T/lzy/mars-mcp')
import mcp.types as types
from mcp.server.lowlevel import Server
# 导入提示词处理器
#from prompts.material_synthesis import register_prompt_handlers
# 导入Mars Toolkit工具函数
try:
# 获取当前时间
from mars_toolkit.misc.misc_tools import get_current_time
# 网络搜索
from mars_toolkit.query.web_search import search_online
# 从Materials Project查询材料属性
from mars_toolkit.query.mp_query import search_material_property_from_material_project
# 从Materials Project获取晶体结构
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
# 从化学式获取Materials Project ID
from mars_toolkit.query.mp_query import get_mpid_from_formula
# 优化晶体结构
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
# 生成材料
from mars_toolkit.compute.material_gen import generate_material
# 从OQMD获取化学成分
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
# 从知识库检索
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
# 预测属性
from mars_toolkit.compute.property_pred import predict_properties
# 获取所有工具函数
from mars_toolkit import get_tools, get_tool_schemas
MARS_TOOLKIT_AVAILABLE = True
except ImportError as e:
print(f"警告: 无法导入Mars Toolkit: {e}", file=sys.stderr)
MARS_TOOLKIT_AVAILABLE = False
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
app = Server("mars-toolkit-server")
async def call_mars_toolkit_function(func_name: str, arguments: Dict[str, Any]) -> Any:
"""
调用Mars Toolkit中的工具函数
Args:
func_name: 工具函数名称
arguments: 工具函数参数
Returns:
工具函数的执行结果
"""
if not MARS_TOOLKIT_AVAILABLE:
raise ValueError("Mars Toolkit不可用")
# 获取所有注册的工具函数
tools = get_tools()
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
raise ValueError(f"函数 '{func_name}' 不存在于工具函数字典中")
# 获取对应的工具函数
tool_func = tools[func_name]
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
print("result1",result)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
return result
def get_tool_schemas_dict() -> Dict[str, Dict[str, Any]]:
"""
获取所有工具函数的模式字典
Returns:
工具函数名称到模式的映射字典
"""
if not MARS_TOOLKIT_AVAILABLE:
return {}
schemas = get_tool_schemas()
schemas_dict = {}
for schema in schemas:
func_name = schema["function"]["name"]
schemas_dict[func_name] = schema
return schemas_dict
@click.command()
@click.option("--port", default=5666, help="Port to listen on for SSE")
@click.option(
"--transport",
type=click.Choice(["stdio", "sse"]),
default="sse",
help="Transport type",
)
def main(port: int, transport: str='SSE') -> int:
"""
Mars Toolkit MCP Server主函数
Args:
port: SSE传输的端口号
transport: 传输类型stdio或sse
Returns:
退出码
"""
if not MARS_TOOLKIT_AVAILABLE:
print("错误: Mars Toolkit不可用请确保已正确安装", file=sys.stderr)
return 1
# 获取工具函数模式字典
schemas_dict = get_tool_schemas_dict()
# 注册提示词处理器
#register_prompt_handlers(app)
@app.list_prompts()
async def list_prompts() -> list[types.Prompt]:
return [
types.Prompt(
name="material_synthesis",
description="生成材料并设计合成方案使用mermaid绘制合成流程图",
arguments=[
types.PromptArgument(
name="properties",
description="材料性质及其值的JSON字符串例如 {\"dft_band_gap\": \"2.0\"}",
required=False,
),
types.PromptArgument(
name="batch_size",
description="生成材料的数量默认为2",
required=False,
),
],
)
]
@app.get_prompt()
async def get_prompt(
name: str, arguments: dict[str, str] | None = None
) -> types.GetPromptResult:
if name != "material_synthesis":
raise ValueError(f"未知的提示词: {name}")
if arguments is None:
arguments = {}
# 解析properties参数
properties = {}
if "properties" in arguments and arguments["properties"]:
try:
import json
properties = json.loads(arguments["properties"])
except json.JSONDecodeError:
properties = {}
# 解析batch_size参数
batch_size = 2 # 默认值
if "batch_size" in arguments and arguments["batch_size"]:
try:
batch_size = int(arguments["batch_size"])
except ValueError:
pass # 使用默认值
return types.GetPromptResult(
messages=create_messages(properties=properties, batch_size=batch_size),
description="生成材料并设计合成方案使用mermaid绘制合成流程图",
)
@app.call_tool()
async def call_tool(
name: str, arguments: Dict[str, Any]
) -> List[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""
调用工具函数
Args:
name: 工具函数名称
arguments: 工具函数参数
Returns:
工具函数的执行结果
"""
try:
print(f"调用{name},参数为{arguments}")
result = await call_mars_toolkit_function(name, arguments)
print("result",result)
# 将结果转换为字符串
if isinstance(result, (dict, list)):
result_str = json.dumps(result, ensure_ascii=False, indent=2)
else:
result_str = str(result)
return [types.TextContent(type="text", text=result_str)]
except Exception as e:
error_msg = f"调用工具函数 {name} 时出错: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
@app.list_tools()
async def list_tools() -> List[types.Tool]:
"""
列出所有可用的工具函数
Returns:
工具函数列表
"""
tools = []
print("列举所有可用的工具函数")
for func_name, schema in schemas_dict.items():
# 获取函数描述
description = schema["function"].get("description", f"Mars Toolkit工具: {func_name}")
# 获取参数模式
parameters = schema["function"].get("parameters", {})
# 创建工具
tool = types.Tool(
name=func_name,
description=description,
inputSchema=parameters,
)
tools.append(tool)
return tools
if transport == "sse":
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.routing import Mount, Route
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
import uvicorn
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
else:
from mcp.server.stdio import stdio_server
async def arun():
async with stdio_server() as streams:
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
anyio.run(arun)
return 0
if __name__ == "__main__":
print(get_tool_schemas_dict())
main()