mcp,生成数据代码
This commit is contained in:
306
server.py
Executable file
306
server.py
Executable file
@@ -0,0 +1,306 @@
|
||||
"""Mars Toolkit MCP Server implementation."""
|
||||
|
||||
import anyio
|
||||
import asyncio
|
||||
import click
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from prompts.material_synthesis import create_messages
|
||||
|
||||
# 添加mars_toolkit模块的路径
|
||||
sys.path.append('/home/ubuntu/50T/lzy/mars-mcp')
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel import Server
|
||||
|
||||
# 导入提示词处理器
|
||||
#from prompts.material_synthesis import register_prompt_handlers
|
||||
|
||||
# 导入Mars Toolkit工具函数
|
||||
try:
|
||||
# 获取当前时间
|
||||
from mars_toolkit.misc.misc_tools import get_current_time
|
||||
# 网络搜索
|
||||
from mars_toolkit.query.web_search import search_online
|
||||
# 从Materials Project查询材料属性
|
||||
from mars_toolkit.query.mp_query import search_material_property_from_material_project
|
||||
# 从Materials Project获取晶体结构
|
||||
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
|
||||
# 从化学式获取Materials Project ID
|
||||
from mars_toolkit.query.mp_query import get_mpid_from_formula
|
||||
# 优化晶体结构
|
||||
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
|
||||
# 生成材料
|
||||
from mars_toolkit.compute.material_gen import generate_material
|
||||
# 从OQMD获取化学成分
|
||||
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||
# 从知识库检索
|
||||
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||
# 预测属性
|
||||
from mars_toolkit.compute.property_pred import predict_properties
|
||||
|
||||
# 获取所有工具函数
|
||||
from mars_toolkit import get_tools, get_tool_schemas
|
||||
|
||||
MARS_TOOLKIT_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
print(f"警告: 无法导入Mars Toolkit: {e}", file=sys.stderr)
|
||||
MARS_TOOLKIT_AVAILABLE = False
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
app = Server("mars-toolkit-server")
|
||||
|
||||
|
||||
async def call_mars_toolkit_function(func_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
调用Mars Toolkit中的工具函数
|
||||
|
||||
Args:
|
||||
func_name: 工具函数名称
|
||||
arguments: 工具函数参数
|
||||
|
||||
Returns:
|
||||
工具函数的执行结果
|
||||
"""
|
||||
if not MARS_TOOLKIT_AVAILABLE:
|
||||
raise ValueError("Mars Toolkit不可用")
|
||||
|
||||
# 获取所有注册的工具函数
|
||||
tools = get_tools()
|
||||
|
||||
# 检查函数名是否存在于工具函数字典中
|
||||
if func_name not in tools:
|
||||
raise ValueError(f"函数 '{func_name}' 不存在于工具函数字典中")
|
||||
|
||||
# 获取对应的工具函数
|
||||
tool_func = tools[func_name]
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
print("result1",result)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_tool_schemas_dict() -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取所有工具函数的模式字典
|
||||
|
||||
Returns:
|
||||
工具函数名称到模式的映射字典
|
||||
"""
|
||||
if not MARS_TOOLKIT_AVAILABLE:
|
||||
return {}
|
||||
|
||||
schemas = get_tool_schemas()
|
||||
schemas_dict = {}
|
||||
|
||||
for schema in schemas:
|
||||
func_name = schema["function"]["name"]
|
||||
schemas_dict[func_name] = schema
|
||||
|
||||
return schemas_dict
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--port", default=5666, help="Port to listen on for SSE")
|
||||
@click.option(
|
||||
"--transport",
|
||||
type=click.Choice(["stdio", "sse"]),
|
||||
default="sse",
|
||||
help="Transport type",
|
||||
)
|
||||
def main(port: int, transport: str='SSE') -> int:
|
||||
"""
|
||||
Mars Toolkit MCP Server主函数
|
||||
|
||||
Args:
|
||||
port: SSE传输的端口号
|
||||
transport: 传输类型,stdio或sse
|
||||
|
||||
Returns:
|
||||
退出码
|
||||
"""
|
||||
if not MARS_TOOLKIT_AVAILABLE:
|
||||
print("错误: Mars Toolkit不可用,请确保已正确安装", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
# 获取工具函数模式字典
|
||||
schemas_dict = get_tool_schemas_dict()
|
||||
|
||||
# 注册提示词处理器
|
||||
#register_prompt_handlers(app)
|
||||
|
||||
|
||||
|
||||
|
||||
@app.list_prompts()
|
||||
async def list_prompts() -> list[types.Prompt]:
|
||||
return [
|
||||
types.Prompt(
|
||||
name="material_synthesis",
|
||||
description="生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||||
arguments=[
|
||||
types.PromptArgument(
|
||||
name="properties",
|
||||
description="材料性质及其值的JSON字符串,例如 {\"dft_band_gap\": \"2.0\"}",
|
||||
required=False,
|
||||
),
|
||||
types.PromptArgument(
|
||||
name="batch_size",
|
||||
description="生成材料的数量,默认为2",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
@app.get_prompt()
|
||||
async def get_prompt(
|
||||
name: str, arguments: dict[str, str] | None = None
|
||||
) -> types.GetPromptResult:
|
||||
if name != "material_synthesis":
|
||||
raise ValueError(f"未知的提示词: {name}")
|
||||
|
||||
if arguments is None:
|
||||
arguments = {}
|
||||
|
||||
# 解析properties参数
|
||||
properties = {}
|
||||
if "properties" in arguments and arguments["properties"]:
|
||||
try:
|
||||
import json
|
||||
properties = json.loads(arguments["properties"])
|
||||
except json.JSONDecodeError:
|
||||
properties = {}
|
||||
|
||||
# 解析batch_size参数
|
||||
batch_size = 2 # 默认值
|
||||
if "batch_size" in arguments and arguments["batch_size"]:
|
||||
try:
|
||||
batch_size = int(arguments["batch_size"])
|
||||
except ValueError:
|
||||
pass # 使用默认值
|
||||
|
||||
return types.GetPromptResult(
|
||||
messages=create_messages(properties=properties, batch_size=batch_size),
|
||||
description="生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||||
)
|
||||
@app.call_tool()
|
||||
async def call_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
"""
|
||||
调用工具函数
|
||||
|
||||
Args:
|
||||
name: 工具函数名称
|
||||
arguments: 工具函数参数
|
||||
|
||||
Returns:
|
||||
工具函数的执行结果
|
||||
"""
|
||||
try:
|
||||
print(f"调用{name},参数为{arguments}")
|
||||
result = await call_mars_toolkit_function(name, arguments)
|
||||
print("result",result)
|
||||
# 将结果转换为字符串
|
||||
if isinstance(result, (dict, list)):
|
||||
result_str = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
result_str = str(result)
|
||||
|
||||
return [types.TextContent(type="text", text=result_str)]
|
||||
except Exception as e:
|
||||
error_msg = f"调用工具函数 {name} 时出错: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
@app.list_tools()
|
||||
async def list_tools() -> List[types.Tool]:
|
||||
"""
|
||||
列出所有可用的工具函数
|
||||
|
||||
Returns:
|
||||
工具函数列表
|
||||
"""
|
||||
tools = []
|
||||
print("列举所有可用的工具函数")
|
||||
for func_name, schema in schemas_dict.items():
|
||||
# 获取函数描述
|
||||
description = schema["function"].get("description", f"Mars Toolkit工具: {func_name}")
|
||||
|
||||
# 获取参数模式
|
||||
parameters = schema["function"].get("parameters", {})
|
||||
|
||||
# 创建工具
|
||||
tool = types.Tool(
|
||||
name=func_name,
|
||||
description=description,
|
||||
inputSchema=parameters,
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
if transport == "sse":
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await app.run(
|
||||
streams[0], streams[1], app.create_initialization_options()
|
||||
)
|
||||
|
||||
starlette_app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
|
||||
else:
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
async def arun():
|
||||
async with stdio_server() as streams:
|
||||
await app.run(
|
||||
streams[0], streams[1], app.create_initialization_options()
|
||||
)
|
||||
|
||||
anyio.run(arun)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_tool_schemas_dict())
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user