工具函数零散能用版
This commit is contained in:
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
/.venv
|
||||
# 忽略整个mattergen目录中的所有更改
|
||||
mattergen/
|
||||
model_agent_test.py
|
||||
pyproject.toml
|
||||
/pretrained_models
|
||||
/mcp-python-sdk
|
||||
BIN
__pycache__/mattergen_wrapper.cpython-310.pyc
Normal file
BIN
__pycache__/mattergen_wrapper.cpython-310.pyc
Normal file
Binary file not shown.
BIN
__pycache__/mattergen_wrapper.cpython-312.pyc
Normal file
BIN
__pycache__/mattergen_wrapper.cpython-312.pyc
Normal file
Binary file not shown.
401
agent_test.py
Normal file
401
agent_test.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import asyncio
|
||||
from tools_for_ms import *
|
||||
from api_key import *
|
||||
from openai import OpenAI
|
||||
import json
|
||||
from typing import Dict, List, Any, Union, Optional
|
||||
from rich.console import Console
|
||||
|
||||
# 获取工具
|
||||
tools = get_tools()
|
||||
tools_schema = get_tool_schemas()
|
||||
tool_map = {tool_name: tool for tool_name, tool in tools.items()}
|
||||
|
||||
console = Console()
|
||||
console.print(tools_schema)
|
||||
|
||||
class DualModelAgent:
|
||||
"""
|
||||
同时支持 qwq 和 gpt4o 两种大模型的代理类
|
||||
处理两种不同的返回值格式并提供统一的工具调用接口
|
||||
"""
|
||||
|
||||
def __init__(self, qwq_model_name: str = "qwq-32b", gpt_model_name: str = "gpt-4o"):
|
||||
"""
|
||||
初始化两个模型的客户端
|
||||
|
||||
Args:
|
||||
qwq_model_name: qwq 模型名称
|
||||
gpt_model_name: gpt 模型名称
|
||||
"""
|
||||
# 初始化 qwq 客户端
|
||||
self.qwq_client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_URL,
|
||||
)
|
||||
|
||||
# 初始化 gpt 客户端 (可以使用不同的 API 密钥和 URL)
|
||||
self.gpt_client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_URL,
|
||||
)
|
||||
|
||||
# 模型名称
|
||||
self.qwq_model_name = qwq_model_name
|
||||
self.gpt_model_name = gpt_model_name
|
||||
|
||||
# 定义工具列表
|
||||
self.tools = tools_schema
|
||||
|
||||
def get_qwq_response(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 qwq 模型的响应(返回字典格式)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
字典格式的响应
|
||||
"""
|
||||
completion = self.qwq_client.chat.completions.create(
|
||||
model=self.qwq_model_name,
|
||||
messages=messages,
|
||||
tools=self.tools,
|
||||
temperature=0.6,
|
||||
tool_choice='auto'
|
||||
)
|
||||
# qwq 返回的是对象,需要转换为字典
|
||||
return completion.model_dump()
|
||||
|
||||
def get_gpt_response(self, messages: List[Dict[str, Any]]) -> Any:
|
||||
"""
|
||||
获取 gpt 模型的响应(返回对象格式)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
对象格式的响应
|
||||
"""
|
||||
completion = self.gpt_client.chat.completions.create(
|
||||
model=self.gpt_model_name,
|
||||
messages=messages,
|
||||
tools=self.tools,
|
||||
tool_choice="auto",
|
||||
temperature=0.6,
|
||||
)
|
||||
# gpt 返回的是对象,直接返回
|
||||
return completion
|
||||
|
||||
def extract_tool_calls_from_qwq(self, response: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
从 qwq 响应中提取工具调用信息
|
||||
|
||||
Args:
|
||||
response: qwq 响应字典
|
||||
|
||||
Returns:
|
||||
工具调用列表,如果没有则返回 None
|
||||
"""
|
||||
assistant_message = response['choices'][0]['message']
|
||||
console.print("assistant_message",assistant_message)
|
||||
return assistant_message.get('tool_calls')
|
||||
|
||||
def extract_tool_calls_from_gpt(self, response: Any) -> Optional[List[Any]]:
|
||||
"""
|
||||
从 gpt 响应中提取工具调用信息
|
||||
|
||||
Args:
|
||||
response: gpt 响应对象
|
||||
|
||||
Returns:
|
||||
工具调用列表,如果没有则返回 None
|
||||
"""
|
||||
if hasattr(response.choices[0].message, 'tool_calls'):
|
||||
return response.choices[0].message.tool_calls
|
||||
return None
|
||||
|
||||
def extract_content_from_qwq(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
从 qwq 响应中提取内容
|
||||
|
||||
Args:
|
||||
response: qwq 响应字典
|
||||
|
||||
Returns:
|
||||
内容字符串
|
||||
"""
|
||||
content = response['choices'][0]['message'].get('content')
|
||||
return content if content is not None else ""
|
||||
|
||||
def extract_content_from_gpt(self, response: Any) -> str:
|
||||
"""
|
||||
从 gpt 响应中提取内容
|
||||
|
||||
Args:
|
||||
response: gpt 响应对象
|
||||
|
||||
Returns:
|
||||
内容字符串
|
||||
"""
|
||||
content = response.choices[0].message.content
|
||||
return content if content is not None else ""
|
||||
|
||||
def extract_finish_reason_from_qwq(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
从 qwq 响应中提取完成原因
|
||||
|
||||
Args:
|
||||
response: qwq 响应字典
|
||||
|
||||
Returns:
|
||||
完成原因
|
||||
"""
|
||||
return response['choices'][0]['finish_reason']
|
||||
|
||||
def extract_finish_reason_from_gpt(self, response: Any) -> str:
|
||||
"""
|
||||
从 gpt 响应中提取完成原因
|
||||
|
||||
Args:
|
||||
response: gpt 响应对象
|
||||
|
||||
Returns:
|
||||
完成原因
|
||||
"""
|
||||
return response.choices[0].finish_reason
|
||||
|
||||
async def call_tool(self, tool_name: str, tool_arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具函数
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
if tool_name in tools:
|
||||
tool_function = tools[tool_name]
|
||||
try:
|
||||
# 异步调用工具函数
|
||||
tool_result = await tool_function(**tool_arguments)
|
||||
return tool_result
|
||||
except Exception as e:
|
||||
return f"工具调用错误: {str(e)}"
|
||||
else:
|
||||
return f"未找到工具: {tool_name}"
|
||||
|
||||
async def chat_with_qwq(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
|
||||
"""
|
||||
与 qwq 模型对话,支持工具调用
|
||||
|
||||
Args:
|
||||
messages: 初始消息列表
|
||||
max_turns: 最大对话轮数
|
||||
|
||||
Returns:
|
||||
最终回答
|
||||
"""
|
||||
|
||||
current_messages = messages.copy()
|
||||
turn = 0
|
||||
|
||||
while turn < max_turns:
|
||||
turn += 1
|
||||
console.print(f"\n[bold cyan]第 {turn} 轮 QWQ 对话[/bold cyan]")
|
||||
console.print("message",current_messages)
|
||||
# 获取 qwq 响应
|
||||
response = self.get_qwq_response(current_messages)
|
||||
assistant_message = response['choices'][0]['message']
|
||||
if assistant_message['content'] is None:
|
||||
assistant_message['content'] = ""
|
||||
console.print("message", assistant_message)
|
||||
# 将助手消息添加到上下文
|
||||
current_messages.append(assistant_message)
|
||||
|
||||
# 提取内容和工具调用
|
||||
content = self.extract_content_from_qwq(response)
|
||||
tool_calls = self.extract_tool_calls_from_qwq(response)
|
||||
finish_reason = self.extract_finish_reason_from_qwq(response)
|
||||
|
||||
console.print(f"[green]助手回复:[/green] {content}")
|
||||
|
||||
# 如果没有工具调用或已完成,返回内容
|
||||
if tool_calls is None or finish_reason != "tool_calls":
|
||||
return content
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in tool_calls:
|
||||
tool_call_name = tool_call['function']['name']
|
||||
tool_call_arguments = json.loads(tool_call['function']['arguments'])
|
||||
|
||||
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
|
||||
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
|
||||
|
||||
# 执行工具调用
|
||||
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
|
||||
console.print(f"[blue]工具结果:[/blue] {tool_result}")
|
||||
|
||||
# 添加工具结果到上下文
|
||||
current_messages.append({
|
||||
"role": "tool",
|
||||
"name": tool_call_name,
|
||||
"content": tool_result
|
||||
})
|
||||
|
||||
return "达到最大对话轮数限制"
|
||||
|
||||
async def chat_with_gpt(self, messages: List[Dict[str, Any]], max_turns: int = 5) -> str:
|
||||
"""
|
||||
与 gpt 模型对话,支持工具调用
|
||||
|
||||
Args:
|
||||
messages: 初始消息列表
|
||||
max_turns: 最大对话轮数
|
||||
|
||||
Returns:
|
||||
最终回答
|
||||
"""
|
||||
current_messages = messages.copy()
|
||||
turn = 0
|
||||
|
||||
while turn < max_turns:
|
||||
turn += 1
|
||||
console.print(f"\n[bold magenta]第 {turn} 轮 GPT 对话[/bold magenta]")
|
||||
|
||||
# 获取 gpt 响应
|
||||
response = self.get_gpt_response(current_messages)
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# 将助手消息添加到上下文
|
||||
current_messages.append(assistant_message.model_dump())
|
||||
|
||||
# 提取内容和工具调用
|
||||
content = self.extract_content_from_gpt(response)
|
||||
tool_calls = self.extract_tool_calls_from_gpt(response)
|
||||
finish_reason = self.extract_finish_reason_from_gpt(response)
|
||||
|
||||
console.print(f"[green]助手回复:[/green] {content}")
|
||||
|
||||
# 如果没有工具调用或已完成,返回内容
|
||||
if tool_calls is None or finish_reason != "tool_calls":
|
||||
return content
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in tool_calls:
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_arguments = json.loads(tool_call.function.arguments)
|
||||
|
||||
console.print(f"[yellow]调用工具:[/yellow] {tool_call_name}")
|
||||
console.print(f"[yellow]工具参数:[/yellow] {tool_call_arguments}")
|
||||
|
||||
# 执行工具调用
|
||||
tool_result = await self.call_tool(tool_call_name, tool_call_arguments)
|
||||
console.print(f"[blue]工具结果:[/blue] {tool_result}")
|
||||
|
||||
# 添加工具结果到上下文
|
||||
current_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call_name,
|
||||
"content": tool_result
|
||||
})
|
||||
|
||||
return "达到最大对话轮数限制"
|
||||
|
||||
async def chat_with_both_models(self, user_input: str, system_prompt: str = "你是一个有用的助手。") -> Dict[str, str]:
|
||||
"""
|
||||
同时与两个模型对话,比较它们的回答
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
system_prompt: 系统提示
|
||||
|
||||
Returns:
|
||||
包含两个模型回答的字典
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_input}
|
||||
]
|
||||
|
||||
# 并行调用两个模型
|
||||
qwq_task = asyncio.create_task(self.chat_with_qwq(messages))
|
||||
gpt_task = asyncio.create_task(self.chat_with_gpt(messages))
|
||||
|
||||
# 等待两个任务完成
|
||||
qwq_response, gpt_response = await asyncio.gather(qwq_task, gpt_task)
|
||||
|
||||
return {
|
||||
"qwq": qwq_response,
|
||||
"gpt": gpt_response
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
# 创建双模型代理
|
||||
agent = DualModelAgent()
|
||||
|
||||
while True:
|
||||
# 获取用户输入
|
||||
user_input = input("\n请输入问题(输入 'exit' 退出): ")
|
||||
#user_input='现在的时间'
|
||||
if user_input.lower() == 'exit':
|
||||
break
|
||||
|
||||
# 选择模型
|
||||
model_choice = input("选择模型 (1: QWQ, 2: GPT, 3: 两者): ")
|
||||
|
||||
|
||||
try:
|
||||
if model_choice == '1':
|
||||
# 使用 QWQ 模型
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": user_input}
|
||||
]
|
||||
response = await agent.chat_with_qwq(messages)
|
||||
console.print(f"\n[bold cyan]QWQ 最终回答:[/bold cyan] {response}")
|
||||
|
||||
elif model_choice == '2':
|
||||
# 使用 GPT 模型
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个有用的助手。"},
|
||||
{"role": "user", "content": user_input}
|
||||
]
|
||||
response = await agent.chat_with_gpt(messages)
|
||||
console.print(f"\n[bold magenta]GPT 最终回答:[/bold magenta] {response}")
|
||||
|
||||
elif model_choice == '3':
|
||||
# 同时使用两个模型
|
||||
responses = await agent.chat_with_both_models(user_input)
|
||||
console.print(f"\n[bold cyan]QWQ 最终回答:[/bold cyan] {responses['qwq']}")
|
||||
console.print(f"\n[bold magenta]GPT 最终回答:[/bold magenta] {responses['gpt']}")
|
||||
|
||||
else:
|
||||
console.print("[bold red]无效的选择,请输入 1、2 或 3[/bold red]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]错误:[/bold red] {str(e)}")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
asyncio.run(main())
|
||||
#from tools_for_ms.services_tools.mp_service_tools import search_material_property_from_material_project
|
||||
#asyncio.run(search_material_property_from_material_project('Fe2O3'))
|
||||
|
||||
# 测试工具函数
|
||||
|
||||
# tool_name = 'get_crystal_structures_from_materials_project'
|
||||
# result = asyncio.run(test_tool(tool_name))
|
||||
# print(result)
|
||||
|
||||
#pass
|
||||
|
||||
# 知识检索API的接口 数据库
|
||||
8
api_key.py
Normal file
8
api_key.py
Normal file
@@ -0,0 +1,8 @@
|
||||
OPENAI_API_KEY='sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d'
|
||||
OPENAI_API_URL='https://vip.apiyi.com/v1'
|
||||
|
||||
#OPENAI_API_KEY='gpustack_56f0adc61a865d22_c61cdbf601fa2cb95979d417618060e6'
|
||||
#OPENAI_API_URL='http://192.168.191.100:5080/v1'
|
||||
|
||||
|
||||
|
||||
270
execute_tool_copy.py
Normal file
270
execute_tool_copy.py
Normal file
@@ -0,0 +1,270 @@
|
||||
import json
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from tools_for_ms.llm_tools import *
|
||||
import threading
|
||||
# Create a lock for file writing
|
||||
file_lock = threading.Lock()
|
||||
from mysql.connector import pooling
|
||||
|
||||
|
||||
connection_pool = pooling.MySQLConnectionPool(
|
||||
pool_name="mypool",
|
||||
pool_size=32,
|
||||
pool_reset_session=True,
|
||||
host='localhost',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
def process_retrieval_from_knowledge_base(data):
|
||||
doi = data.get('doi')
|
||||
mp_id = data.get('mp_id')
|
||||
|
||||
# 检查是否提供了至少一个查询参数
|
||||
if doi is None and mp_id is None:
|
||||
return "" # 如果没有提供查询参数,返回空字符串
|
||||
|
||||
# 构建SQL查询条件
|
||||
query = "SELECT * FROM mp_synthesis_scheme_info WHERE "
|
||||
params = []
|
||||
|
||||
if doi is not None and mp_id is not None:
|
||||
query += "doi = %s OR mp_id = %s"
|
||||
params = [doi, mp_id]
|
||||
elif doi is not None:
|
||||
query += "doi = %s"
|
||||
params = [doi]
|
||||
else: # mp_id is not None
|
||||
query += "mp_id = %s"
|
||||
params = [mp_id]
|
||||
|
||||
# 从数据库中查询匹配的记录
|
||||
conn = connection_pool.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
try:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone() # 获取第一个匹配的记录
|
||||
finally:
|
||||
cursor.close()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# 检查是否找到匹配的记录
|
||||
if not result:
|
||||
return "" # 如果没有找到匹配记录,返回空字符串
|
||||
|
||||
# 构建markdown格式的结果
|
||||
markdown_result = ""
|
||||
|
||||
# 添加各个字段(除了doi和mp_id)
|
||||
fields = [
|
||||
"target_material",
|
||||
"reaction_string",
|
||||
"chara_structure",
|
||||
"chara_performance",
|
||||
"chara_application",
|
||||
"synthesis_schemes"
|
||||
]
|
||||
|
||||
for field in fields:
|
||||
# 获取字段内容
|
||||
field_content = result.get(field, "")
|
||||
# 只有当字段内容不为空时才添加该字段
|
||||
if field_content and field_content.strip():
|
||||
markdown_result += f"\n## {field}\n{field_content}\n\n"
|
||||
|
||||
return markdown_result # 直接返回markdown文本
|
||||
async def execute_tool_from_dict(input_dict: dict):
|
||||
"""
|
||||
从字典中提取工具函数名称和参数,并执行相应的工具函数
|
||||
|
||||
Args:
|
||||
input_dict: 字典,例如:
|
||||
{"name": "search_material_property_from_material_project",
|
||||
"arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}
|
||||
|
||||
Returns:
|
||||
工具函数的执行结果,如果工具函数不存在则返回错误信息
|
||||
"""
|
||||
try:
|
||||
# 解析输入字符串为字典
|
||||
# input_dict = json.loads(input_str)
|
||||
|
||||
# 提取函数名和参数
|
||||
func_name = input_dict.get("name")
|
||||
arguments_data = input_dict.get("arguments")
|
||||
#print('func_name', func_name)
|
||||
#print("argument", arguments_data)
|
||||
if not func_name:
|
||||
return {"status": "error", "message": "未提供函数名称"}
|
||||
|
||||
# 获取所有注册的工具函数
|
||||
tools = get_tools()
|
||||
|
||||
# 检查函数名是否存在于工具函数字典中
|
||||
if func_name not in tools:
|
||||
return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"}
|
||||
|
||||
# 获取对应的工具函数
|
||||
tool_func = tools[func_name]
|
||||
|
||||
# 处理参数
|
||||
arguments = {}
|
||||
if arguments_data:
|
||||
# 检查arguments是字符串还是字典
|
||||
if isinstance(arguments_data, dict):
|
||||
# 如果已经是字典,直接使用
|
||||
arguments = arguments_data
|
||||
elif isinstance(arguments_data, str):
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
try:
|
||||
# 尝试直接解析为JSON对象
|
||||
arguments = json.loads(arguments_data)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||
# 尝试修复常见的JSON字符串问题
|
||||
fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\')
|
||||
try:
|
||||
arguments = json.loads(fixed_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||
arguments = {"raw_string": arguments_data}
|
||||
|
||||
# 调用工具函数
|
||||
if asyncio.iscoroutinefunction(tool_func):
|
||||
# 如果是异步函数,使用await调用
|
||||
result = await tool_func(**arguments)
|
||||
else:
|
||||
# 如果是同步函数,直接调用
|
||||
result = tool_func(**arguments)
|
||||
# if func_name=='generate_material':
|
||||
# print("xxxxx",result)
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
|
||||
|
||||
|
||||
|
||||
|
||||
# # 示例用法
|
||||
# if __name__ == "__main__":
|
||||
# # 示例输入
|
||||
# input_str = '{"name": "search_material_property_from_material_project", "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}'
|
||||
|
||||
# # 调用函数
|
||||
# result = asyncio.run(execute_tool_from_string(input_str))
|
||||
# print(result)
|
||||
|
||||
|
||||
def worker(data, output_file_path):
|
||||
|
||||
try:
|
||||
# rich.console.Console().print(tools_schema)
|
||||
# print(tools_schema)
|
||||
func_contents = data["function_calls"]
|
||||
func_results = []
|
||||
formatted_results = [] # 新增一个列表来存储格式化后的结果
|
||||
for func in func_contents:
|
||||
if func.get("name") == 'retrieval_from_knowledge_base':
|
||||
func_name = func.get("name")
|
||||
arguments_data = func.get("arguments")
|
||||
# print('func_name', func_name)
|
||||
# print("argument", arguments_data)
|
||||
result = process_retrieval_from_knowledge_base(data)
|
||||
func_results.append({"function": func['name'], "result": result})
|
||||
# 格式化结果
|
||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
else:
|
||||
result = asyncio.run(execute_tool_from_dict(func))
|
||||
func_results.append({"function": func['name'], "result": result})
|
||||
# 格式化结果
|
||||
func_name = func.get("name")
|
||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
# 将所有格式化后的结果连接起来
|
||||
final_result = "\n\n\n".join(formatted_results)
|
||||
data['obeservation']=final_result
|
||||
# print("#"*50,"start","#"*50)
|
||||
# print(data['obeservation'])
|
||||
# print("#"*50,'end',"#"*50)
|
||||
#return final_result # 返回格式化后的结果,而不是固定消息
|
||||
|
||||
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(data) # obeservation . data
|
||||
return f"Processed successfully"
|
||||
|
||||
except Exception as e:
|
||||
return f"Error processing: {str(e)}"
|
||||
|
||||
|
||||
|
||||
def main(datas, output_file_path, max_workers=1):
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
from mysql.connector import pooling, Error
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(datas), desc="Processing papers")
|
||||
|
||||
# 创建一个线程池
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交任务到执行器
|
||||
future_to_path = {}
|
||||
for path in datas:
|
||||
future = executor.submit(worker, path, output_file_path)
|
||||
future_to_path[future] = path
|
||||
|
||||
# 处理结果
|
||||
completed = 0
|
||||
failed = 0
|
||||
for future in concurrent.futures.as_completed(future_to_path):
|
||||
path = future_to_path[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if "successfully" in result:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
# 每100个文件更新一次统计信息
|
||||
if (completed + failed) % 100 == 0:
|
||||
pbar.set_postfix(completed=completed, failed=failed)
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
pbar.update(1)
|
||||
print(f"\nWorker for {path} generated an exception: {e}")
|
||||
|
||||
pbar.close()
|
||||
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import datetime
|
||||
import jsonlines
|
||||
datas = []
|
||||
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
|
||||
for obj in reader:
|
||||
datas.append(obj)
|
||||
|
||||
print(len(datas))
|
||||
# print()
|
||||
output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||
main(datas, output_file,max_workers=8)
|
||||
|
||||
# print("开始测试 process_retrieval_from_knowledge_base 函数...")
|
||||
# data={'doi':'10.1016_s0025-5408(01)00495-0','mp_id':None}
|
||||
# result = process_retrieval_from_knowledge_base(data)
|
||||
# print("函数执行结果:")
|
||||
# print(result)
|
||||
# print("测试完成")
|
||||
26
mattergen_wrapper.py
Normal file
26
mattergen_wrapper.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
This is a wrapper module that provides access to the mattergen modules
|
||||
by modifying the Python path at runtime.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add the mattergen directory to the Python path
|
||||
mattergen_dir = os.path.join(os.path.dirname(__file__), 'mattergen')
|
||||
sys.path.insert(0, mattergen_dir)
|
||||
|
||||
# Import the necessary modules from the mattergen package
|
||||
try:
|
||||
from mattergen import generator
|
||||
from mattergen.common.data import chemgraph
|
||||
from mattergen.common.data.types import TargetProperty
|
||||
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||||
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME
|
||||
except ImportError as e:
|
||||
print(f"Error importing mattergen modules: {e}")
|
||||
print(f"Python path: {sys.path}")
|
||||
raise
|
||||
|
||||
# Re-export the modules
|
||||
__all__ = ['generator', 'chemgraph', 'TargetProperty', 'MatterGenCheckpointInfo', 'PRETRAINED_MODEL_NAME']
|
||||
190
test_tools.py
Normal file
190
test_tools.py
Normal file
@@ -0,0 +1,190 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
async def test_tool(tool_name: str) -> str:
|
||||
"""
|
||||
测试指定的工具函数是否能正常被调用
|
||||
|
||||
Args:
|
||||
tool_name: 工具函数的名称
|
||||
|
||||
Returns:
|
||||
测试结果信息
|
||||
"""
|
||||
try:
|
||||
print(f"开始测试工具: {tool_name}")
|
||||
|
||||
if tool_name == "get_current_time":
|
||||
from tools_for_ms.basic_tools import get_current_time
|
||||
result = await get_current_time(timezone="Asia/Shanghai")
|
||||
|
||||
elif tool_name == "search_online":
|
||||
from tools_for_ms.basic_tools import search_online
|
||||
#from tools_for_ms.basic_tools import search_online
|
||||
result = await search_online(query="material science", num_results=2)
|
||||
|
||||
elif tool_name == "search_material_property_from_material_project":
|
||||
from tools_for_ms.services_tools.mp_tools import search_material_property_from_material_project
|
||||
result = await search_material_property_from_material_project(formula="Fe2O3")
|
||||
|
||||
elif tool_name == "get_crystal_structures_from_materials_project":
|
||||
from tools_for_ms.services_tools.mp_tools import get_crystal_structures_from_materials_project
|
||||
result = await get_crystal_structures_from_materials_project(
|
||||
formulas=["Fe2O3"])
|
||||
elif tool_name == "get_mpid_from_formula":
|
||||
from tools_for_ms.query_tools.mp_tools import get_mpid_from_formula
|
||||
result = await get_mpid_from_formula(['Fe2O3'])
|
||||
elif tool_name == "optimize_crystal_structure":
|
||||
from tools_for_ms.services_tools.fairchem_tools import optimize_crystal_structure
|
||||
# 使用一个简单的CIF字符串作为测试输入
|
||||
simple_cif = """
|
||||
data_simple
|
||||
_cell_length_a 4.0
|
||||
_cell_length_b 4.0
|
||||
_cell_length_c 4.0
|
||||
_cell_angle_alpha 90
|
||||
_cell_angle_beta 90
|
||||
_cell_angle_gamma 90
|
||||
_symmetry_space_group_name_H-M 'P 1'
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
Si 0.0 0.0 0.0
|
||||
O 0.25 0.25 0.25
|
||||
"""
|
||||
result = await optimize_crystal_structure(content=simple_cif, input_format="cif")
|
||||
|
||||
elif tool_name == "generate_material":
|
||||
from tools_for_ms.services_tools.mattergen_tools import generate_material
|
||||
# 使用简单的属性约束进行测试
|
||||
result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
|
||||
|
||||
elif tool_name == "fetch_chemical_composition_from_OQMD":
|
||||
from tools_for_ms.services_tools.oqmd_tools import fetch_chemical_composition_from_OQMD
|
||||
result = await fetch_chemical_composition_from_OQMD(composition="Fe2O3")
|
||||
|
||||
elif tool_name == "retrieval_from_knowledge_base":
|
||||
from tools_for_ms.query_tools.search_dify import retrieval_from_knowledge_base
|
||||
result = await retrieval_from_knowledge_base(query="CsPbBr3", topk=3)
|
||||
|
||||
elif tool_name == "predict_properties":
|
||||
from tools_for_ms.services_tools.mattersim_tools import predict_properties
|
||||
# 使用一个简单的硅钻石结构CIF字符串作为测试输入
|
||||
_cif = """
|
||||
# generated using pymatgen
|
||||
data_CsPbBr3
|
||||
_symmetry_space_group_name_H-M 'P 1'
|
||||
_cell_length_a 8.37036600
|
||||
_cell_length_b 8.42533500
|
||||
_cell_length_c 12.01129500
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 90.00000000
|
||||
_symmetry_Int_Tables_number 1
|
||||
_chemical_formula_structural CsPbBr3
|
||||
_chemical_formula_sum 'Cs4 Pb4 Br12'
|
||||
_cell_volume 847.07421031
|
||||
_cell_formula_units_Z 4
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
Cs Cs0 1 0.50831300 0.46818500 0.25000000 1
|
||||
Cs Cs1 1 0.00831300 0.03181500 0.75000000 1
|
||||
Cs Cs2 1 0.99168700 0.96818500 0.25000000 1
|
||||
Cs Cs3 1 0.49168700 0.53181500 0.75000000 1
|
||||
Pb Pb4 1 0.50000000 0.00000000 0.50000000 1
|
||||
Pb Pb5 1 0.00000000 0.50000000 0.00000000 1
|
||||
Pb Pb6 1 0.00000000 0.50000000 0.50000000 1
|
||||
Pb Pb7 1 0.50000000 0.00000000 0.00000000 1
|
||||
Br Br8 1 0.54824500 0.99370800 0.75000000 1
|
||||
Br Br9 1 0.04824500 0.50629200 0.25000000 1
|
||||
Br Br10 1 0.79480800 0.20538600 0.02568800 1
|
||||
Br Br11 1 0.20519200 0.79461400 0.97431200 1
|
||||
Br Br12 1 0.29480800 0.29461400 0.97431200 1
|
||||
Br Br13 1 0.70519200 0.70538600 0.02568800 1
|
||||
Br Br14 1 0.29480800 0.29461400 0.52568800 1
|
||||
Br Br15 1 0.70519200 0.70538600 0.47431200 1
|
||||
Br Br16 1 0.20519200 0.79461400 0.52568800 1
|
||||
Br Br17 1 0.79480800 0.20538600 0.47431200 1
|
||||
Br Br18 1 0.95175500 0.49370800 0.75000000 1
|
||||
Br Br19 1 0.45175500 0.00629200 0.25000000 1
|
||||
|
||||
"""
|
||||
result = await predict_properties(cif_content=_cif)
|
||||
|
||||
# elif tool_name == "visualize_cif":
|
||||
# from tools_for_ms.services_tools.cif_visualization_tools import visualize_cif
|
||||
# # 使用一个简单的CIF字符串作为测试输入
|
||||
# simple_cif = """
|
||||
# data_CdEu2NEu
|
||||
# _chemical_formula_structural CdEu2NEu
|
||||
# _chemical_formula_sum "Cd1 Eu3 N1"
|
||||
# _cell_length_a 5.114863465543178
|
||||
# _cell_length_b 5.110721509244114
|
||||
# _cell_length_c 5.113552093505859
|
||||
# _cell_angle_alpha 90.02261043268513
|
||||
# _cell_angle_beta 90.00946914658029
|
||||
# _cell_angle_gamma 89.99314499504335
|
||||
|
||||
# _space_group_name_H-M_alt "P 1"
|
||||
# _space_group_IT_number 1
|
||||
|
||||
# loop_
|
||||
# _space_group_symop_operation_xyz
|
||||
# 'x, y, z'
|
||||
|
||||
# loop_
|
||||
# _atom_site_type_symbol
|
||||
# _atom_site_label
|
||||
# _atom_site_symmetry_multiplicity
|
||||
# _atom_site_fract_x
|
||||
# _atom_site_fract_y
|
||||
# _atom_site_fract_z
|
||||
# _atom_site_occupancy
|
||||
# Cd Cd1 1.0 0.6641489863395691 0.6804293394088744 0.3527604341506958 1.0000
|
||||
# Eu Eu1 1.0 0.1641521006822586 0.18045939505100247 0.35262206196784973 1.0000
|
||||
# Eu Eu2 1.0 0.16385404765605927 0.6803322434425354 0.8526210784912109 1.0000
|
||||
# N N1 1.0 0.16389326751232147 0.1804375052452087 0.8527467250823975 1.0000
|
||||
# Eu Eu3 1.0 0.664197564125061 0.1803932040929794 0.8526203036308289 1.0000
|
||||
# """
|
||||
# result = await visualize_cif(cif_content=simple_cif)
|
||||
|
||||
# else:
|
||||
# return f"未知工具: {tool_name}"
|
||||
|
||||
print(f"工具 {tool_name} 测试完成")
|
||||
return f"工具 {tool_name} 测试成功,返回结果类型: {type(result)},返回的结果{result}"
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
return f"工具 {tool_name} 测试失败: {str(e)}\n{error_details}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# 查询工具
|
||||
# search_material_property_from_material_project 在material project中通过化学式查询材料性质 ✅
|
||||
# get_crystal_structures_from_materials_project 在material project中通过化学式查询晶体性质✅
|
||||
# fetch_chemical_composition_from_OQMD 在OQMD中通过化学式查询获取化学组成✅
|
||||
# search_online
|
||||
# 生成内容的工具
|
||||
# optimize_crystal_structure 使用fairchem 优化晶体结构✅
|
||||
# predict_properties 使用mattersim 预测晶体性质 ✅
|
||||
# generate_material 使用matter 预测晶体性质✅
|
||||
# 测试 MatterSim 工具
|
||||
tool_name ='search_material_property_from_material_project'
|
||||
result = asyncio.run(test_tool(tool_name))
|
||||
print(result)
|
||||
16
tools_for_ms/__init__.py
Normal file
16
tools_for_ms/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Tools package for LLM function calling.
|
||||
|
||||
This package provides utilities for defining, registering, and managing LLM tools.
|
||||
"""
|
||||
|
||||
from .llm_tools import llm_tool, get_tools, get_tool_schemas
|
||||
from .basic_tools import *
|
||||
from .services_tools.oqmd_tools import fetch_chemical_composition_from_OQMD
|
||||
from .services_tools.mp_tools import search_material_property_from_material_project,get_crystal_structures_from_materials_project
|
||||
from .services_tools.search_dify import retrieval_from_knowledge_base
|
||||
from .services_tools.fairchem_tools import optimize_crystal_structure
|
||||
#from .services_tools.mattergen_tools import generate_material
|
||||
#from .services_tools.mattersim_tools import predict_properties
|
||||
|
||||
__all__ = ["llm_tool", "get_tools", "get_tool_schemas"]
|
||||
BIN
tools_for_ms/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
tools_for_ms/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tools_for_ms/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tools_for_ms/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/api_key.cpython-310.pyc
Normal file
BIN
tools_for_ms/__pycache__/api_key.cpython-310.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/basic_tools.cpython-310.pyc
Normal file
BIN
tools_for_ms/__pycache__/basic_tools.cpython-310.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/basic_tools.cpython-312.pyc
Normal file
BIN
tools_for_ms/__pycache__/basic_tools.cpython-312.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/llm_tools.cpython-310.pyc
Normal file
BIN
tools_for_ms/__pycache__/llm_tools.cpython-310.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/llm_tools.cpython-311.pyc
Normal file
BIN
tools_for_ms/__pycache__/llm_tools.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/llm_tools.cpython-312.pyc
Normal file
BIN
tools_for_ms/__pycache__/llm_tools.cpython-312.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/utils.cpython-310.pyc
Normal file
BIN
tools_for_ms/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/__pycache__/utils.cpython-312.pyc
Normal file
BIN
tools_for_ms/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
123
tools_for_ms/basic_tools.py
Normal file
123
tools_for_ms/basic_tools.py
Normal file
@@ -0,0 +1,123 @@
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Dict, Any
|
||||
import pytz
|
||||
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
|
||||
from .llm_tools import llm_tool
|
||||
|
||||
# Json Schema 将函数转化为大模型能够理解的格式(因为大模型训练时,调用函数相关的数据使用Json Schema的格式)
|
||||
#1. 使用@tool装饰器装饰函数。
|
||||
#2. 使用Annotated为参数添加描述。
|
||||
#3. 完善函数的docstring以明确工具的功能。 让模型在调用函数时能清楚每个模块的功能
|
||||
|
||||
# @tool
|
||||
# def online_search(
|
||||
# query: Annotated[str, "The search term to find scientific content in English"]
|
||||
# ) -> str:
|
||||
# """Searches scientific information on the Internet and returns results in English."""
|
||||
# search = SearxSearchWrapper(
|
||||
# searx_host="http://192.168.191.101:40032/",
|
||||
# categories=["science"],
|
||||
# k=20
|
||||
# )
|
||||
|
||||
# return search.run(query, language='es', num_results=2)
|
||||
|
||||
|
||||
@llm_tool(name="get_current_time", description="Get current date and time in specified timezone")
|
||||
async def get_current_time(timezone: str = "UTC") -> str:
|
||||
"""Returns the current date and time in the specified timezone.
|
||||
|
||||
Args:
|
||||
timezone: Timezone name (e.g., UTC, Asia/Shanghai, America/New_York)
|
||||
|
||||
Returns:
|
||||
Formatted date and time string
|
||||
"""
|
||||
try:
|
||||
tz = pytz.timezone(timezone)
|
||||
current_time = datetime.now(tz)
|
||||
return f"The current {timezone} time is: {current_time.strftime('%Y-%m-%d %H:%M:%S %Z')}"
|
||||
except pytz.exceptions.UnknownTimeZoneError:
|
||||
return f"Unknown timezone: {timezone}. Please use a valid timezone such as 'UTC', 'Asia/Shanghai', etc."
|
||||
|
||||
@llm_tool(name="search_online", description="Search scientific information online and return results as a string")
|
||||
async def search_online(
|
||||
query: Annotated[str, "Search term"],
|
||||
num_results: Annotated[int, "Number of results (1-20)"] = 5
|
||||
) -> str:
|
||||
"""
|
||||
Searches for scientific information online and returns results as a formatted string.
|
||||
|
||||
Args:
|
||||
query: Search term for scientific content
|
||||
num_results: Number of results to return (1-20)
|
||||
|
||||
Returns:
|
||||
Formatted string with search results (titles, snippets, links)
|
||||
"""
|
||||
# 确保 num_results 是整数
|
||||
try:
|
||||
num_results = int(num_results)
|
||||
except (TypeError, ValueError):
|
||||
num_results = 5
|
||||
|
||||
# Parameter validation
|
||||
if num_results < 1:
|
||||
num_results = 1
|
||||
elif num_results > 20:
|
||||
num_results = 20
|
||||
|
||||
# Initialize search wrapper
|
||||
search = SearxSearchWrapper(
|
||||
searx_host="http://192.168.191.101:40032/",
|
||||
categories=["science",],
|
||||
k=num_results
|
||||
)
|
||||
|
||||
# Execute search in a separate thread to avoid blocking the event loop
|
||||
# since SearxSearchWrapper doesn't have native async support
|
||||
loop = asyncio.get_event_loop()
|
||||
raw_results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: search.results(query, language=['en','zh'], num_results=num_results)
|
||||
)
|
||||
|
||||
# Transform results into structured format
|
||||
formatted_results = []
|
||||
for result in raw_results:
|
||||
formatted_results.append({
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("snippet", ""),
|
||||
"link": result.get("link", ""),
|
||||
"source": result.get("source", "")
|
||||
})
|
||||
|
||||
# Convert the results to a formatted string
|
||||
result_str = f"Search Results for '{query}' ({len(formatted_results)} items):\n\n"
|
||||
|
||||
for i, result in enumerate(formatted_results, 1):
|
||||
result_str += f"Result {i}:\n"
|
||||
result_str += f"Title: {result['title']}\n"
|
||||
result_str += f"Summary: {result['snippet']}\n"
|
||||
result_str += f"Link: {result['link']}\n"
|
||||
result_str += f"Source: {result['source']}\n\n"
|
||||
|
||||
return result_str
|
||||
|
||||
|
||||
# 让大模型可以根据函数名,直接调用函数
|
||||
# tool_map = {
|
||||
# "online_search": online_search,
|
||||
# "get_current_time": get_current_time,
|
||||
# }
|
||||
|
||||
|
||||
|
||||
|
||||
#####要用时得加修饰符@tool;为了实现异步,不用@tool
|
||||
#tools = [online_search,get_current_time]
|
||||
#tools_json_shcema = [convert_to_openai_function_format(tool.args_schema.model_json_schema()) for tool in tools]
|
||||
213
tools_for_ms/llm_tools.py
Normal file
213
tools_for_ms/llm_tools.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
LLM Tools Module
|
||||
|
||||
This module provides decorators and utilities for defining, registering, and managing LLM tools.
|
||||
It allows marking functions as LLM tools, generating JSON schemas for them, and retrieving
|
||||
registered tools for use with LLM APIs.
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, get_type_hints, get_origin, get_args
|
||||
import docstring_parser
|
||||
from pydantic import BaseModel, create_model, Field
|
||||
|
||||
# Registry to store all registered tools
|
||||
_TOOL_REGISTRY = {}
|
||||
|
||||
def llm_tool(name: Optional[str] = None, description: Optional[str] = None):
|
||||
"""
|
||||
Decorator to mark a function as an LLM tool.
|
||||
|
||||
This decorator registers the function as an LLM tool, generates a JSON schema for it,
|
||||
and makes it available for retrieval through the get_tools function.
|
||||
|
||||
Args:
|
||||
name: Optional custom name for the tool. If not provided, the function name will be used.
|
||||
description: Optional custom description for the tool. If not provided, the function's
|
||||
docstring will be used.
|
||||
|
||||
Returns:
|
||||
The decorated function with additional attributes for LLM tool functionality.
|
||||
|
||||
Example:
|
||||
@llm_tool(name="weather_lookup", description="Get current weather for a location")
|
||||
def get_weather(location: str, units: str = "metric") -> Dict[str, Any]:
|
||||
'''Get weather information for a specific location.'''
|
||||
# Implementation...
|
||||
return {"temperature": 22.5, "conditions": "sunny"}
|
||||
"""
|
||||
# Handle case when decorator is used without parentheses: @llm_tool
|
||||
if callable(name):
|
||||
func = name
|
||||
name = None
|
||||
description = None
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
# Handle case when decorator is used with parentheses: @llm_tool() or @llm_tool(name="xyz")
|
||||
def decorator(func: Callable) -> Callable:
|
||||
return _llm_tool_impl(func, name, description)
|
||||
|
||||
return decorator
|
||||
|
||||
def _llm_tool_impl(func: Callable, name: Optional[str] = None, description: Optional[str] = None) -> Callable:
|
||||
"""Implementation of the llm_tool decorator."""
|
||||
# Get function signature and docstring
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func) or ""
|
||||
parsed_doc = docstring_parser.parse(doc)
|
||||
|
||||
# Determine tool name
|
||||
tool_name = name or func.__name__
|
||||
|
||||
# Determine tool description
|
||||
tool_description = description or doc
|
||||
|
||||
# Create parameter properties for JSON schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self parameter for methods
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = None if param.default is inspect.Parameter.empty else param.default
|
||||
param_required = param.default is inspect.Parameter.empty
|
||||
|
||||
# Get parameter description from docstring if available
|
||||
param_desc = ""
|
||||
for param_doc in parsed_doc.params:
|
||||
if param_doc.arg_name == param_name:
|
||||
param_desc = param_doc.description
|
||||
break
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0] # The actual type
|
||||
if len(args) > 1 and isinstance(args[1], str):
|
||||
param_desc = args[1] # The description
|
||||
|
||||
# Create property for parameter
|
||||
param_schema = {
|
||||
"type": _get_json_type(param_type),
|
||||
"description": param_desc,
|
||||
"title": param_name.replace("_", " ").title()
|
||||
}
|
||||
|
||||
# Add default value if available
|
||||
if param_default is not None:
|
||||
param_schema["default"] = param_default
|
||||
|
||||
properties[param_name] = param_schema
|
||||
|
||||
# Add to required list if no default value
|
||||
if param_required:
|
||||
required.append(param_name)
|
||||
|
||||
# Create JSON schema
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Create Pydantic model for args schema
|
||||
field_definitions = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = param.annotation
|
||||
param_default = ... if param.default is inspect.Parameter.empty else param.default
|
||||
|
||||
# Handle Annotated types
|
||||
if get_origin(param_type) is not None and get_origin(param_type).__name__ == "Annotated":
|
||||
args = get_args(param_type)
|
||||
param_type = args[0]
|
||||
description = args[1] if len(args) > 1 and isinstance(args[1], str) else ""
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default, description=description))
|
||||
else:
|
||||
field_definitions[param_name] = (param_type, Field(default=param_default))
|
||||
|
||||
# Create args schema model
|
||||
model_name = f"{tool_name.title().replace('_', '')}Schema"
|
||||
args_schema = create_model(model_name, **field_definitions)
|
||||
|
||||
# 根据原始函数是否是异步函数来创建相应类型的包装函数
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Attach metadata to function
|
||||
wrapper.is_llm_tool = True
|
||||
wrapper.tool_name = tool_name
|
||||
wrapper.tool_description = tool_description
|
||||
wrapper.json_schema = schema
|
||||
wrapper.args_schema = args_schema
|
||||
|
||||
# Register the tool
|
||||
_TOOL_REGISTRY[tool_name] = wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_tools() -> Dict[str, Callable]:
|
||||
"""
|
||||
Get all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping tool names to their corresponding functions.
|
||||
"""
|
||||
return _TOOL_REGISTRY
|
||||
|
||||
def get_tool_schemas() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get JSON schemas for all registered LLM tools.
|
||||
|
||||
Returns:
|
||||
A list of JSON schemas for all registered tools, suitable for use with LLM APIs.
|
||||
"""
|
||||
return [tool.json_schema for tool in _TOOL_REGISTRY.values()]
|
||||
|
||||
def _get_json_type(python_type: Any) -> str:
|
||||
"""
|
||||
Convert Python type to JSON schema type.
|
||||
|
||||
Args:
|
||||
python_type: Python type annotation
|
||||
|
||||
Returns:
|
||||
Corresponding JSON schema type as string
|
||||
"""
|
||||
if python_type is str:
|
||||
return "string"
|
||||
elif python_type is int:
|
||||
return "integer"
|
||||
elif python_type is float:
|
||||
return "number"
|
||||
elif python_type is bool:
|
||||
return "boolean"
|
||||
elif python_type is list or python_type is List:
|
||||
return "array"
|
||||
elif python_type is dict or python_type is Dict:
|
||||
return "object"
|
||||
else:
|
||||
# Default to string for complex types
|
||||
return "string"
|
||||
23
tools_for_ms/services_tools/Configs.py
Normal file
23
tools_for_ms/services_tools/Configs.py
Normal file
@@ -0,0 +1,23 @@
|
||||
MP_API_KEY='PMASAg256b814q3OaSRWeVc7MKx4mlKI'
|
||||
MP_ENDPOINT='https://api.materialsproject.org/'
|
||||
MP_TOPK = 3
|
||||
LOCAL_MP_PROPERTY_ROOT='/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Props/'
|
||||
|
||||
|
||||
|
||||
# Proxy
|
||||
HTTP_PROXY='http://192.168.168.1:20171' #'http://127.0.0.1:7897' #192.168.191.101:20171
|
||||
HTTPS_PROXY='http://192.168.168.1:20171'#'http://127.0.0.1:7897'
|
||||
|
||||
|
||||
FAIRCHEM_MODEL_PATH='/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
|
||||
FMAX=0.05
|
||||
|
||||
|
||||
MATTERGENMODEL_ROOT='/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
|
||||
MATTERGENMODEL_RESULT_PATH='results/'
|
||||
|
||||
DIFY_ROOT_URL='http://192.168.191.101:6080'
|
||||
DIFY_API_KEY='app-IKZrS1RqIyurPSzR73mz6XSA'
|
||||
|
||||
VIZ_CIF_OUTPUT_ROOT='/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
|
||||
BIN
tools_for_ms/services_tools/__pycache__/Configs.cpython-310.pyc
Normal file
BIN
tools_for_ms/services_tools/__pycache__/Configs.cpython-310.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/services_tools/__pycache__/Configs.cpython-312.pyc
Normal file
BIN
tools_for_ms/services_tools/__pycache__/Configs.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tools_for_ms/services_tools/__pycache__/mp_tools.cpython-310.pyc
Normal file
BIN
tools_for_ms/services_tools/__pycache__/mp_tools.cpython-310.pyc
Normal file
Binary file not shown.
BIN
tools_for_ms/services_tools/__pycache__/mp_tools.cpython-312.pyc
Normal file
BIN
tools_for_ms/services_tools/__pycache__/mp_tools.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tools_for_ms/services_tools/__pycache__/utils.cpython-310.pyc
Normal file
BIN
tools_for_ms/services_tools/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
49
tools_for_ms/services_tools/error_handlers.py
Normal file
49
tools_for_ms/services_tools/error_handlers.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class APIError(HTTPException):
|
||||
"""自定义API错误类"""
|
||||
def __init__(self, status_code: int, detail: Any = None):
|
||||
super().__init__(status_code=status_code, detail=detail)
|
||||
logger.error(f"API Error: {status_code} - {detail}")
|
||||
|
||||
def handle_minio_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理MinIO相关错误"""
|
||||
logger.error(f"MinIO operation failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"MinIO operation failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_http_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理HTTP请求错误"""
|
||||
logger.error(f"HTTP request failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"HTTP request failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_validation_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理数据验证错误"""
|
||||
logger.error(f"Validation failed: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"Validation failed: {str(e)}"
|
||||
}
|
||||
|
||||
def handle_general_error(e: Exception) -> Dict[str, str]:
|
||||
"""处理通用错误"""
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"data": f"Unexpected error: {str(e)}"
|
||||
}
|
||||
386
tools_for_ms/services_tools/fairchem_tools.py
Normal file
386
tools_for_ms/services_tools/fairchem_tools.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# 输入content-> 转换content生成atom(ase能处理的格式)->用matgen生成优化后的结构-再生成对称性cif。
|
||||
|
||||
from io import StringIO
|
||||
import sys
|
||||
import tempfile
|
||||
from ase.io import read, write
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
import os
|
||||
|
||||
from .. import llm_tool
|
||||
os.environ["PYTHONWARNINGS"] = "ignore"
|
||||
|
||||
# 或者更精细的控制
|
||||
os.environ["PYTHONWARNINGS"] = "ignore::DeprecationWarning"
|
||||
|
||||
from typing import Optional
|
||||
import logging
|
||||
from pymatgen.core.structure import Structure
|
||||
from .error_handlers import handle_general_error
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
logger = logging.getLogger(__name__)
|
||||
from pymatgen.io.cif import CifWriter
|
||||
from ase.atoms import Atoms
|
||||
from .Configs import*
|
||||
calc = None
|
||||
|
||||
# def init_model():
|
||||
# """初始化FairChem模型"""
|
||||
# global calc
|
||||
# try:
|
||||
# from fairchem.core import OCPCalculator
|
||||
# calc = OCPCalculator(checkpoint_path= FAIRCHEM_MODEL_PATH)
|
||||
# logger.info("FairChem model initialized successfully")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||||
# raise
|
||||
# init_model()
|
||||
# # 格式转化
|
||||
# def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
||||
|
||||
# '''example:
|
||||
# input_format = "xyz" cif vasp 等等
|
||||
# content = """5
|
||||
# H2O molecule with an extra oxygen and hydrogen
|
||||
# O 0.0 0.0 0.0
|
||||
# H 0.0 0.0 0.9
|
||||
# H 0.0 0.9 0.0
|
||||
# O 1.0 0.0 0.0
|
||||
# H 1.0 0.0 0.9
|
||||
# return Atoms(symbols='OH2OH', pbc=False)
|
||||
# """
|
||||
# '''
|
||||
|
||||
# """将输入内容转换为Atoms对象"""
|
||||
# try:
|
||||
# with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||||
# tmp_file.write(content)
|
||||
# tmp_path = tmp_file.name
|
||||
|
||||
# atoms = read(tmp_path)
|
||||
# os.unlink(tmp_path)
|
||||
# return atoms
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to convert structure: {str(e)}")
|
||||
# return None
|
||||
|
||||
# def generate_symmetry_cif(structure: Structure) -> str:
|
||||
# """生成对称性CIF"""
|
||||
# analyzer = SpacegroupAnalyzer(structure)
|
||||
# structure = analyzer.get_refined_structure()
|
||||
|
||||
# with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||
# cif_writer = CifWriter(structure, symprec=0.1, refine_struct=True)
|
||||
# cif_writer.write_file(tmp_file.name)
|
||||
# tmp_file.seek(0)
|
||||
# return tmp_file.read()
|
||||
|
||||
|
||||
# def optimize_structure(atoms: Atoms, output_format: str):
|
||||
# """优化晶体结构"""
|
||||
# atoms.calc = calc
|
||||
# try:
|
||||
# import io
|
||||
# # from contextlib import redirect_stdout
|
||||
|
||||
# # # 创建StringIO对象捕获输出
|
||||
# # f = io.StringIO()
|
||||
# # dyn = FIRE(FrechetCellFilter(atoms))
|
||||
|
||||
# # # 同时捕获并输出到控制台
|
||||
# # with redirect_stdout(f):
|
||||
# # dyn.run(fmax=FMAX)
|
||||
# # # 获取捕获的日志
|
||||
# # optimization_log = f.getvalue()
|
||||
|
||||
# temp_output = StringIO()
|
||||
# # 保存原始的stdout
|
||||
# original_stdout = sys.stdout
|
||||
# # 重定向stdout到StringIO对象
|
||||
# sys.stdout = temp_output
|
||||
# dyn = FIRE(FrechetCellFilter(atoms))
|
||||
# dyn.run(fmax=FMAX)
|
||||
# sys.stdout = original_stdout
|
||||
# output_string = temp_output.getvalue()
|
||||
|
||||
# temp_output.close()
|
||||
# optimization_log = output_string
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# # 同时输出到控制台
|
||||
|
||||
# total_energy = atoms.get_potential_energy()
|
||||
|
||||
# except Exception as e:
|
||||
# return handle_general_error(e)
|
||||
|
||||
# #atoms.get_potential_energy() 函数解析
|
||||
# # atoms.get_potential_energy() 是 ASE (Atomic Simulation Environment) 中 Atoms 对象的一个方法,用于获取原子系统的势能(或总能量)。
|
||||
|
||||
# # 功能与用途
|
||||
# # 获取能量 :返回原子系统的计算总能量,通常以电子伏特 (eV) 为单位。
|
||||
# # 用途 :
|
||||
# # 评估结构稳定性(能量越低的结构通常越稳定)
|
||||
# # 计算反应能垒和反应能
|
||||
# # 分析能量随结构变化的趋势
|
||||
# # 作为结构优化的目标函数
|
||||
# # 计算分子或材料的吸附能、形成能等
|
||||
# # 工作原理
|
||||
# # 计算引擎依赖 :
|
||||
# # 该方法不会自行计算能量,而是从附加到 Atoms 对象的计算器 (calculator) 获取能量
|
||||
# # 需要先给 Atoms 对象设置一个计算器(如 VASP、Quantum ESPRESSO、GPAW 等)
|
||||
# # 执行机制 :
|
||||
# # 如果能量已计算过且原子结构未改变,则返回缓存值
|
||||
# # 否则会触发计算器执行能量计算
|
||||
|
||||
# # 处理对称性
|
||||
# if output_format == "cif":
|
||||
# optimized_structure = Structure.from_ase_atoms(atoms)
|
||||
# content = generate_symmetry_cif(optimized_structure)
|
||||
# #print('xxx',content)
|
||||
# #print('yyy',total_energy)
|
||||
# # 格式化返回结果
|
||||
# format_result = f"""
|
||||
# The following is the optimized crystal structure information:
|
||||
# ### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||
# **Total Energy: {total_energy} eV**
|
||||
|
||||
# #### Optimizing Log:
|
||||
# ```text
|
||||
# {optimization_log}
|
||||
# ```
|
||||
# ### Optimized {output_format.upper()} Content:
|
||||
# ```{content}
|
||||
# {optimized_structure[:300]}
|
||||
# ```
|
||||
# """
|
||||
# print("output_log",format_result)
|
||||
|
||||
input_format = "cif" # generated using pymatgen
|
||||
content = """
|
||||
data_H2O
|
||||
_symmetry_space_group_name_H-M 'P 1'
|
||||
_cell_length_a 7.60356659
|
||||
_cell_length_b 7.60356659
|
||||
_cell_length_c 7.14296200
|
||||
_cell_angle_alpha 90.00000000
|
||||
_cell_angle_beta 90.00000000
|
||||
_cell_angle_gamma 120.00000516
|
||||
_symmetry_Int_Tables_number 1
|
||||
_chemical_formula_structural H2O
|
||||
_chemical_formula_sum 'H24 O12'
|
||||
_cell_volume 357.63799926
|
||||
_cell_formula_units_Z 12
|
||||
loop_
|
||||
_symmetry_equiv_pos_site_id
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
1 'x, y, z'
|
||||
loop_
|
||||
_atom_site_type_symbol
|
||||
_atom_site_label
|
||||
_atom_site_symmetry_multiplicity
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
H H0 1 0.33082300 0.33082300 0.69642800 1
|
||||
H H1 1 0.66917700 0.00000000 0.69642800 1
|
||||
H H2 1 0.00000000 0.66917700 0.69642800 1
|
||||
H H3 1 0.66917700 0.66917700 0.19642800 1
|
||||
H H4 1 0.33082300 0.00000000 0.19642800 1
|
||||
H H5 1 0.00000000 0.33082300 0.19642800 1
|
||||
H H6 1 0.45234700 0.45234700 0.51064600 1
|
||||
H H7 1 0.54765300 0.00000000 0.51064600 1
|
||||
H H8 1 0.00000000 0.54765300 0.51064600 1
|
||||
H H9 1 0.54765300 0.54765300 0.01064600 1
|
||||
H H10 1 0.45234700 0.00000000 0.01064600 1
|
||||
H H11 1 0.00000000 0.45234700 0.01064600 1
|
||||
H H12 1 0.78617100 0.66371600 0.47884700 1
|
||||
H H13 1 0.33628400 0.12245500 0.47884700 1
|
||||
H H14 1 0.87754500 0.21382900 0.47884700 1
|
||||
H H15 1 0.66371600 0.78617100 0.47884700 1
|
||||
H H16 1 0.12245500 0.33628400 0.47884700 1
|
||||
H H17 1 0.21382900 0.87754500 0.47884700 1
|
||||
H H18 1 0.21382900 0.33628400 0.97884700 1
|
||||
H H19 1 0.66371600 0.87754500 0.97884700 1
|
||||
H H20 1 0.12245500 0.78617100 0.97884700 1
|
||||
H H21 1 0.33628400 0.21382900 0.97884700 1
|
||||
H H22 1 0.87754500 0.66371600 0.97884700 1
|
||||
H H23 1 0.78617100 0.12245500 0.97884700 1
|
||||
O O24 1 0.32664200 0.32664200 0.55565800 1
|
||||
O O25 1 0.67335800 0.00000000 0.55565800 1
|
||||
O O26 1 0.00000000 0.67335800 0.55565800 1
|
||||
O O27 1 0.67335800 0.67335800 0.05565800 1
|
||||
O O28 1 0.32664200 0.00000000 0.05565800 1
|
||||
O O29 1 0.00000000 0.32664200 0.05565800 1
|
||||
O O30 1 0.66060500 0.66060500 0.42957500 1
|
||||
O O31 1 0.33939500 0.00000000 0.42957500 1
|
||||
O O32 1 0.00000000 0.33939500 0.42957500 1
|
||||
O O33 1 0.33939500 0.33939500 0.92957500 1
|
||||
O O34 1 0.66060500 0.00000000 0.92957500 1
|
||||
O O35 1 0.00000000 0.66060500 0.92957500 1
|
||||
"""
|
||||
# atoms=convert_structure(input_format=input_format,content=content)
|
||||
# optimize_structure(atoms=atoms,output_format='cif')
|
||||
|
||||
# 添加新的异步LLM工具,包装optimize_structure功能
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
import sys
|
||||
import tempfile
|
||||
from ase.io import read, write
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from ase.atoms import Atoms
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FairChem模型
|
||||
calc = None
|
||||
|
||||
def init_model():
|
||||
"""初始化FairChem模型"""
|
||||
global calc
|
||||
if calc is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from fairchem.core import OCPCalculator
|
||||
from tools_for_ms.services_tools.Configs import FAIRCHEM_MODEL_PATH
|
||||
calc = OCPCalculator(checkpoint_path=FAIRCHEM_MODEL_PATH)
|
||||
logger.info("FairChem model initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize FairChem model: {str(e)}")
|
||||
raise
|
||||
|
||||
def convert_structure(input_format: str, content: str) -> Optional[Atoms]:
|
||||
"""将输入内容转换为Atoms对象"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{input_format}", mode="w", delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
atoms = read(tmp_path)
|
||||
os.unlink(tmp_path)
|
||||
return atoms
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert structure: {str(e)}")
|
||||
return None
|
||||
|
||||
def generate_symmetry_cif(structure: Structure) -> str:
|
||||
"""生成对称性CIF"""
|
||||
analyzer = SpacegroupAnalyzer(structure)
|
||||
structure_refined = analyzer.get_refined_structure()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".cif", mode="w+", delete=False) as tmp_file:
|
||||
cif_writer = CifWriter(structure_refined, symprec=0.1, refine_struct=True)
|
||||
cif_writer.write_file(tmp_file.name)
|
||||
tmp_file.seek(0)
|
||||
return tmp_file.read()
|
||||
|
||||
def optimize_structure(atoms: Atoms, output_format: str) -> Dict[str, Any]:
|
||||
"""优化晶体结构"""
|
||||
atoms.calc = calc
|
||||
|
||||
try:
|
||||
# 捕获优化过程的输出
|
||||
temp_output = StringIO()
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = temp_output
|
||||
|
||||
# 执行优化
|
||||
from tools_for_ms.services_tools.Configs import FMAX
|
||||
dyn = FIRE(FrechetCellFilter(atoms))
|
||||
dyn.run(fmax=FMAX)
|
||||
|
||||
# 恢复标准输出并获取日志
|
||||
sys.stdout = original_stdout
|
||||
optimization_log = temp_output.getvalue()
|
||||
temp_output.close()
|
||||
|
||||
# 获取总能量
|
||||
total_energy = atoms.get_potential_energy()
|
||||
|
||||
# 处理优化后的结构
|
||||
if output_format == "cif":
|
||||
optimized_structure = Structure.from_ase_atoms(atoms)
|
||||
content = generate_symmetry_cif(optimized_structure)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", mode="w+", delete=False) as tmp_file:
|
||||
write(tmp_file.name, atoms)
|
||||
tmp_file.seek(0)
|
||||
content = tmp_file.read()
|
||||
|
||||
# 格式化返回结果
|
||||
format_result = f"""
|
||||
The following is the optimized crystal structure information:
|
||||
### Optimization Results (using FIRE(eqV2_86M) algorithm):
|
||||
**Total Energy: {total_energy} eV**
|
||||
|
||||
#### Optimizing Log:
|
||||
```text
|
||||
{optimization_log}
|
||||
```
|
||||
### Optimized {output_format.upper()} Content:
|
||||
```
|
||||
{content[:300]}
|
||||
```
|
||||
"""
|
||||
|
||||
# return {
|
||||
# "total_energy": total_energy,
|
||||
# "optimization_log": optimization_log,
|
||||
# "content": content,
|
||||
# "formatted_result": format_result
|
||||
# }
|
||||
return format_result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to optimize structure: {str(e)}")
|
||||
raise e
|
||||
|
||||
@llm_tool(name="optimize_crystal_structure",
|
||||
description="Optimize crystal structure using FairChem model")
|
||||
async def optimize_crystal_structure(
|
||||
content: str,
|
||||
input_format: str = "cif",
|
||||
output_format: str = "cif"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optimize crystal structure using FairChem model.
|
||||
|
||||
Args:
|
||||
content: Crystal structure content string
|
||||
input_format: Input format (cif, xyz, vasp)
|
||||
output_format: Output format (cif, xyz, vasp)
|
||||
|
||||
Returns:
|
||||
Optimized structure with energy and optimization log
|
||||
"""
|
||||
# 确保模型已初始化
|
||||
if calc is None:
|
||||
init_model()
|
||||
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_optimization():
|
||||
# 转换结构
|
||||
atoms = convert_structure(input_format, content)
|
||||
if atoms is None:
|
||||
raise ValueError(f"无法转换输入的{input_format}格式内容,请检查格式是否正确")
|
||||
|
||||
# 优化结构
|
||||
return optimize_structure(atoms, output_format)
|
||||
|
||||
# 直接返回结果或抛出异常
|
||||
return await asyncio.to_thread(run_optimization)
|
||||
412
tools_for_ms/services_tools/mattergen_tools.py
Normal file
412
tools_for_ms/services_tools/mattergen_tools.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import os
|
||||
import datetime
|
||||
import asyncio
|
||||
import zipfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
from ase.io import read, write
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
# Use our wrapper module instead of direct imports
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
||||
#import mattergen_wrapper
|
||||
|
||||
# Access the modules through the wrapper
|
||||
# The generator module is re-exported as an attribute, not a submodule
|
||||
from mattergen_wrapper import generator
|
||||
CrystalGenerator = generator.CrystalGenerator
|
||||
from mattergen.common.data.types import TargetProperty
|
||||
from mattergen.common.utils.eval_utils import MatterGenCheckpointInfo
|
||||
from mattergen.common.utils.data_classes import (
|
||||
PRETRAINED_MODEL_NAME,
|
||||
MatterGenCheckpointInfo,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
from tools_for_ms.services_tools.Configs import *
|
||||
from tools_for_ms.llm_tools import llm_tool
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def format_cif_content(content):
|
||||
"""
|
||||
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
||||
|
||||
Args:
|
||||
content: String containing CIF content, possibly with PK headers
|
||||
|
||||
Returns:
|
||||
Formatted string with each CIF file properly labeled and formatted
|
||||
"""
|
||||
# 如果内容为空,直接返回空字符串
|
||||
if not content or content.strip() == '':
|
||||
return ''
|
||||
|
||||
# 删除从PK开始到第一个_chemical_formula_structural之前的所有内容
|
||||
content = re.sub(r'PK.*?(?=_chemical_formula_structural)', '', content, flags=re.DOTALL)
|
||||
|
||||
# 删除从PK开始到字符串结束且没有_chemical_formula_structural的内容
|
||||
content = re.sub(r'PK[^_]*$', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'PK.*?(?!.*_chemical_formula_structural)$', '', content, flags=re.DOTALL)
|
||||
|
||||
# 使用_chemical_formula_structural作为分隔符来分割不同的CIF文件
|
||||
# 但我们需要保留这个字段在每个CIF文件中
|
||||
cif_blocks = []
|
||||
|
||||
# 查找所有_chemical_formula_structural的位置
|
||||
formula_positions = [m.start() for m in re.finditer(r'_chemical_formula_structural', content)]
|
||||
|
||||
# 如果没有找到任何_chemical_formula_structural,返回空字符串
|
||||
if not formula_positions:
|
||||
return ''
|
||||
|
||||
# 分割CIF块
|
||||
for i in range(len(formula_positions)):
|
||||
start_pos = formula_positions[i]
|
||||
# 如果是最后一个块,结束位置是字符串末尾
|
||||
end_pos = formula_positions[i+1] if i < len(formula_positions)-1 else len(content)
|
||||
|
||||
cif_block = content[start_pos:end_pos].strip()
|
||||
|
||||
# 提取formula值
|
||||
formula_match = re.search(r'_chemical_formula_structural\s+(\S+)', cif_block)
|
||||
if formula_match:
|
||||
formula = formula_match.group(1)
|
||||
cif_blocks.append((formula, cif_block))
|
||||
|
||||
# 格式化输出
|
||||
result = []
|
||||
for i, (formula, cif_content) in enumerate(cif_blocks, 1):
|
||||
formatted = f"[cif {i} begin]\ndata_{formula}\n{cif_content}\n[cif {i} end]"
|
||||
result.append(formatted)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
|
||||
def convert_values(data_str):
|
||||
# 将字符串转换为字典
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return data_str # 如果无法解析为JSON,返回原字符串
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def preprocess_property(property_name: str, property_value: Union[str, float, int]) -> Tuple[str, Any]:
|
||||
"""
|
||||
Preprocess a property value based on its name, converting it to the appropriate type.
|
||||
|
||||
Args:
|
||||
property_name: Name of the property
|
||||
property_value: Value of the property (can be string, float, or int)
|
||||
|
||||
Returns:
|
||||
Tuple of (property_name, processed_value)
|
||||
|
||||
Raises:
|
||||
ValueError: If the property value is invalid for the given property name
|
||||
"""
|
||||
valid_properties = [
|
||||
"dft_mag_density", "dft_bulk_modulus", "dft_shear_modulus",
|
||||
"energy_above_hull", "formation_energy_per_atom", "space_group",
|
||||
"hhi_score", "ml_bulk_modulus", "chemical_system", "dft_band_gap"
|
||||
]
|
||||
|
||||
if property_name not in valid_properties:
|
||||
raise ValueError(f"Invalid property_name: {property_name}. Must be one of: {', '.join(valid_properties)}")
|
||||
|
||||
# Process property_value if it's a string
|
||||
if isinstance(property_value, str):
|
||||
try:
|
||||
# Try to convert string to float for numeric properties
|
||||
if property_name != "chemical_system":
|
||||
property_value = float(property_value)
|
||||
except ValueError:
|
||||
# If conversion fails, keep as string (for chemical_system)
|
||||
pass
|
||||
|
||||
# Handle special cases for properties that need specific types
|
||||
if property_name == "chemical_system":
|
||||
if isinstance(property_value, (int, float)):
|
||||
logger.warning(f"Converting numeric property_value {property_value} to string for chemical_system property")
|
||||
property_value = str(property_value)
|
||||
elif property_name == "space_group" :
|
||||
space_group = property_value
|
||||
if space_group < 1 or space_group > 230:
|
||||
raise ValueError(f"Invalid space_group value: {space_group}. Must be an integer between 1 and 230.")
|
||||
|
||||
return property_name, property_value
|
||||
|
||||
|
||||
def main(
|
||||
output_path: str,
|
||||
pretrained_name: PRETRAINED_MODEL_NAME | None = None,
|
||||
model_path: str | None = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
config_overrides: list[str] | None = None,
|
||||
checkpoint_epoch: Literal["best", "last"] | int = "last",
|
||||
properties_to_condition_on: TargetProperty | None = None,
|
||||
sampling_config_path: str | None = None,
|
||||
sampling_config_name: str = "default",
|
||||
sampling_config_overrides: list[str] | None = None,
|
||||
record_trajectories: bool = True,
|
||||
diffusion_guidance_factor: float | None = None,
|
||||
strict_checkpoint_loading: bool = True,
|
||||
target_compositions: list[dict[str, int]] | None = None,
|
||||
):
|
||||
"""
|
||||
Evaluate diffusion model against molecular metrics.
|
||||
|
||||
Args:
|
||||
model_path: Path to DiffusionLightningModule checkpoint directory.
|
||||
output_path: Path to output directory.
|
||||
config_overrides: Overrides for the model config, e.g., `model.num_layers=3 model.hidden_dim=128`.
|
||||
properties_to_condition_on: Property value to draw conditional sampling with respect to. When this value is an empty dictionary (default), unconditional samples are drawn.
|
||||
sampling_config_path: Path to the sampling config file. (default: None, in which case we use `DEFAULT_SAMPLING_CONFIG_PATH` from explorers.common.utils.utils.py)
|
||||
sampling_config_name: Name of the sampling config (corresponds to `{sampling_config_path}/{sampling_config_name}.yaml` on disk). (default: default)
|
||||
sampling_config_overrides: Overrides for the sampling config, e.g., `condition_loader_partial.batch_size=32`.
|
||||
load_epoch: Epoch to load from the checkpoint. If None, the best epoch is loaded. (default: None)
|
||||
record: Whether to record the trajectories of the generated structures. (default: True)
|
||||
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
|
||||
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
|
||||
Only supported for models trained for crystal structure prediction (CSP) (default: None)
|
||||
|
||||
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
|
||||
"""
|
||||
assert (
|
||||
pretrained_name is not None or model_path is not None
|
||||
), "Either pretrained_name or model_path must be provided."
|
||||
assert (
|
||||
pretrained_name is None or model_path is None
|
||||
), "Only one of pretrained_name or model_path can be provided."
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
sampling_config_overrides = sampling_config_overrides or []
|
||||
config_overrides = config_overrides or []
|
||||
properties_to_condition_on = properties_to_condition_on or {}
|
||||
target_compositions = target_compositions or []
|
||||
|
||||
if pretrained_name is not None:
|
||||
checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
|
||||
pretrained_name, config_overrides=config_overrides
|
||||
)
|
||||
else:
|
||||
checkpoint_info = MatterGenCheckpointInfo(
|
||||
model_path=Path(model_path).resolve(),
|
||||
load_epoch=checkpoint_epoch,
|
||||
config_overrides=config_overrides,
|
||||
strict_checkpoint_loading=strict_checkpoint_loading,
|
||||
)
|
||||
_sampling_config_path = Path(sampling_config_path) if sampling_config_path is not None else None
|
||||
generator = CrystalGenerator(
|
||||
checkpoint_info=checkpoint_info,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
sampling_config_name=sampling_config_name,
|
||||
sampling_config_path=_sampling_config_path,
|
||||
sampling_config_overrides=sampling_config_overrides,
|
||||
record_trajectories=record_trajectories,
|
||||
diffusion_guidance_factor=(
|
||||
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
|
||||
),
|
||||
target_compositions_dict=target_compositions,
|
||||
)
|
||||
generator.generate(output_dir=Path(output_path))
|
||||
|
||||
|
||||
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
|
||||
async def generate_material(
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
diffusion_guidance_factor: float = 2.0
|
||||
) -> str:
|
||||
"""
|
||||
Generate crystal structures with optional property constraints.
|
||||
|
||||
This unified function can generate materials in three modes:
|
||||
1. Unconditional generation (no properties specified)
|
||||
2. Single property conditional generation (one property specified)
|
||||
3. Multi-property conditional generation (multiple properties specified)
|
||||
|
||||
Args:
|
||||
properties: Optional property constraints. Can be:
|
||||
- None or empty dict for unconditional generation
|
||||
- Dict with single key-value pair for single property conditioning
|
||||
- Dict with multiple key-value pairs for multi-property conditioning
|
||||
Valid property names include: "dft_band_gap", "chemical_system", etc.
|
||||
batch_size: Number of structures per batch
|
||||
num_batches: Number of batches to generate
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
# Use the configured results directory
|
||||
output_dir = MATTERGENMODEL_RESULT_PATH
|
||||
|
||||
# Handle string input if provided
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid properties JSON string: {properties}")
|
||||
|
||||
# Default to empty dict if None
|
||||
properties = properties or {}
|
||||
|
||||
# Process properties based on generation mode
|
||||
if not properties:
|
||||
# Unconditional generation
|
||||
model_path = os.path.join(MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
properties_to_condition_on = None
|
||||
generation_type = "unconditional"
|
||||
property_description = "unconditionally"
|
||||
else:
|
||||
# Conditional generation (single or multi-property)
|
||||
properties_to_condition_on = {}
|
||||
|
||||
# Process each property
|
||||
for property_name, property_value in properties.items():
|
||||
_, processed_value = preprocess_property(property_name, property_value)
|
||||
properties_to_condition_on[property_name] = processed_value
|
||||
|
||||
# Determine which model to use based on properties
|
||||
if len(properties) == 1:
|
||||
# Single property conditioning
|
||||
property_name = list(properties.keys())[0]
|
||||
property_to_model = {
|
||||
"dft_mag_density": "dft_mag_density",
|
||||
"dft_bulk_modulus": "dft_bulk_modulus",
|
||||
"dft_shear_modulus": "dft_shear_modulus",
|
||||
"energy_above_hull": "energy_above_hull",
|
||||
"formation_energy_per_atom": "formation_energy_per_atom",
|
||||
"space_group": "space_group",
|
||||
"hhi_score": "hhi_score",
|
||||
"ml_bulk_modulus": "ml_bulk_modulus",
|
||||
"chemical_system": "chemical_system",
|
||||
"dft_band_gap": "dft_band_gap"
|
||||
}
|
||||
model_dir = property_to_model.get(property_name, property_name)
|
||||
generation_type = "single_property"
|
||||
property_description = f"conditioned on {property_name} = {properties[property_name]}"
|
||||
else:
|
||||
# Multi-property conditioning
|
||||
property_keys = set(properties.keys())
|
||||
if property_keys == {"dft_mag_density", "hhi_score"}:
|
||||
model_dir = "dft_mag_density_hhi_score"
|
||||
elif property_keys == {"chemical_system", "energy_above_hull"}:
|
||||
model_dir = "chemical_system_energy_above_hull"
|
||||
else:
|
||||
# If no specific multi-property model exists, use the first property's model
|
||||
first_property = list(properties.keys())[0]
|
||||
model_dir = first_property
|
||||
generation_type = "multi_property"
|
||||
property_description = f"conditioned on multiple properties: {', '.join([f'{name} = {value}' for name, value in properties.items()])}"
|
||||
|
||||
# Construct the full model path
|
||||
model_path = os.path.join(MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
# Check if the model directory exists
|
||||
if not os.path.exists(model_path):
|
||||
# Fallback to base model if specific model doesn't exist
|
||||
logger.warning(f"Model directory for {model_dir} not found. Using base model instead.")
|
||||
model_path = os.path.join(MATTERGENMODEL_ROOT, "mattergen_base")
|
||||
|
||||
# Call the main function with appropriate parameters
|
||||
main(
|
||||
output_path=output_dir,
|
||||
model_path=model_path,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
properties_to_condition_on=properties_to_condition_on,
|
||||
record_trajectories=True,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor if properties else 0.0
|
||||
)
|
||||
|
||||
# Create a dictionary to store the file contents
|
||||
result_dict = {}
|
||||
|
||||
# Define file paths
|
||||
cif_zip_path = os.path.join(output_dir, "generated_crystals_cif.zip")
|
||||
xyz_file_path = os.path.join(output_dir, "generated_crystals.extxyz")
|
||||
trajectories_zip_path = os.path.join(output_dir, "generated_trajectories.zip")
|
||||
|
||||
# Read the CIF zip file
|
||||
if os.path.exists(cif_zip_path):
|
||||
with open(cif_zip_path, 'rb') as f:
|
||||
result_dict['cif_content'] = f.read()
|
||||
|
||||
# Create a descriptive prompt based on generation type
|
||||
if generation_type == "unconditional":
|
||||
title = "Generated Material Structures"
|
||||
description = "These structures were generated unconditionally, meaning no specific properties were targeted."
|
||||
elif generation_type == "single_property":
|
||||
property_name = list(properties.keys())[0]
|
||||
property_value = properties[property_name]
|
||||
title = f"Generated Material Structures Conditioned on {property_name} = {property_value}"
|
||||
description = f"These structures were generated with property conditioning, targeting a {property_name} value of {property_value}."
|
||||
else: # multi_property
|
||||
title = "Generated Material Structures Conditioned on Multiple Properties"
|
||||
description = "These structures were generated with multi-property conditioning, targeting the specified property values."
|
||||
|
||||
# Create the full prompt
|
||||
prompt = f"""
|
||||
# {title}
|
||||
|
||||
This data contains {batch_size * num_batches} crystal structures generated by the MatterGen model, {property_description}.
|
||||
|
||||
{'' if generation_type == 'unconditional' else f'''
|
||||
A diffusion guidance factor of {diffusion_guidance_factor} was used, which controls how strongly
|
||||
the generation adheres to the specified property values. Higher values produce samples that more
|
||||
closely match the target properties but may reduce diversity.
|
||||
'''}
|
||||
|
||||
## CIF Files (Crystallographic Information Files)
|
||||
|
||||
- Standard format for crystallographic structures
|
||||
- Contains unit cell parameters, atomic positions, and symmetry information
|
||||
- Used by crystallographic software and visualization tools
|
||||
|
||||
```
|
||||
{format_cif_content(result_dict.get('cif_content', b'').decode('utf-8', errors='replace') if isinstance(result_dict.get('cif_content', b''), bytes) else str(result_dict.get('cif_content', '')))}
|
||||
```
|
||||
|
||||
{description}
|
||||
You can use these structures for materials discovery, property prediction, or further analysis.
|
||||
"""
|
||||
|
||||
# Clean up the files (delete them after reading)
|
||||
try:
|
||||
if os.path.exists(cif_zip_path):
|
||||
os.remove(cif_zip_path)
|
||||
if os.path.exists(xyz_file_path):
|
||||
os.remove(xyz_file_path)
|
||||
if os.path.exists(trajectories_zip_path):
|
||||
os.remove(trajectories_zip_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up files: {e}")
|
||||
return prompt
|
||||
65
tools_for_ms/services_tools/mattersim_tools.py
Normal file
65
tools_for_ms/services_tools/mattersim_tools.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from ..llm_tools import llm_tool
|
||||
import torch
|
||||
import numpy as np
|
||||
from ase.units import GPa
|
||||
from mattersim.forcefield import MatterSimCalculator
|
||||
import asyncio
|
||||
from .fairchem_tools import convert_structure
|
||||
|
||||
@llm_tool(
|
||||
name="predict_properties",
|
||||
description="Predict energy, forces, and stress of crystal structures based on CIF string",
|
||||
)
|
||||
async def predict_properties(cif_content: str) -> str:
|
||||
"""
|
||||
Use MatterSim to predict energy, forces, and stress of crystal structures.
|
||||
|
||||
Args:
|
||||
cif_content: Crystal structure string in CIF format
|
||||
|
||||
Returns:
|
||||
String containing prediction results
|
||||
"""
|
||||
# 使用asyncio.to_thread异步执行可能阻塞的操作
|
||||
def run_prediction():
|
||||
# 使用 convert_structure 函数将 CIF 字符串转换为 Atoms 对象
|
||||
structure = convert_structure("cif", cif_content)
|
||||
if structure is None:
|
||||
return "Unable to parse CIF string. Please check if the format is correct."
|
||||
|
||||
# 设置设备
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 使用 MatterSimCalculator 计算属性
|
||||
structure.calc = MatterSimCalculator(device=device)
|
||||
|
||||
# 直接获取能量、力和应力
|
||||
energy = structure.get_potential_energy()
|
||||
forces = structure.get_forces()
|
||||
stresses = structure.get_stress(voigt=False)
|
||||
|
||||
# 计算每原子能量
|
||||
num_atoms = len(structure)
|
||||
energy_per_atom = energy / num_atoms
|
||||
|
||||
# 计算应力(GPa和eV/A^3格式)
|
||||
stresses_ev_a3 = stresses
|
||||
stresses_gpa = stresses / GPa
|
||||
|
||||
# 构建返回的提示信息
|
||||
prompt = f"""
|
||||
## {structure.get_chemical_formula()} Crystal Structure Property Prediction Results
|
||||
|
||||
Prediction results using the provided CIF structure:
|
||||
|
||||
- Total Energy (eV): {energy}
|
||||
- Energy per Atom (eV/atom): {energy_per_atom:.4f}
|
||||
- Forces (eV/Angstrom): {forces[0]} # Forces on the first atom
|
||||
- Stress (GPa): {stresses_gpa[0][0]} # First component of the stress tensor
|
||||
- Stress (eV/A^3): {stresses_ev_a3[0][0]} # First component of the stress tensor
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
# 异步执行预测操作
|
||||
return await asyncio.to_thread(run_prediction)
|
||||
644
tools_for_ms/services_tools/mp_tools.py
Normal file
644
tools_for_ms/services_tools/mp_tools.py
Normal file
@@ -0,0 +1,644 @@
|
||||
"""
|
||||
Materials Project API Service Tools
|
||||
|
||||
This module provides functions for querying the Materials Project database,
|
||||
processing search results, and formatting responses. It includes a LLM tool
|
||||
for integration with large language models.
|
||||
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import datetime
|
||||
import os
|
||||
from multiprocessing import Process, Manager
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from mp_api.client import MPRester
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
||||
from pymatgen.io.cif import CifWriter
|
||||
from ..services_tools import Configs
|
||||
from ..utils import settings, handle_minio_upload
|
||||
from .error_handlers import handle_general_error
|
||||
from ..llm_tools import llm_tool
|
||||
def read_cif_txt_file(file_path):
|
||||
"""Read the markdown file and return its content."""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def get_extra_cif_info(path: str, fields_name: list):
|
||||
"""Extract specific fields from the CIF description."""
|
||||
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
|
||||
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
|
||||
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
|
||||
# metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"]
|
||||
|
||||
selected_fields = []
|
||||
if fields_name[0] == 'all_fields':
|
||||
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
|
||||
# selected_fields = energy_electronic_fields + metal_magentic_fields
|
||||
else:
|
||||
for field in fields_name:
|
||||
selected_fields.extend(locals().get(field, []))
|
||||
|
||||
with open(path, 'r') as f:
|
||||
docs = json.load(f)
|
||||
|
||||
new_docs = {}
|
||||
for field_name in selected_fields:
|
||||
new_docs[field_name] = docs.get(field_name, '')
|
||||
|
||||
# new_docs['structure'] = {"lattice": docs['structure']['lattice']}
|
||||
return new_docs
|
||||
|
||||
def remove_symmetry_equiv_xyz(cif_content):
|
||||
"""
|
||||
Remove symmetry operations section from CIF file content
|
||||
|
||||
Args:
|
||||
cif_content: CIF file content string
|
||||
|
||||
Returns:
|
||||
Cleaned CIF content string
|
||||
"""
|
||||
lines = cif_content.split('\n')
|
||||
output_lines = []
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# 检测循环开始
|
||||
if line == 'loop_':
|
||||
# 查看下一行,检查是否是对称性循环
|
||||
next_lines = []
|
||||
j = i + 1
|
||||
while j < len(lines) and lines[j].strip().startswith('_'):
|
||||
next_lines.append(lines[j].strip())
|
||||
j += 1
|
||||
|
||||
# 检查是否包含对称性操作标签
|
||||
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
|
||||
# 跳过整个循环块
|
||||
while i < len(lines):
|
||||
if i + 1 >= len(lines):
|
||||
break
|
||||
|
||||
next_line = lines[i + 1].strip()
|
||||
# 检查是否到达下一个循环或数据块
|
||||
if next_line == 'loop_' or next_line.startswith('data_'):
|
||||
break
|
||||
|
||||
# 检查是否到达原子位置部分
|
||||
if next_line.startswith('_atom_site_'):
|
||||
break
|
||||
|
||||
i += 1
|
||||
else:
|
||||
# 不是对称性循环,保留loop_行
|
||||
output_lines.append(lines[i])
|
||||
else:
|
||||
# 非循环开始行,直接保留
|
||||
output_lines.append(lines[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_bool(param: str) -> bool | None:
|
||||
"""
|
||||
Parse a string parameter into a boolean value.
|
||||
|
||||
Args:
|
||||
param: String parameter to parse (e.g., "true", "false")
|
||||
|
||||
Returns:
|
||||
Boolean value if param is not empty, None otherwise
|
||||
"""
|
||||
if not param:
|
||||
return None
|
||||
return param.lower() == 'true'
|
||||
|
||||
def parse_list(param: str) -> List[str] | None:
|
||||
"""
|
||||
Parse a comma-separated string into a list of strings.
|
||||
|
||||
Args:
|
||||
param: Comma-separated string (e.g., "Li,Fe,O")
|
||||
|
||||
Returns:
|
||||
List of strings if param is not empty, None otherwise
|
||||
"""
|
||||
if not param:
|
||||
return None
|
||||
return param.split(',')
|
||||
|
||||
def parse_tuple(param: str) -> tuple[float, float] | None:
|
||||
"""
|
||||
Parse a comma-separated string into a tuple of two float values.
|
||||
|
||||
Used for range parameters like band_gap, density, etc.
|
||||
|
||||
Args:
|
||||
param: Comma-separated string of two numbers (e.g., "0,3.5")
|
||||
|
||||
Returns:
|
||||
Tuple of two float values if param is valid, None otherwise
|
||||
"""
|
||||
if not param:
|
||||
return None
|
||||
try:
|
||||
values = param.split(',')
|
||||
return (float(values[0]), float(values[1]))
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def parse_search_parameters(query_params: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse search parameters from query parameters.
|
||||
|
||||
Converts string query parameters into appropriate types for the Materials Project API.
|
||||
"""
|
||||
return {
|
||||
'band_gap': parse_tuple(query_params.get('band_gap')),
|
||||
'chemsys': parse_list(query_params.get('chemsys')),
|
||||
'crystal_system': parse_list(query_params.get('crystal_system')),
|
||||
'density': parse_tuple(query_params.get('density')),
|
||||
'formation_energy': parse_tuple(query_params.get('formation_energy')),
|
||||
'elements': parse_list(query_params.get('elements')),
|
||||
'exclude_elements': parse_list(query_params.get('exclude_elements')),
|
||||
'formula': parse_list(query_params.get('formula')),
|
||||
'is_gap_direct': parse_bool(query_params.get('is_gap_direct')),
|
||||
'is_metal': parse_bool(query_params.get('is_metal')),
|
||||
'is_stable': parse_bool(query_params.get('is_stable')),
|
||||
'magnetic_ordering': query_params.get('magnetic_ordering'),
|
||||
'material_ids': parse_list(query_params.get('material_ids')),
|
||||
'total_energy': parse_tuple(query_params.get('total_energy')),
|
||||
'num_elements': parse_tuple(query_params.get('num_elements')),
|
||||
'volume': parse_tuple(query_params.get('volume')),
|
||||
'chunk_size': int(query_params.get('chunk_size', '5'))
|
||||
}
|
||||
|
||||
def process_search_results(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]] | str:
|
||||
"""
|
||||
Process search results from the Materials Project API.
|
||||
|
||||
Extracts relevant fields from each document and formats them into a consistent structure.
|
||||
|
||||
Returns:
|
||||
List of processed documents or error message string if an exception occurs
|
||||
"""
|
||||
try:
|
||||
fields = [
|
||||
'formula_pretty', 'nsites', 'nelements', 'material_id', 'chemsys',
|
||||
'volume', 'density', 'density_atomic', 'cbm', 'vbm', 'band_gap',
|
||||
'is_gap_direct', 'is_stable', 'formation_energy_per_atom',
|
||||
'energy_above_hull', 'is_metal', 'total_magnetization', 'efermi',
|
||||
'is_magnetic', 'ordering', 'bulk_modulus', 'shear_modulus',
|
||||
'universal_anisotropy', 'theoretical'
|
||||
]
|
||||
|
||||
res = []
|
||||
for doc in docs:
|
||||
try:
|
||||
new_docs = {}
|
||||
for field_name in fields:
|
||||
new_docs[field_name] = doc.get(field_name, '')
|
||||
res.append(new_docs)
|
||||
except Exception as e:
|
||||
# logger.warning(f"Error processing document: {str(e)}")
|
||||
continue
|
||||
return res
|
||||
except Exception as e:
|
||||
error_msg = f"Error in process_search_results: {str(e)}"
|
||||
# logger.error(error_msg)
|
||||
import traceback
|
||||
# logger.error(traceback.format_exc())
|
||||
return error_msg
|
||||
|
||||
|
||||
def _search_worker(queue, api_key, **kwargs):
|
||||
"""
|
||||
Worker function for executing Materials Project API searches.
|
||||
|
||||
Runs in a separate process to perform the actual API call and puts results in the queue.
|
||||
|
||||
Args:
|
||||
queue: Multiprocessing queue for returning results
|
||||
api_key: Materials Project API key
|
||||
**kwargs: Search parameters to pass to the API
|
||||
"""
|
||||
try:
|
||||
import os
|
||||
import traceback
|
||||
os.environ['HTTP_PROXY'] = Configs.HTTP_PROXY or ''
|
||||
os.environ['HTTPS_PROXY'] = Configs.HTTPS_PROXY or ''
|
||||
|
||||
|
||||
# 初始化 MPRester 客户端
|
||||
with MPRester(api_key) as mpr:
|
||||
# print(f"MPRester initialized with endpoint:")
|
||||
|
||||
|
||||
# print("Executing search...")
|
||||
result = mpr.materials.summary.search(**kwargs)
|
||||
# print(f"Search completed, result type: {type(result)}")
|
||||
|
||||
# 检查结果
|
||||
if result:
|
||||
# print(f"Number of results: {len(result)}")
|
||||
# print(f"First result type: {type(result[0])}")
|
||||
|
||||
# 尝试使用更安全的方式处理结果
|
||||
processed_results = []
|
||||
for doc in result:
|
||||
try:
|
||||
# 尝试使用 model_dump 方法
|
||||
processed_doc = doc.model_dump()
|
||||
processed_results.append(processed_doc)
|
||||
except AttributeError:
|
||||
# 如果没有 model_dump 方法,尝试使用 dict 方法
|
||||
try:
|
||||
processed_doc = doc.dict()
|
||||
processed_results.append(processed_doc)
|
||||
except AttributeError:
|
||||
# 如果没有 dict 方法,尝试直接转换为字典
|
||||
if hasattr(doc, "__dict__"):
|
||||
processed_doc = doc.__dict__
|
||||
# 移除可能导致序列化问题的特殊属性
|
||||
if "_sa_instance_state" in processed_doc:
|
||||
del processed_doc["_sa_instance_state"]
|
||||
processed_results.append(processed_doc)
|
||||
else:
|
||||
# 最后的尝试,直接使用 doc
|
||||
processed_results.append(doc)
|
||||
|
||||
# print(f"Processed {len(processed_results)} results")
|
||||
queue.put(processed_results)
|
||||
else:
|
||||
# print("No results found")
|
||||
queue.put([])
|
||||
except Exception as e:
|
||||
# print(f"Error in _search_worker: {str(e)}")
|
||||
# print(traceback.format_exc())
|
||||
queue.put(e)
|
||||
|
||||
|
||||
async def execute_search(search_args: Dict[str, Any], timeout: int = 120) -> List[Dict[str, Any]] | str:
|
||||
"""
|
||||
Execute a search against the Materials Project API.
|
||||
|
||||
Runs the search in a separate process to handle potential timeouts and returns the results.
|
||||
|
||||
Args:
|
||||
search_args: Dictionary of search parameters
|
||||
timeout: Maximum time in seconds to wait for the search to complete
|
||||
|
||||
Returns:
|
||||
List of document dictionaries from the search results or error message string if an exception occurs
|
||||
"""
|
||||
# print(f"Starting execute_search with args: {search_args}")
|
||||
|
||||
# 确保 formula 参数是列表类型
|
||||
if 'formula' in search_args and isinstance(search_args['formula'], str):
|
||||
search_args['formula'] = [search_args['formula']]
|
||||
# print(f"Converted formula to list in execute_search: {search_args['formula']}")
|
||||
|
||||
manager = Manager()
|
||||
queue = manager.Queue()
|
||||
|
||||
try:
|
||||
p = Process(target=_search_worker, args=(queue, Configs.MP_API_KEY), kwargs=search_args)
|
||||
|
||||
p.start()
|
||||
|
||||
# logger.info(f"Started worker process with PID: {p.pid}")
|
||||
# print(f"Waiting for process {p.pid} to complete (timeout: {timeout}s)...")
|
||||
p.join(timeout=timeout)
|
||||
|
||||
if p.is_alive():
|
||||
# logger.warning(f"Terminating worker process {p.pid} due to timeout")
|
||||
# print(f"Process {p.pid} timed out, terminating...")
|
||||
p.terminate()
|
||||
p.join()
|
||||
error_msg = f"Request timed out after {timeout} seconds"
|
||||
return error_msg
|
||||
|
||||
# print("Process completed, retrieving results from queue...")
|
||||
try:
|
||||
if queue.empty():
|
||||
# logger.warning("Queue is empty after process completion")
|
||||
# print("Warning: Queue is empty after process completion")
|
||||
pass
|
||||
else:
|
||||
# logger.info("Queue contains data, retrieving...")
|
||||
# print("Queue contains data, retrieving...")
|
||||
pass
|
||||
|
||||
result = queue.get(timeout=timeout)
|
||||
# print(f"Result type: {type(result)}")
|
||||
|
||||
if isinstance(result, Exception):
|
||||
# logger.error(f"Error in search worker: {str(result)}")
|
||||
# print(f"Error in search worker: {str(result)}")
|
||||
# 尝试获取更详细的错误信息
|
||||
if hasattr(result, "__traceback__"):
|
||||
import traceback
|
||||
tb_str = ''.join(traceback.format_exception(None, result, result.__traceback__))
|
||||
# print(f"Error traceback: {tb_str}")
|
||||
return f"Error in search worker: {str(result)}"
|
||||
|
||||
if isinstance(result, list):
|
||||
# print(f"Successfully retrieved {len(result)} documents")
|
||||
# logger.info(f"Successfully retrieved {len(result)} documents")
|
||||
pass
|
||||
else:
|
||||
# print(f"Result is not a list, but {type(result)}")
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
except queue.Empty:
|
||||
error_msg = "Failed to retrieve data from queue (timeout)"
|
||||
# logger.error(error_msg)
|
||||
# print(error_msg)
|
||||
return error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Error in execute_search: {str(e)}"
|
||||
# logger.error(error_msg)
|
||||
# print(error_msg)
|
||||
import traceback
|
||||
# print(traceback.format_exc())
|
||||
return error_msg
|
||||
|
||||
@llm_tool(name="search_material_property_from_material_project", description="Search materials in Materials Project database by formula and properties")
|
||||
async def search_material_property_from_material_project(
|
||||
formula: str | list[str],
|
||||
chemsys: Optional[str | list[str] | None] = None,
|
||||
crystal_system: Optional[str | list[str] | None] = None,
|
||||
is_gap_direct: Optional[bool | None] = None,
|
||||
is_stable: Optional[bool | None] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search materials in Materials Project database.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula(s) (e.g., "Fe2O3" or ["ABO3", "Si*"])
|
||||
chemsys: Chemical system(s) (e.g., "Li-Fe-O")
|
||||
crystal_system: Crystal system(s) (e.g., "Cubic")
|
||||
is_gap_direct: Filter for direct band gap materials
|
||||
is_stable: Filter for thermodynamically stable materials
|
||||
Returns:
|
||||
JSON formatted material properties data
|
||||
"""
|
||||
# print(f"search_material_property_from_material_project called with formula: {formula}, type: {type(formula)}")
|
||||
|
||||
# 验证晶系参数
|
||||
VALID_CRYSTAL_SYSTEMS = ['Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal', 'Cubic']
|
||||
|
||||
# 验证晶系参数是否有效
|
||||
if crystal_system is not None:
|
||||
if isinstance(crystal_system, str):
|
||||
if crystal_system not in VALID_CRYSTAL_SYSTEMS:
|
||||
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
|
||||
elif isinstance(crystal_system, list):
|
||||
for cs in crystal_system:
|
||||
if cs not in VALID_CRYSTAL_SYSTEMS:
|
||||
return "Input should be 'Triclinic', 'Monoclinic', 'Orthorhombic', 'Tetragonal', 'Trigonal', 'Hexagonal' or 'Cubic'"
|
||||
|
||||
# 确保 formula 是列表类型
|
||||
if isinstance(formula, str):
|
||||
formula = [formula]
|
||||
# print(f"Converted formula to list: {formula}")
|
||||
|
||||
|
||||
|
||||
params = {
|
||||
"chemsys": chemsys,
|
||||
"crystal_system": crystal_system,
|
||||
"formula": formula,
|
||||
"is_gap_direct": is_gap_direct,
|
||||
"is_stable": is_stable,
|
||||
"chunk_size": 5,
|
||||
|
||||
}
|
||||
|
||||
# Filter out None values
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
# print("Parameters after filtering:", params)
|
||||
mp_id_list = await get_mpid_from_formula(formula=formula)
|
||||
try:
|
||||
res=[]
|
||||
# Execute search against Materials Project API
|
||||
#docs = await execute_search(params)
|
||||
# for mp_id in id_list:
|
||||
for mp_id in mp_id_list:
|
||||
crystal_props = get_extra_cif_info(Configs.LOCAL_MP_PROPERTY_ROOT+f"{mp_id}.json", ['all_fields'])
|
||||
res.append(crystal_props)
|
||||
|
||||
|
||||
|
||||
#res = process_search_results(docs)
|
||||
# print(f"Processed {len(res)} results")
|
||||
|
||||
if len(res) == 0:
|
||||
# print("No results found")
|
||||
return "No results found, please try again."
|
||||
|
||||
# Format response with top results
|
||||
# print(f"Formatting top {Configs.MP_TOPK} results")
|
||||
try:
|
||||
# 创建包含索引的JSON结果
|
||||
formatted_results = []
|
||||
for i, item in enumerate(res[:Configs.MP_TOPK], 1):
|
||||
formatted_result = f"[property {i} begin]\n"
|
||||
formatted_result += json.dumps(item, indent=2)
|
||||
formatted_result += f"\n[property {i} end]\n\n"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
# 将所有结果合并为一个字符串
|
||||
res_chunk = "\n\n".join(formatted_results)
|
||||
res_template = f"""
|
||||
Here are the search results from the Materials Project database:
|
||||
Due to length limitations, only the top {Configs.MP_TOPK} results are shown below:\n
|
||||
{res_chunk}
|
||||
If you need more results, please modify your search criteria or try different query parameters.
|
||||
"""
|
||||
# print("Successfully formatted results")
|
||||
return res_template
|
||||
except Exception as format_error:
|
||||
# print(f"Error formatting results: {str(format_error)}")
|
||||
import traceback
|
||||
# print(traceback.format_exc())
|
||||
return str(format_error)
|
||||
|
||||
except Exception as e:
|
||||
# print(f"Error in search_material_property_from_material_project: {str(e)}")
|
||||
import traceback
|
||||
# print(traceback.format_exc())
|
||||
return str(e)
|
||||
|
||||
|
||||
@llm_tool(name="get_crystal_structures_from_materials_project", description="Get symmetrized crystal structures CIF data from Materials Project database by chemical formula")
|
||||
async def get_crystal_structures_from_materials_project(
|
||||
formulas: list[str],
|
||||
conventional_unit_cell: bool = True,
|
||||
symprec: float = 0.1
|
||||
) -> str:
|
||||
"""
|
||||
Get crystal structures from Materials Project database by chemical formula and apply symmetrization.
|
||||
|
||||
Args:
|
||||
formulas: List of chemical formulas (e.g., ["Fe2O3", "SiO2", "TiO2"])
|
||||
conventional_unit_cell: Whether to return conventional unit cell (True) or primitive cell (False)
|
||||
symprec: Precision parameter for symmetrization
|
||||
|
||||
Returns:
|
||||
Formatted text containing symmetrized CIF data
|
||||
"""
|
||||
# 确保 formulas 是列表类型
|
||||
# if isinstance(formulas, str):
|
||||
# formulas = [formulas]
|
||||
|
||||
# try:
|
||||
# # 构建搜索参数
|
||||
# search_args = {
|
||||
# "formula": formulas,
|
||||
# "fields": ["material_id",]
|
||||
# }
|
||||
|
||||
# # 使用execute_search函数查询晶体结构信息
|
||||
# docs = await execute_search(search_args, timeout=60)
|
||||
|
||||
# if isinstance(docs, str):
|
||||
# # 如果返回的是字符串,说明发生了错误
|
||||
# return f"获取晶体结构时出错: {docs}"
|
||||
|
||||
# if not docs:
|
||||
# return "未找到指定化学式的晶体结构数据。"
|
||||
|
||||
# 处理结果
|
||||
# result = {}
|
||||
# for i, doc in enumerate(docs):
|
||||
# try:
|
||||
# # 获取材料ID和结构
|
||||
# material_id = doc.get('material_id')
|
||||
# structure_data = doc.get('structure')
|
||||
|
||||
# if not structure_data:
|
||||
# continue
|
||||
|
||||
# # 将结构数据转换为pymatgen Structure对象
|
||||
result={}
|
||||
mp_id_list=await get_mpid_from_formula(formula=formulas)
|
||||
|
||||
|
||||
for i,mp_id in enumerate(mp_id_list):
|
||||
cif_file = glob.glob(f"/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/MPDatasets/{mp_id}.cif")[0]
|
||||
#print('111',cif_file)
|
||||
structure = Structure.from_file(cif_file)
|
||||
# 如果需要常规单元格
|
||||
if conventional_unit_cell:
|
||||
structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure()
|
||||
|
||||
# 对结构进行对称化处理
|
||||
sga = SpacegroupAnalyzer(structure, symprec=symprec)
|
||||
symmetrized_structure = sga.get_refined_structure()
|
||||
|
||||
# 使用CifWriter生成CIF数据
|
||||
cif_writer = CifWriter(symmetrized_structure, symprec=symprec, refine_struct=True)
|
||||
cif_data = str(cif_writer)
|
||||
|
||||
# 删除CIF文件中的对称性操作部分
|
||||
cif_data = remove_symmetry_equiv_xyz(cif_data)
|
||||
cif_data=cif_data.replace('# generated using pymatgen',"")
|
||||
# 生成一个唯一的键
|
||||
formula = structure.composition.reduced_formula
|
||||
key = f"{formula}_{i}"
|
||||
|
||||
result[key] = cif_data
|
||||
|
||||
# 只保留前Configs.MP_TOPK个结果
|
||||
if len(result) >= Configs.MP_TOPK:
|
||||
break
|
||||
|
||||
# except Exception as e:
|
||||
# continue
|
||||
|
||||
# 格式化响应
|
||||
try:
|
||||
prompt = f"""
|
||||
# Materials Project Symmetrized Crystal Structure Data
|
||||
|
||||
Below are symmetrized crystal structure data for {len(result)} materials from the Materials Project database, in CIF (Crystallographic Information File) format.
|
||||
These structures have been analyzed and optimized for symmetry using SpacegroupAnalyzer with precision parameter symprec={symprec}.\n
|
||||
"""
|
||||
|
||||
for i, (key, cif_data) in enumerate(result.items(), 1):
|
||||
prompt += f"[cif {i} begin]\n"
|
||||
prompt += cif_data
|
||||
prompt += f"\n[cif {i} end]\n\n"
|
||||
|
||||
prompt += """
|
||||
## Usage Instructions
|
||||
|
||||
1. You can copy the above CIF data and save it as .cif files
|
||||
2. Open these files with crystal structure visualization software (such as VESTA, Mercury, Avogadro, etc.)
|
||||
3. These structures can be used for further material analysis, simulation, or visualization
|
||||
|
||||
CIF files contain complete structural information of crystals, including cell parameters, atomic coordinates, symmetry, etc.
|
||||
Symmetrization helps identify and optimize crystal symmetry, making the structure more standardized and accurate.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
except Exception as format_error:
|
||||
import traceback
|
||||
return str(format_error)
|
||||
|
||||
|
||||
@llm_tool(name="get_mpid_from_formula", description="Get material IDs (mpid) from Materials Project database by chemical formula")
|
||||
async def get_mpid_from_formula(formula: str) -> str:
|
||||
"""
|
||||
Get material IDs (mpid) from Materials Project database by chemical formula.
|
||||
Returns mpids for the lowest energy structures.
|
||||
|
||||
Args:
|
||||
formula: Chemical formula (e.g., "Fe2O3")
|
||||
|
||||
Returns:
|
||||
Formatted text containing material IDs
|
||||
"""
|
||||
# 确保 formula 是列表类型,因为 _search_by_formula_worker 需要列表输入
|
||||
|
||||
os.environ['HTTP_PROXY'] = Configs.HTTP_PROXY or ''
|
||||
os.environ['HTTPS_PROXY'] = Configs.HTTPS_PROXY or ''
|
||||
id_list = []
|
||||
with MPRester(Configs.MP_API_KEY) as mpr:
|
||||
docs = mpr.materials.summary.search(formula=formula)#这里设定搜索条件id list =[]for doc in docs:#获取材料索引号id list.append(doc.material id)
|
||||
for doc in docs:
|
||||
id_list.append(doc.material_id)
|
||||
return id_list
|
||||
# cif_description_list= []
|
||||
# cif_information_list=[]
|
||||
# crystal_props_list=[]
|
||||
# #print("mp_id",id_list)
|
||||
# for mp_id in id_list:
|
||||
# cif_description = read_cif_txt_file('/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Text-bond/{}.txt'.format(mp_id))
|
||||
# cif_information = read_cif_txt_file('/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Symmetry_MPDatasets/{}_symmetrized.cif'.format(mp_id))
|
||||
# cif_information = cif_information.replace('# generated using pymatgen\n', '')
|
||||
|
||||
# crystal_props = get_extra_cif_info("/home/ubuntu/sas0/LYT/paper_dataset/mp_cif/Props/{}.json".format(mp_id), ['all_fields'])
|
||||
# cif_description_list.append(cif_description)
|
||||
# cif_information_list.append(cif_information)
|
||||
96
tools_for_ms/services_tools/oqmd_tools.py
Normal file
96
tools_for_ms/services_tools/oqmd_tools.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import logging
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from bs4 import BeautifulSoup
|
||||
from io import StringIO
|
||||
from typing import Annotated
|
||||
|
||||
from ..llm_tools import llm_tool
|
||||
from ..llm_tools import *
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@llm_tool(name="fetch_chemical_composition_from_OQMD", description="Fetch material data for a chemical composition from OQMD database")
|
||||
async def fetch_chemical_composition_from_OQMD (
|
||||
composition: Annotated[str, "Chemical formula (e.g., Fe2O3, LiFePO4)"]
|
||||
) -> str:
|
||||
"""
|
||||
Fetch material data for a chemical composition from OQMD database.
|
||||
|
||||
Args:
|
||||
composition: Chemical formula (e.g., Fe2O3, LiFePO4)
|
||||
|
||||
Returns:
|
||||
Formatted text with material information and property tables
|
||||
"""
|
||||
# Fetch data from OQMD
|
||||
url = f"https://www.oqmd.org/materials/composition/{composition}"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=100.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Validate response content
|
||||
if not response.text or len(response.text) < 100:
|
||||
raise ValueError("Invalid response content from OQMD API")
|
||||
|
||||
# Parse HTML data
|
||||
html = response.text
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# Parse basic data
|
||||
basic_data = []
|
||||
h1_element = soup.find('h1')
|
||||
if h1_element:
|
||||
basic_data.append(h1_element.text.strip())
|
||||
else:
|
||||
basic_data.append(f"Material: {composition}")
|
||||
|
||||
for script in soup.find_all('p'):
|
||||
if script:
|
||||
combined_text = ""
|
||||
for element in script.contents:
|
||||
if hasattr(element, 'name') and element.name == 'a' and 'href' in element.attrs:
|
||||
url = "https://www.oqmd.org" + element['href']
|
||||
combined_text += f"[{element.text.strip()}]({url}) "
|
||||
elif hasattr(element, 'text'):
|
||||
combined_text += element.text.strip() + " "
|
||||
else:
|
||||
combined_text += str(element).strip() + " "
|
||||
basic_data.append(combined_text.strip())
|
||||
|
||||
# Parse table data
|
||||
table_data = ""
|
||||
table = soup.find('table')
|
||||
if table:
|
||||
try:
|
||||
df = pd.read_html(StringIO(str(table)))[0]
|
||||
df = df.fillna('')
|
||||
df = df.replace([float('inf'), float('-inf')], '')
|
||||
table_data = df.to_markdown(index=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing table: {str(e)}")
|
||||
table_data = "Error parsing table data"
|
||||
|
||||
# Integrate data into a single text
|
||||
combined_text = "\n\n".join(basic_data)
|
||||
if table_data:
|
||||
combined_text += "\n\n## Material Properties Table\n\n" + table_data
|
||||
|
||||
return combined_text
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"OQMD API request failed: {str(e)}")
|
||||
return f"Error: OQMD API request failed - {str(e)}"
|
||||
except httpx.TimeoutException:
|
||||
logger.error("OQMD API request timed out")
|
||||
return "Error: OQMD API request timed out"
|
||||
except httpx.NetworkError as e:
|
||||
logger.error(f"Network error occurred: {str(e)}")
|
||||
return f"Error: Network error occurred - {str(e)}"
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid response content: {str(e)}")
|
||||
return f"Error: Invalid response content - {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
return f"Error: Unexpected error occurred - {str(e)}"
|
||||
80
tools_for_ms/services_tools/search_dify.py
Normal file
80
tools_for_ms/services_tools/search_dify.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
import json
|
||||
from .Configs import DIFY_API_KEY
|
||||
import requests
|
||||
import codecs
|
||||
from ..llm_tools import llm_tool
|
||||
|
||||
@llm_tool(
|
||||
name="retrieval_from_knowledge_base",
|
||||
description="Retrieve information from local materials science literature knowledge base"
|
||||
)
|
||||
async def retrieval_from_knowledge_base(query: str, topk: int = 3) -> str:
|
||||
"""
|
||||
检索本地材料科学文献知识库中的相关信息
|
||||
|
||||
输入:
|
||||
query: 查询字符串,如材料名称"CsPbBr3"
|
||||
topk: 返回结果数量,默认3条
|
||||
|
||||
输出:
|
||||
包含文档ID、标题和相关性分数的字典
|
||||
"""
|
||||
# 设置Dify API的URL端点
|
||||
url = 'http://192.168.191.101:6080/v1/chat-messages'
|
||||
|
||||
# 配置请求头,包含API密钥和内容类型
|
||||
headers = {
|
||||
'Authorization': f'Bearer {DIFY_API_KEY}', # 使用配置文件中的API密钥
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# 准备请求数据
|
||||
data = {
|
||||
"inputs": {"topK": topk}, # 设置返回的最大结果数量
|
||||
"query": query, # 设置查询字符串
|
||||
"response_mode": "blocking", # 使用阻塞模式,等待并获取完整响应
|
||||
"conversation_id": "", # 不使用会话ID,每次都是独立查询
|
||||
"user": "abc-123" # 用户标识符
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送POST请求到Dify API并获取响应
|
||||
# 设置较长的超时时间(1111秒)以处理可能的长时间响应
|
||||
response = requests.post(url, headers=headers, json=data, timeout=1111)
|
||||
|
||||
# 获取响应文本
|
||||
response_text = response.text
|
||||
useful_results = [] # 初始化结果列表(当前未使用)
|
||||
|
||||
# 解码响应文本中的Unicode转义序列
|
||||
response_text = codecs.decode(response_text, 'unicode_escape')
|
||||
print(response_text) # 打印完整响应用于调试
|
||||
|
||||
# 将响应文本解析为JSON对象
|
||||
result_json = json.loads(response_text)
|
||||
|
||||
# 从响应中提取元数据
|
||||
metadata = result_json.get("metadata", {})
|
||||
|
||||
# 构建包含关键信息的结果字典
|
||||
useful_info = {
|
||||
"id": metadata.get("document_id"), # 文档ID
|
||||
"title": result_json.get("title"), # 文档标题
|
||||
"content": None, # 内容字段设为空,注意:原字典使用'answer'字段存储内容
|
||||
"metadata": None, # 元数据字段设为空
|
||||
"embedding": None, # 嵌入向量字段设为空
|
||||
"score": metadata.get("score") # 相关性分数
|
||||
}
|
||||
|
||||
# 返回提取的有用信息
|
||||
return useful_info
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并处理所有可能的异常,返回错误信息
|
||||
return f"错误: {str(e)}"
|
||||
|
||||
# 当脚本直接运行时的测试代码
|
||||
if __name__ == "__main__":
|
||||
# 使用示例查询"CsPbBr3"测试函数
|
||||
print(asyncio.run(retrieval_from_knowledge_base('CsPbBr3')))
|
||||
101
tools_for_ms/utils.py
Normal file
101
tools_for_ms/utils.py
Normal file
@@ -0,0 +1,101 @@
|
||||
|
||||
|
||||
import os
|
||||
import boto3
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Material Project
|
||||
mp_api_key: Optional[str] = Field(None, env="MP_API_KEY")
|
||||
mp_endpoint: Optional[str] = Field(None, env="MP_ENDPOINT")
|
||||
mp_topk: Optional[int] = Field(3, env="MP_TOPK")
|
||||
|
||||
# Proxy
|
||||
http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
|
||||
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
|
||||
|
||||
# FairChem
|
||||
fairchem_model_path: Optional[str] = Field(None, env="FAIRCHEM_MODEL_PATH")
|
||||
fmax: Optional[float] = Field(0.05, env="FMAX")
|
||||
|
||||
# MinIO
|
||||
minio_endpoint: Optional[str] = Field(None, env="MINIO_ENDPOINT")
|
||||
internal_minio_endpoint: Optional[str] = Field(None, env="INTERNAL_MINIO_ENDPOINT")
|
||||
minio_access_key: Optional[str] = Field(None, env="MINIO_ACCESS_KEY")
|
||||
minio_secret_key: Optional[str] = Field(None, env="MINIO_SECRET_KEY")
|
||||
minio_bucket: Optional[str] = Field("mars-toolkit", env="MINIO_BUCKET")
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
def get_minio_client(settings: Settings):
|
||||
"""获取MinIO客户端"""
|
||||
return boto3.client(
|
||||
's3',
|
||||
endpoint_url=settings.internal_minio_endpoint or settings.minio_endpoint,
|
||||
aws_access_key_id=settings.minio_access_key,
|
||||
aws_secret_access_key=settings.minio_secret_key
|
||||
)
|
||||
|
||||
def handle_minio_upload(file_path: str, file_name: str) -> str:
|
||||
"""统一处理MinIO上传"""
|
||||
try:
|
||||
client = get_minio_client(settings)
|
||||
client.upload_file(file_path, settings.minio_bucket, file_name, ExtraArgs={"ACL": "private"})
|
||||
|
||||
# 生成预签名 URL
|
||||
url = client.generate_presigned_url(
|
||||
'get_object',
|
||||
Params={'Bucket': settings.minio_bucket, 'Key': file_name},
|
||||
ExpiresIn=3600
|
||||
)
|
||||
return url.replace(settings.internal_minio_endpoint or "", settings.minio_endpoint)
|
||||
except Exception as e:
|
||||
from tools_for_ms.services_tools.error_handlers import handle_minio_error
|
||||
return handle_minio_error(e)
|
||||
|
||||
def setup_logging():
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
log_file_path = os.path.join(parent_dir, 'mars_toolkit.log')
|
||||
"""配置日志记录"""
|
||||
logging.config.dictConfig({
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'standard': {
|
||||
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
'datefmt': '%Y-%m-%d %H:%M:%S'
|
||||
},
|
||||
},
|
||||
'handlers': {
|
||||
'console': {
|
||||
'level': 'INFO',
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': 'standard'
|
||||
},
|
||||
'file': {
|
||||
'level': 'DEBUG',
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'filename': log_file_path,
|
||||
'maxBytes': 10485760, # 10MB
|
||||
'backupCount': 5,
|
||||
'formatter': 'standard'
|
||||
}
|
||||
},
|
||||
'loggers': {
|
||||
'': {
|
||||
'handlers': ['console', 'file'],
|
||||
'level': 'INFO',
|
||||
'propagate': True
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
# 初始化配置
|
||||
settings = Settings()
|
||||
Reference in New Issue
Block a user