76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
|
||
|
||
import asyncio
|
||
import json
|
||
import requests
|
||
import codecs
|
||
from typing import Dict, Any
|
||
|
||
from mars_toolkit.core.llm_tools import llm_tool
|
||
from mars_toolkit.core.config import config
|
||
|
||
@llm_tool(
|
||
name="retrieval_from_knowledge_base",
|
||
description="Retrieve information from local materials science literature knowledge base"
|
||
)
|
||
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
|
||
"""
|
||
检索本地材料科学文献知识库中的相关信息
|
||
|
||
Args:
|
||
query: 查询字符串,如材料名称"CsPbBr3"
|
||
topk: 返回结果数量,默认3条
|
||
|
||
Returns:
|
||
包含文档ID、标题和相关性分数的字典
|
||
"""
|
||
# 设置Dify API的URL端点
|
||
url = f'{config.DIFY_ROOT_URL}/v1/chat-messages'
|
||
|
||
# 配置请求头,包含API密钥和内容类型
|
||
headers = {
|
||
'Authorization': f'Bearer {config.DIFY_API_KEY}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
# 准备请求数据
|
||
data = {
|
||
"inputs": {"topK": topk}, # 设置返回的最大结果数量
|
||
"query": query, # 设置查询字符串
|
||
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
|
||
"conversation_id": "", # 不使用会话ID,每次都是独立查询
|
||
"user": "abc-123" # 用户标识符
|
||
}
|
||
|
||
try:
|
||
# 发送POST请求到Dify API并获取响应
|
||
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
|
||
response = requests.post(url, headers=headers, json=data, timeout=1111)
|
||
|
||
# 获取响应文本
|
||
response_text = response.text
|
||
|
||
# 解码响应文本中的Unicode转义序列
|
||
response_text = codecs.decode(response_text, 'unicode_escape')
|
||
|
||
# 将响应文本解析为JSON对象
|
||
result_json = json.loads(response_text)
|
||
|
||
# 从响应中提取元数据
|
||
metadata = result_json.get("metadata", {})
|
||
|
||
# 构建包含关键信息的结果字典
|
||
useful_info = {
|
||
"id": metadata.get("document_id"), # 文档ID
|
||
"title": result_json.get("title"), # 文档标题
|
||
"content": result_json.get("answer", ""), # 内容字段,使用'answer'字段存储内容
|
||
"score": metadata.get("score") # 相关性分数
|
||
}
|
||
|
||
# 返回提取的有用信息
|
||
return json.dumps(useful_info, ensure_ascii=False, indent=2)
|
||
|
||
except Exception as e:
|
||
# 捕获并处理所有可能的异常,返回错误信息
|
||
return f"错误: {str(e)}"
|