mcp,生成数据代码

This commit is contained in:
lzy
2025-04-16 11:15:01 +08:00
parent 72045e5cfe
commit 6b92e54a41
66 changed files with 1938 additions and 1483 deletions

0
.gitignore vendored Normal file → Executable file
View File

Binary file not shown.

0
__pycache__/execute_tool_copy.cpython-310.pyc Normal file → Executable file
View File

0
__pycache__/mattergen_wrapper.cpython-310.pyc Normal file → Executable file
View File

0
__pycache__/normalize_material_args.cpython-310.pyc Normal file → Executable file
View File

0
agent_test.py Normal file → Executable file
View File

0
api_key.py Normal file → Executable file
View File

View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
读取JSONL文件打印包含'generate_material content begin'的行
"""
import json
import sys
from rich.console import Console
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
console = Console()
def count_lines(file_path):
"""计算文件的总行数"""
with open(file_path, 'r', encoding='utf-8') as f:
return sum(1 for _ in f)
def filter_generate_material(file_path):
"""
读取JSONL文件打印包含'generate_material content begin'的行
Args:
file_path: JSONL文件路径
"""
# 计算文件总行数用于进度条
total_lines = count_lines(file_path)
console.print(f"[bold green]文件总行数: {total_lines}[/bold green]")
# 匹配计数
match_count = 0
# 使用进度条显示处理进度
with Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
) as progress:
# 创建进度条任务
task_id = progress.add_task("[bold green]处理JSONL文件...", total=total_lines)
# 逐行处理文件
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
try:
# 解析JSON
data = json.loads(line.strip())
# 检查是否包含observation字段
# if 'observation' in data:
# # 检查observation是否包含目标字符串
# if 'generate_material content begin' in data['observation'] and '[generate_material content begin]Error' not in data['observation']:
# console.print(f"[bold cyan]第 {line_num} 行包含目标内容:[/bold cyan]")
# console.print(line.strip())
# console.print("-" * 80)
# print(data)
# match_count += 1
# return
if 'solution' in data:
print("line_num",line_num)
print("solution",data['solution'])
if data['solution'] != "":
print( data['solution'])
return
except json.JSONDecodeError:
console.print(f"[bold red]第 {line_num} 行JSON解析错误[/bold red]")
except Exception as e:
console.print(f"[bold red]处理第 {line_num} 行时出错: {str(e)}[/bold red]")
# 更新进度条
progress.update(task_id, advance=1)
console.print(f"[bold green]处理完成,共找到 {match_count} 行包含 'generate_material content begin'[/bold green]")
if __name__ == "__main__":
# 默认文件路径
file_path ='/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl'
# "/home/ubuntu/50T/lzy/mars-mcp/mars-agent_data_20250408205427.jsonl"
# 如果提供了命令行参数,则使用命令行参数作为文件路径
if len(sys.argv) > 1:
file_path = sys.argv[1]
console.print(f"[bold blue]正在处理文件: {file_path}[/bold blue]")
filter_generate_material(file_path)

View File

@@ -17,7 +17,7 @@ import random
# 初始化colorama
init(autoreset=True)
from typing import Dict, Union, Any, Optional
from typing import Dict, Union, Any, Optional, List
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
@@ -110,6 +110,8 @@ def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
import requests
# 创建数据库连接池(仅用于初始加载数据)
connection_pool = pooling.MySQLConnectionPool(
pool_name="mypool",
pool_size=32,
@@ -120,7 +122,50 @@ connection_pool = pooling.MySQLConnectionPool(
database='metadata_mat_papers'
)
# 内存缓存,用于存储从数据库加载的数据
# 结构: {doi: record, mp_id: record}
memory_cache = {}
def load_data_to_memory():
"""
从数据库加载所有数据到内存中
"""
print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}")
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
# 查询所有记录
cursor.execute("SELECT * FROM mp_synthesis_scheme_info")
records = cursor.fetchall()
# 将记录添加到内存缓存中
for record in records:
doi = record.get('doi')
mp_id = record.get('mp_id')
# 使用doi作为键如果存在
if doi:
memory_cache[doi] = record
# 使用mp_id作为键如果存在
if mp_id:
memory_cache[mp_id] = record
print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}")
print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}")
finally:
cursor.close()
finally:
conn.close()
# 在程序启动时加载数据到内存中
load_data_to_memory()
async def process_retrieval_from_knowledge_base(data):
doi = data.get('doi')
mp_id = data.get('mp_id')
@@ -128,31 +173,12 @@ async def process_retrieval_from_knowledge_base(data):
if doi is None and mp_id is None:
return "" # 如果没有提供查询参数,返回空字符串
# 构建SQL查询条件
query = "SELECT * FROM mp_synthesis_scheme_info WHERE "
params = []
if doi is not None and mp_id is not None:
query += "doi = %s OR mp_id = %s"
params = [doi, mp_id]
elif doi is not None:
query += "doi = %s"
params = [doi]
else: # mp_id is not None
query += "mp_id = %s"
params = [mp_id]
# 从数据库中查询匹配的记录
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
cursor.execute(query, params)
result = cursor.fetchone() # 获取第一个匹配的记录
finally:
cursor.close()
finally:
conn.close()
# 从内存缓存中查询匹配的记录
result = None
if doi is not None and doi in memory_cache:
result = memory_cache[doi]
elif mp_id is not None and mp_id in memory_cache:
result = memory_cache[mp_id]
# 检查是否找到匹配的记录
if not result:
@@ -265,13 +291,13 @@ def worker(data, output_file_path):
arguments_data = func.get("arguments")
# 使用富文本打印函数名
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# 使用富文本打印参数
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
pass
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
@@ -281,60 +307,39 @@ def worker(data, output_file_path):
formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
# # 规范化参数
# try:
# # 确保arguments_data是字典
# if isinstance(arguments_data, str):
# try:
# arguments_data = json.loads(arguments_data)
# except json.JSONDecodeError as e:
# print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
# continue
try:
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
continue
# # 规范化参数
# normalized_args = normalize_material_args(arguments_data)
# # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
# # 优先使用mattergen函数
# try:
# # output = asyncio.run(generate_material(**normalized_args))
# output = generate_material(**normalized_args)
# 优先使用mattergen函数
try:
# # 添加延迟,模拟额外的工具函数调用
output = generate_material(**normalized_args)
# # 随机延迟5-10秒
# # delay_time = random.uniform(5, 10)
# # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
# # time.sleep(delay_time)
# # # 模拟其他工具函数调用的日志输出
# # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
# # time.sleep(0.5)
# # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
# # time.sleep(0.5)
# # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
# # time.sleep(0.5)
# # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
except Exception as e:
#print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
continue
# 将结果添加到func_results
func_results.append({"function": func_name, "result": output})
# except Exception as e:
# print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
# # 将结果添加到func_results
# func_results.append({"function": func_name, "result": output})
# # 格式化结果
# formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
# formatted_results.append(formatted_result)
# except Exception as e:
# print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
# import traceback
# print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
pass
# 格式化结果
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
continue
else:
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result})
@@ -347,17 +352,17 @@ def worker(data, output_file_path):
final_result = "\n\n\n".join(formatted_results)
data['observation'] = final_result
# 使用富文本打印开始和结束标记
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
print(data['observation'])
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
#使用富文本打印开始和结束标记
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
# print(data['observation'])
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:
print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}"
@@ -365,7 +370,6 @@ def main(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
from mysql.connector import pooling, Error
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing papers")
@@ -403,21 +407,26 @@ def main(datas, output_file_path, max_workers=1):
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
import jsonlines
datas = []
datas_with_solution = []
datas_without_solution = []
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
for obj in reader:
datas.append(obj)
if obj['solution']!='':
datas_with_solution.append(obj)
else:
datas_without_solution.append(obj)
print(len(datas))
# print()
output_file = f"./filter_ok_questions_solutions_agent_other_tools_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
datas_with_solution = datas_with_solution[:5000]
datas_without_solution = datas_without_solution[:5000]
datas = datas_with_solution + datas_without_solution
import random
random.shuffle(datas)
output_file = f"./filter_ok_questions_solutions_agent_data10000_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=32)
# 示例1使用正确的JSON格式
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
# argument = json.loads(argument)
# print(json.dumps(argument, indent=2))
# asyncio.run(mattergen(**argument))

View File

@@ -17,7 +17,7 @@ import random
# 初始化colorama
init(autoreset=True)
from typing import Dict, Union, Any, Optional
from typing import Dict, Union, Any, Optional, List
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
@@ -110,6 +110,8 @@ def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
import requests
# 创建数据库连接池(仅用于初始加载数据)
connection_pool = pooling.MySQLConnectionPool(
pool_name="mypool",
pool_size=32,
@@ -120,7 +122,50 @@ connection_pool = pooling.MySQLConnectionPool(
database='metadata_mat_papers'
)
# 内存缓存,用于存储从数据库加载的数据
# 结构: {doi: record, mp_id: record}
memory_cache = {}
def load_data_to_memory():
"""
从数据库加载所有数据到内存中
"""
print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}")
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
# 查询所有记录
cursor.execute("SELECT * FROM mp_synthesis_scheme_info")
records = cursor.fetchall()
# 将记录添加到内存缓存中
for record in records:
doi = record.get('doi')
mp_id = record.get('mp_id')
# 使用doi作为键如果存在
if doi:
memory_cache[doi] = record
# 使用mp_id作为键如果存在
if mp_id:
memory_cache[mp_id] = record
print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}")
print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}")
finally:
cursor.close()
finally:
conn.close()
# 在程序启动时加载数据到内存中
load_data_to_memory()
async def process_retrieval_from_knowledge_base(data):
doi = data.get('doi')
mp_id = data.get('mp_id')
@@ -128,31 +173,12 @@ async def process_retrieval_from_knowledge_base(data):
if doi is None and mp_id is None:
return "" # 如果没有提供查询参数,返回空字符串
# 构建SQL查询条件
query = "SELECT * FROM mp_synthesis_scheme_info WHERE "
params = []
if doi is not None and mp_id is not None:
query += "doi = %s OR mp_id = %s"
params = [doi, mp_id]
elif doi is not None:
query += "doi = %s"
params = [doi]
else: # mp_id is not None
query += "mp_id = %s"
params = [mp_id]
# 从数据库中查询匹配的记录
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
cursor.execute(query, params)
result = cursor.fetchone() # 获取第一个匹配的记录
finally:
cursor.close()
finally:
conn.close()
# 从内存缓存中查询匹配的记录
result = None
if doi is not None and doi in memory_cache:
result = memory_cache[doi]
elif mp_id is not None and mp_id in memory_cache:
result = memory_cache[mp_id]
# 检查是否找到匹配的记录
if not result:
@@ -265,62 +291,43 @@ def worker(data, output_file_path):
arguments_data = func.get("arguments")
# 使用富文本打印函数名
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# 使用富文本打印参数
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
pass
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
# result = asyncio.run(process_retrieval_from_knowledge_base(data))
# func_results.append({"function": func['name'], "result": result})
# # 格式化结果
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
# formatted_results.append(formatted_result)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
# 规范化参数
try:
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
continue
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
# print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# 优先使用mattergen函数
try:
# output = asyncio.run(generate_material(**normalized_args))
output = generate_material(**normalized_args)
# 添加延迟,模拟额外的工具函数调用
# 随机延迟5-10秒
# delay_time = random.uniform(5, 10)
# print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
# time.sleep(delay_time)
# # 模拟其他工具函数调用的日志输出
# print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
# time.sleep(0.5)
# print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
# time.sleep(0.5)
# print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
# time.sleep(0.5)
# print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
#print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
continue
# 将结果添加到func_results
func_results.append({"function": func_name, "result": output})
@@ -328,36 +335,34 @@ def worker(data, output_file_path):
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
continue
else:
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
pass
# result = asyncio.run(execute_tool_from_dict(func))
# func_results.append({"function": func['name'], "result": result})
# # 格式化结果
# func_name = func.get("name")
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
# formatted_results.append(formatted_result)
result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
func_name = func.get("name")
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['observation'] = final_result
# 使用富文本打印开始和结束标记
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
print(data['observation'])
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
#使用富文本打印开始和结束标记
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
# print(data['observation'])
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:
print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}"
@@ -365,7 +370,6 @@ def main(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
from mysql.connector import pooling, Error
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing papers")
@@ -409,12 +413,13 @@ if __name__ == '__main__':
datas = []
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
for obj in reader:
datas.append(obj)
#if obj['solution']!='':
datas.append(obj)
print(len(datas))
# print()
output_file = f"./filter_ok_questions_solutions_agent_mattergen_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=1)
output_file = f"./filter_ok_questions_solutions_agent_data_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=32)
# 示例1使用正确的JSON格式
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'

91
generate_data/grpo_tools.py Executable file
View File

@@ -0,0 +1,91 @@
import jsonlines
import argparse
import generate_data.utils as utils
import glob
import json
from ase import io
import tempfile
import re
from pymatgen.io.vasp import Poscar
from pymatgen.io.cif import CifParser
import threading
import concurrent.futures
import copy
from grpo_utils import generate_design_question, generate_props_question, generate_obs_response
# Create a lock for file writing
file_lock = threading.Lock()
def worker(data, output_file_path):
try:
messages = copy.deepcopy(data['messages'])
obs = data['observation']
messages[-1]['content'] = messages[-1]['content'].split("<answer>")[-1].split("</answer>")[0]
messages.append({"role": "user", "content": obs})
data['messages'].append({"role": "user", "content": obs})
# print(messages)
# print(obs)
reasoning_content, response = generate_obs_response(messages)
data['messages'].append({"role": "assistant", "content": f"<think>\n{reasoning_content}</think>\n<answer>\n{response}</answer>\n"})
# Use the lock to safely write to the file
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(messages)
return f"Processed successfully"
except Exception as e:
return f"Error processing: {str(e)}"
def main(input_file_path, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
datas = None
with jsonlines.open(input_file_path, mode='r') as reader:
datas = [line for line in reader]
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing CIF files")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_data = {}
for data in datas:
future = executor.submit(worker, data, output_file_path)
future_to_data[future] = data
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_data):
data = future_to_data[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {data} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
origin_file = "/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl"
output_file = f"agent_questions_solutions_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(origin_file, output_file)

294
generate_data/grpo_utils.py Executable file
View File

@@ -0,0 +1,294 @@
import jsonlines
import argparse
import generate_data.utils as utils
import glob
import json
from ase import io
import tempfile
import re
from pymatgen.io.vasp import Poscar
from pymatgen.io.cif import CifParser
import threading
import concurrent.futures
# Create a lock for file writing
file_lock = threading.Lock()
def generate_design_question(crystal_desc, cif_info, crystal_props, max_retries=3, initial_backoff=1.0):
instruction = """
{crystal_desc}
### 对应的晶体结构数据(CIF)如下:
{cif_info}
### 该晶体结构的物理化学性质为:
{crystal_props}
根据如上信息我现在需要给材料科学的博士考试出题问题要求博士们回答出上文中的完整CIF文件如果是你你会如何出题
也就是说要求我们提出的问题的答案是上文中提及的完整CIF文件。当然你的问题必须给定充足的该晶体结构的相关信息。
但是相关信息应该抽象和隐晦,避免过于直白,除明确的化学表达式外,尽量避免过多的精确信息,让博士考生们可以通过推理得到某些信息以增加问题的难度。
问题的语言一半是中文,一半是英文,以便更好地与模型进行交互。
请先生成10个问题示例再挑选2个最好的问题示例并遵循如下格式输出
```json
{
"selected_questions": [
{
"question_id": 1,
"question_text": "问题1的完整内容...",
},
{
"question_id": 2,
"question_text": "问题2的完整内容...",
}
]
}
"""
instruction = instruction.replace("{crystal_desc}", crystal_desc).replace("{cif_info}", cif_info).replace("{crystal_props}", crystal_props)
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": instruction}
]
import time
start_time = time.time()
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
if _response == 'apierror' or _response == 'unexpectederror':
return _response
# 尝试从响应中提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
questions_data = json.loads(json_str)
return questions_data
except json.JSONDecodeError:
# 如果JSON解析失败尝试清理字符串后再次解析
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
try:
questions_data = json.loads(cleaned_json)
return questions_data
except:
return {"error": "Failed to parse JSON response", "raw_response": _response}
else:
# 如果没有找到JSON格式返回原始响应
return {"error": "No JSON format found in response", "raw_response": _response}
def generate_props_question(crystal_desc, cif_info, crystal_props, max_retries=3, initial_backoff=1.0):
instruction = """
{crystal_desc}
### 对应的晶体结构数据(CIF)如下:
{cif_info}
### 该晶体结构的物理化学性质为:
{crystal_props}
根据如上信息我现在需要给材料科学的博士考试出题问题要求博士们根据CIF文件回答出上文中的物理化学性质如果是你你会如何出题
也就是说,要求我们提出的问题的答案是上文中提及的物理化学性质。当然,你的问题必须尽量包含一个<placeholder>标签代表给定的CIF文件。
让博士考生们根据给定的CIF文件通过深入思考和推理去分析该种晶体材料在上文所提及的全部物理化学性质并用JSON格式回答全部的物理化学性质。
问题的语言一半是中文,一半是英文,以便更好地与模型进行交互。
示例的问题:
1. <placeholder>\n根据上文提供的CIF文件请你xxx
2. 根据下文提供的CIF文件请你xxx\n <<placeholder>>
请先生成10个问题示例再挑选2个最好的问题示例并遵循如下格式输出
```json
{
"selected_questions": [
{
"question_id": 1,
"question_text": "问题1的完整内容...",
},
{
"question_id": 2,
"question_text": "问题2的完整内容...",
}
]
}
```
"""
instruction = instruction.replace("{crystal_desc}", crystal_desc).replace("{cif_info}", cif_info).replace("{crystal_props}", crystal_props)
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": instruction}
]
import time
start_time = time.time()
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
if _response == 'apierror' or _response == 'unexpectederror':
return _response
# 尝试从响应中提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
questions_data = json.loads(json_str)
return questions_data
except json.JSONDecodeError:
# 如果JSON解析失败尝试清理字符串后再次解析
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
try:
questions_data = json.loads(cleaned_json)
return questions_data
except:
return {"error": "Failed to parse JSON response", "raw_response": _response}
else:
# 如果没有找到JSON格式返回原始响应
return {"error": "No JSON format found in response", "raw_response": _response}
def generate_papers_other_question(paper_info, max_retries=3, initial_backoff=1.0):
instruction = """
{paper_info}
根据如上信息,我现在需要给材料科学的博士学生出题,问题要求考察博士对该材料的反应方程式、结构、性能和应用是否完全掌握,如果是你你会怎么出题?
你的问题里面应该包含该材料相关的合适的信息,且是自包含的(在只有问题的情况下问题中的关键信息不遗漏),但问题需要有难度和深度,需要博士生们深入思考和推理后才能作为准确的回答。
由于问题面向博士,因此,提出的问题需要一定的科研价值导向。涉及到反应方程式、关于结构、性能和应用等方面的具体试剂量等信息时,要求他们尽可能给出精确的数值(前提是这些数值在上文中存在)。
请先生成12个问题示例12个问题的语言一半是中文一半是英文再挑选4个最好的问题示例并遵循如下格式输出
```json
{
"selected_questions": [
{
"question_id": 1,
"question_text": "问题1的完整内容...",
"question_type": "问题1的类型", # reaction_string; structure; performence; application
},
{
"question_id": 2,
"question_text": "问题2的完整内容...",
"question_type": "问题1的类型", # reaction_string; structure; performence; application
}, ...
]
}
"""
instruction = instruction.replace("{paper_info}", paper_info)
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": instruction}
]
import time
start_time = time.time()
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
if _response == 'apierror' or _response == 'unexpectederror':
return _response
# 尝试从响应中提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
questions_data = json.loads(json_str)
return questions_data
except json.JSONDecodeError:
# 如果JSON解析失败尝试清理字符串后再次解析
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
try:
questions_data = json.loads(cleaned_json)
return questions_data
except:
return {"error": "Failed to parse JSON response", "raw_response": _response}
else:
# 如果没有找到JSON格式返回原始响应
return {"error": "No JSON format found in response", "raw_response": _response}
def generate_papers_synthesis_question(paper_info, max_retries=3, initial_backoff=1.0):
instruction = """
{paper_info}
根据如上信息,我现在需要给材料科学的博士学生出题,问题要求考察博士是否完全掌握该材料的合成方案,是否完全掌握给定材料的结构和性能到合成方案的精准映射关系,如果是你你会怎么出题?
你的问题里面应该包含该材料充分的结构和性能信息问题需要有难度和深度需要博士生们深入思考和推理后才能给出准确的合成方案并整理成JSON格式的格式化合成方案。
由于问题面向博士,因此,提出的问题需要一定的科研价值导向,并且要求博士在回答该材料的合成方案时给出精确的数值(包括试剂、前驱体、容器、温度等合成条件)。
问题中作为条件信息的部分需要尽可能的在问题中明确而不是隐晦(你要考虑到博士们拿到问题的时候并不知道上文中的信息,所以类似“基于给定的材料结构和性能信息”这种问法应该尽量避免)。
请先生成6个问题示例6个问题的语言一半是中文一半是英文再挑选2个最好的问题示例并遵循如下格式输出
```json
{
"selected_questions": [
{
"question_id": 1,
"question_text": "问题1的完整内容...",
},
{
"question_id": 2,
"question_text": "问题2的完整内容...",
},
]
}
"""
instruction = instruction.replace("{paper_info}", paper_info)
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": instruction}
]
import time
start_time = time.time()
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
if _response == 'apierror' or _response == 'unexpectederror':
return _response
# 尝试从响应中提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
questions_data = json.loads(json_str)
return questions_data
except json.JSONDecodeError:
# 如果JSON解析失败尝试清理字符串后再次解析
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
try:
questions_data = json.loads(cleaned_json)
return questions_data
except:
return {"error": "Failed to parse JSON response", "raw_response": _response}
else:
# 如果没有找到JSON格式返回原始响应
return {"error": "No JSON format found in response", "raw_response": _response}
def generate_function_call(messages, tools, max_retries=3, initial_backoff=1.0):
import time
start_time = time.time()
instruction = """
# 问题
{question}
# 指令
在准确的回答上述问题之前,你只有现在这一次机会允许你调用工具以获取更多信息。
请尽可能深入思考上述问题,并尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
因此,你需要在回答中一次给出多个经过思考后的工具调用,以便更好地回答上述问题。
思考和回答时使用和问题相同的语言。
"""
messages[0]["content"] = instruction.replace("{question}", messages[0]["content"])
_response, functions = utils.get_response_from_qwq(messages, model_name="qwq-32b", tools=tools, max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
# print(_response)
# if _response == 'apierror' or _response == 'unexpectederror':
# return _response
return _response, functions
def generate_obs_response(messages, max_retries=3, initial_backoff=1.0):
import time
start_time = time.time()
_reasoning_content, response = utils.get_response_from_deepseek_r1(messages, prefix=False, max_retries=max_retries, initial_backoff=initial_backoff)
return _reasoning_content, response

800
generate_data/utils.py Executable file
View File

@@ -0,0 +1,800 @@
"""
This script generates questions and answers from a given set of CIFs.
It uses the OpenAI API and MySQL for storing and retrieving data.
@author: Yutang Li
"""
import multiprocessing
import sqlite3
import tiktoken
import re
from fractions import Fraction
import numpy as np
import glob
import tqdm
import copy
import json
import time
import random
from openai import OpenAI, APIError, RateLimitError
from mysql.connector import pooling, Error
from collections import Counter
def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, max_retries: int = 3, initial_backoff: float = 1.0):
"""
Get response from DeepSeek API with retry mechanism.
Args:
messages: List of message dictionaries
prefix: Whether to use the prefix URL
max_retries: Maximum number of retry attempts
initial_backoff: Initial backoff time in seconds
Returns:
Tuple of (reasoning_content, content) or error messages
"""
retries = 0
while retries <= max_retries:
try:
base_url = "https://api.deepseek.com/beta" if prefix else "https://vip.apiyi.com/v1"
api_key = "sk-59279cc16ec740089146ef9aef9c1671" if prefix else "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
client = OpenAI(api_key=api_key, base_url=base_url)
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
response = client.chat.completions.create(
model="deepseek-r1",
messages=messages,
temperature=0.6
)
# reasoning_content = "null" if prefix else "<think>\n" + response.choices[0].message.model_extra['reasoning_content'] + "\n</think>\n"
reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
content = response.choices[0].message.content.split("</think>\n")[-1]
return reasoning_content, content
except RateLimitError as rate_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for RateLimitError: {rate_error}")
return 'apierror', 'apierror'
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
time.sleep(backoff_time)
except APIError as api_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for APIError: {api_error}")
return 'apierror', 'apierror'
# Check if the error is retryable
error_str = str(api_error)
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
time.sleep(backoff_time)
else:
# Non-retryable API error
print(f"Non-retryable API error: {api_error}")
return 'apierror', 'apierror'
except Exception as e:
print(f"generate_design_question Unexpected error: {e}")
return 'unexpectederror', 'unexpectederror'
def get_response_from_llm(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
"""
Get response from LLM API with retry mechanism.
Args:
messages: List of message dictionaries
model_name: Name of the model to use
tools: Optional list of tools to use
max_retries: Maximum number of retry attempts
initial_backoff: Initial backoff time in seconds
Returns:
Content of the response or error message
"""
retries = 0
while retries <= max_retries:
try:
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
if tools is None:
response = client.chat.completions.create(
model=model_name,
messages=messages,
)
else:
response = client.chat.completions.create(
model=model_name,
messages=messages,
tools=tools,
tool_choice='auto',
parallel_tool_calls=True
)
content = response.choices[0].message.content
return content
except RateLimitError as rate_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for RateLimitError: {rate_error}")
return 'apierror'
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
time.sleep(backoff_time)
except APIError as api_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for APIError: {api_error}")
return 'apierror'
# Check if the error is retryable
error_str = str(api_error)
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
time.sleep(backoff_time)
else:
# Non-retryable API error
print(f"Non-retryable API error: {api_error}")
return 'apierror'
except Exception as e:
print(f"generate_design_question Unexpected error: {e}")
return 'unexpectederror'
def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
"""
Get response from LLM API with retry mechanism.
Args:
messages: List of message dictionaries
model_name: Name of the model to use
tools: Optional list of tools to use
max_retries: Maximum number of retry attempts
initial_backoff: Initial backoff time in seconds
Returns:
Content of the response or error message
"""
retries = 0
while retries <= max_retries:
try:
# client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
# client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
import random
if random.random() > 0.5:
client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
else:
client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
if tools is None:
response = client.chat.completions.create(
model=model_name,
messages=messages,
stream=True
)
else:
response = client.chat.completions.create(
model=model_name,
messages=messages,
tools=tools,
tool_choice='auto',
parallel_tool_calls=True,
stream=True
)
reasoning_content = "" # 定义完整思考过程
answer_content = "" # 定义完整回复
tool_info = [] # 存储工具调用信息
is_answering = False # 判断是否结束思考过程并开始回复
# print("="*20+"思考过程"+"="*20)
for chunk in response:
# if not chunk.choices:
# # 处理用量统计信息
# print("\n"+"="*20+"Usage"+"="*20)
# print(chunk.usage)
# else:
delta = chunk.choices[0].delta
# 处理AI的思考过程链式推理
if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None:
reasoning_content += delta.reasoning_content
# print(delta.reasoning_content,end="",flush=True) # 实时输出思考过程
# 处理最终回复内容
else:
if not is_answering: # 首次进入回复阶段时打印标题
is_answering = True
# print("\n"+"="*20+"回复内容"+"="*20)
if delta.content is not None:
answer_content += delta.content
# print(delta.content,end="",flush=True) # 流式输出回复内容
# 处理工具调用信息(支持并行工具调用)
if delta.tool_calls is not None:
for tool_call in delta.tool_calls:
index = tool_call.index # 工具调用索引,用于并行调用
# 动态扩展工具信息存储列表
while len(tool_info) <= index:
tool_info.append({})
# 收集工具调用ID用于后续函数调用
# if tool_call.id:
# tool_info[index]['id'] = tool_info[index].get('id', '') + tool_call.id
# 收集函数名称(用于后续路由到具体函数)
if tool_call.function and tool_call.function.name:
tool_info[index]['name'] = tool_info[index].get('name', '') + tool_call.function.name
# 收集函数参数JSON字符串格式需要后续解析
if tool_call.function and tool_call.function.arguments:
tool_info[index]['arguments'] = tool_info[index].get('arguments', '') + tool_call.function.arguments
tools_response = ""
for tool in tool_info:
tools_response += ("<tool_call>\n" + json.dumps(tool, ensure_ascii=False) + "\n</tool_call>\n")
response = "<think>\n" + reasoning_content + "\n</think>\n" + "<answer>\n" + answer_content + tools_response + "\n</answer>\n"
return response, tool_info
# return reasoning_content, answer_content, tool_info
except RateLimitError as rate_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for RateLimitError: {rate_error}")
return 'apierror', []
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
time.sleep(backoff_time)
except APIError as api_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for APIError: {api_error}")
return 'apierror', []
# Check if the error is retryable
error_str = str(api_error)
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
time.sleep(backoff_time)
else:
# Non-retryable API error
print(f"Non-retryable API error: {api_error}")
return 'apierror', []
except Exception as e:
print(f"generate_design_question Unexpected error: {e}")
return 'unexpectederror', []
def read_json_file(file_path):
"""Read the json file and return its content."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return None
################################## utils
def clean_all_repetitions_with_details(text, min_length=10, threshold=10):
"""
综合清理文本中的各种重复内容,并返回详细信息
参数:
- text: 要清理的文本
- min_length: 最小重复片段长度
- threshold: 重复内容的阈值
返回:
- cleaned_text: 清理后的文本
- is_repetitive: 是否检测到重复
- repetition_details: 重复内容的详细信息
"""
original_text = text
is_repetitive = False
repetition_details = []
# 1. 首先处理有换行符的重复
if '\n' in text:
lines = text.split('\n')
unique_lines = []
line_counts = {}
for i, line in enumerate(lines):
normalized = line.strip().lower()
if not normalized:
unique_lines.append(line)
continue
line_counts[normalized] = line_counts.get(normalized, 0) + 1
if line_counts[normalized] <= threshold:
unique_lines.append(line)
# 如果这是第一次超过阈值,记录重复详情
if line_counts[normalized] == threshold + 1:
# 找到原始形式(保留大小写)
original_form = None
for l in lines[:i]:
if l.strip().lower() == normalized:
original_form = l
break
if original_form is None:
original_form = line
repetition_details.append({
'type': 'line_repetition',
'repeated_string': original_form,
'repeat_count': line_counts[normalized]
})
if any(count > threshold for count in line_counts.values()):
text = '\n'.join(unique_lines)
is_repetitive = True
# 2. 处理同一行内的连续重复模式
for length in range(min_length, 101):
pattern = r'(.{' + str(length) + r'})(\1)+'
while True:
match = re.search(pattern, text)
if not match:
break
repeated_part = match.group(1)
full_match = match.group(0)
# 计算重复次数
repeat_count = len(full_match) // len(repeated_part)
# 记录重复详情
repetition_details.append({
'type': 'inline_repetition',
'repeated_string': repeated_part,
'repeat_count': repeat_count,
'total_length': len(full_match),
'position': match.start()
})
text = text.replace(full_match, repeated_part)
is_repetitive = True
# 3. 处理句子级别的重复
sentences = re.split(r'(?<=[.!?。?!])\s+', text)
if len(sentences) > 1:
sentence_counter = Counter(sentences)
for sentence, count in sentence_counter.items():
if count > threshold:
repetition_details.append({
'type': 'sentence_repetition',
'repeated_string': sentence,
'repeat_count': count
})
if any(count > threshold for count in sentence_counter.values()):
unique_sentences = []
seen_sentences = {}
for sentence in sentences:
seen_sentences[sentence] = seen_sentences.get(sentence, 0) + 1
if seen_sentences[sentence] <= threshold:
unique_sentences.append(sentence)
# 重新组合文本
text = ' '.join(unique_sentences)
is_repetitive = True
# 4. 处理更短的重复(如果前面的方法没有检测到重复)
if not is_repetitive and min_length > 5:
for length in range(5, min_length):
pattern = r'(.{' + str(length) + r'})(\1){2,}' # 至少重复3次才处理
while True:
match = re.search(pattern, text)
if not match:
break
repeated_part = match.group(1)
full_match = match.group(0)
# 计算重复次数
repeat_count = len(full_match) // len(repeated_part)
# 记录重复详情
repetition_details.append({
'type': 'short_repetition',
'repeated_string': repeated_part,
'repeat_count': repeat_count,
'total_length': len(full_match),
'position': match.start()
})
text = text.replace(full_match, repeated_part)
is_repetitive = True
# 按重复类型和长度排序
repetition_details.sort(key=lambda x: (-len(x['repeated_string']), x['type']))
return text, is_repetitive or text != original_text, repetition_details
def create_table(table_name, connection_pool):
"""Create the required MySQL table if it does not exist."""
db = connection_pool.get_connection()
cursor = db.cursor()
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
id INT AUTO_INCREMENT PRIMARY KEY,
mp_id TEXT,
question_model TEXT,
question TEXT,
answer_model TEXT,
answer TEXT,
answer_len INT
)
"""
cursor.execute(create_table_query)
db.commit()
cursor.close()
db.close()
def record_exists(mp_id, table_name, connection_pool):
"""Check if a mp_id already exists in the table."""
db = connection_pool.get_connection()
cursor = db.cursor()
query = f"SELECT * FROM {table_name} WHERE mp_id = %s"
cursor.execute(query, (mp_id,))
result = cursor.fetchone()
cursor.fetchall() # Ensure all results are processed
cursor.close()
db.close()
return result is not None
def insert_record(entry, table_name, connection_pool):
"""Insert a record into the MySQL table."""
db = None
cursor = None
try:
db = connection_pool.get_connection()
cursor = db.cursor()
insert_query = f"""
INSERT INTO {table_name}
(mp_id, question_model, question, answer_model, answer, answer_len)
VALUES (%s, %s, %s, %s, %s, %s)
"""
values = (
entry["mp_id"], entry["question_model"],
entry["question"], entry["answer_model"], entry["answer"], entry["answer_len"],
)
cursor.execute(insert_query, values)
db.commit()
except Error as e:
print(f"Error: {e}")
db.rollback()
finally:
# Ensure cursor is closed
if cursor:
cursor.close()
# Ensure connection is returned to the pool
if db:
db.close()
# Initialize SQLite database connection
def initialize_db():
conn = sqlite3.connect('multi_turns_data.db', check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS conversations (
mp_id TEXT PRIMARY KEY,
sample TEXT,
token_num INTEGER
)
''')
conn.commit()
return conn
# Save sample to SQLite database
def save_to_db(conn, mp_id, sample, total_token):
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO conversations (mp_id, sample, token_num)
VALUES (?, ?, ?)
''', (mp_id, str(sample), total_token))
conn.commit()
def read_cif_txt_file(file_path):
"""Read the markdown file and return its content."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return None
def round_values(data, precision=3):
"""
递归地将字典中的所有值保留三位小数
"""
if isinstance(data, dict): # 如果是字典
return {key: round_values(value) for key, value in data.items()}
elif isinstance(data, list): # 如果是列表,递归处理每个元素
return [round_values(item) for item in data]
elif isinstance(data, (int, float)): # 如果是数字,保留三位小数
return round(data, precision)
else: # 对其他类型,直接返回
return data
def decimal_to_fraction(decimal_value, max_denominator=1000):
"""
将小数转换为分数表示
参数:
decimal_value: 要转换的小数
max_denominator: 分母的最大值,用于控制精度
返回:
分数表示的字符串
"""
frac = Fraction(decimal_value).limit_denominator(max_denominator)
return f"{frac.numerator}/{frac.denominator}"
def poscar_to_fractional_representation(poscar_content, max_denominator=1000):
"""
将POSCAR文件中的数值转换为分数表示
参数:
poscar_content: POSCAR文件内容
max_denominator: 分母的最大值,用于控制精度
返回:
转换后的POSCAR内容数值以分数表示
"""
lines = poscar_content.strip().split('\n')
result_lines = []
# 保留系统名称
result_lines.append(lines[0])
# 保留缩放因子
scaling_factor = float(lines[1])
result_lines.append(lines[1])
# 处理晶格向量
for i in range(2, 5):
vector = [float(x) for x in lines[i].split()]
# 将每个分量转换为分数
fractional_vector = [decimal_to_fraction(x, max_denominator) for x in vector]
result_lines.append(" " + " ".join(fractional_vector))
# 保留元素类型和数量
if len(lines) > 5:
result_lines.append(lines[5])
if len(lines) > 6:
result_lines.append(lines[6])
# 保留坐标类型
if len(lines) > 7:
result_lines.append(lines[7])
# 处理原子坐标
for i in range(8, len(lines)):
parts = lines[i].split()
if len(parts) >= 3:
# 将坐标转换为分数
coords = [float(parts[j]) for j in range(3)]
fractional_coords = [decimal_to_fraction(x, max_denominator) for x in coords]
# 构建新行
new_line = " " + " ".join(fractional_coords)
if len(parts) > 3:
new_line += " " + " ".join(parts[3:])
result_lines.append(new_line)
else:
# 保留非坐标行
result_lines.append(lines[i])
return "\n".join(result_lines)
def remove_symmetry_equiv_xyz(cif_content):
"""
删除CIF文件中的对称性操作部分
参数:
cif_content: CIF文件内容字符串
返回:
清理后的CIF内容字符串
"""
lines = cif_content.split('\n')
output_lines = []
i = 0
while i < len(lines):
line = lines[i].strip()
# 检测循环开始
if line == 'loop_':
# 查看下一行,检查是否是对称性循环
next_lines = []
j = i + 1
while j < len(lines) and lines[j].strip().startswith('_'):
next_lines.append(lines[j].strip())
j += 1
# 检查是否包含对称性操作标签
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
# 跳过整个循环块
while i < len(lines):
if i + 1 >= len(lines):
break
next_line = lines[i + 1].strip()
# 检查是否到达下一个循环或数据块
if next_line == 'loop_' or next_line.startswith('data_'):
break
# 检查是否到达原子位置部分
if next_line.startswith('_atom_site_'):
break
i += 1
else:
# 不是对称性循环保留loop_行
output_lines.append(lines[i])
else:
# 非循环开始行,直接保留
output_lines.append(lines[i])
i += 1
return '\n'.join(output_lines)
def remove_null_values(d):
"""
Recursively remove key-value pairs with null (None) values from a dictionary.
Args:
d (dict): The dictionary to clean.
Returns:
dict: A new dictionary without null values.
"""
if not isinstance(d, dict):
raise ValueError("Input must be a dictionary")
_d = copy.deepcopy(d)
def recursive_remove(d):
cleaned_dict = {}
for key, value in d.items():
if isinstance(value, dict):
# Recursively clean nested dictionaries
nested_cleaned = recursive_remove(value)
if nested_cleaned: # Only add non-empty dictionaries
cleaned_dict[key] = nested_cleaned
elif value is not None and key != 'version':
cleaned_dict[key] = value
return cleaned_dict
clean_dict = recursive_remove(d)
if _d['cbm'] is None and _d['vbm'] is None and _d['band_gap'] is not None:
# clean_dict['band_gap'] = None
clean_dict.pop('band_gap')
return clean_dict
def get_extra_cif_info(path: str, fields_name: list):
"""Extract specific fields from the CIF description."""
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
# metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"]
selected_fields = []
if fields_name[0] == 'all_fields':
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
# selected_fields = energy_electronic_fields + metal_magentic_fields
else:
for field in fields_name:
selected_fields.extend(locals().get(field, []))
with open(path, 'r') as f:
docs = json.load(f)
new_docs = {}
for field_name in selected_fields:
new_docs[field_name] = docs.get(field_name, '')
# new_docs['structure'] = {"lattice": docs['structure']['lattice']}
return new_docs
def extract_json(text):
"""Extract JSON content from a block of text using regex."""
json_pattern = re.compile(r'\\{(?:[^{}]|(?R))*\\}')
matches = json_pattern.search(text)
if matches:
json_str = matches.group(0)
try:
return json.loads(json_str)
except json.JSONDecodeError:
return None
return None
def extract_and_parse_json(response):
"""Extract and parse JSON from a response."""
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
json_str = json_match.group(1) if json_match else response.strip()
json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str)
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"JSON parse error: {e}")
return 'errformat'
# 计算输入消息的tokens
def count_message_tokens(messages, model_name):
encoding = tiktoken.encoding_for_model(model_name)
num_tokens = 0
num_tokens += len(encoding.encode(messages))
return num_tokens
def make_multi_turns_sharegpt_sample(humans: list[str], gpts: list[str], system: str="{SYSTEM}"):
sample = {}
conversations = []
if system is not None and system != "":
sample["system"] = system
assert len(humans) !=0, "human cannot be None"
assert len(gpts) == len(humans), "human and gpt must have the same length"
for human, gpt in zip(humans, gpts):
if human is not None and human != "":
assert gpt is not None, "gpt cannot be None"
assert gpt != "", "gpt cannot be empty"
# 下列顺序不可改
conversations.append({"from": "human", "value": human})
conversations.append({"from": "gpt", "value": gpt})
sample["conversations"] = conversations
return sample
##################################### utils

File diff suppressed because it is too large Load Diff

0
mars_toolkit/__init__.py Normal file → Executable file
View File

BIN
mars_toolkit/__pycache__/__init__.cpython-310.pyc Normal file → Executable file

Binary file not shown.

0
mars_toolkit/compute/__init__.py Normal file → Executable file
View File

BIN
mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc Normal file → Executable file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

0
mars_toolkit/compute/material_gen.py Normal file → Executable file
View File

0
mars_toolkit/compute/property_pred.py Normal file → Executable file
View File

0
mars_toolkit/compute/structure_opt.py Normal file → Executable file
View File

0
mars_toolkit/core/__init__.py Normal file → Executable file
View File

BIN
mars_toolkit/core/__pycache__/__init__.cpython-310.pyc Normal file → Executable file

Binary file not shown.

BIN
mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc Normal file → Executable file

Binary file not shown.

BIN
mars_toolkit/core/__pycache__/config.cpython-310.pyc Normal file → Executable file

Binary file not shown.

View File

BIN
mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc Normal file → Executable file

Binary file not shown.

Binary file not shown.

0
mars_toolkit/core/__pycache__/utils.cpython-310.pyc Normal file → Executable file
View File

0
mars_toolkit/core/cif_utils.py Normal file → Executable file
View File

8
mars_toolkit/core/config.py Normal file → Executable file
View File

@@ -22,12 +22,12 @@ class Config:
HTTPS_PROXY = 'http://192.168.168.1:20171'
# FairChem
FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
FMAX = 0.05
# MatterGen
MATTERGENMODEL_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
MATTERGEN_ROOT='/home/ubuntu/50T/lzy/mars-mcp/mattergen'
MATTERGENMODEL_ROOT = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
MATTERGEN_ROOT='/home/ubuntu/50T/nfs/lzy/mars-mcp/mattergen'
MATTERGENMODEL_RESULT_PATH = 'results/'
# Dify
@@ -38,7 +38,7 @@ class Config:
SEARXNG_HOST="http://192.168.168.1:40032/"
# Visualization
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/nfs/lzy/mars-mcp/outputs/cif_visualization'
@classmethod
def as_dict(cls) -> Dict[str, Any]:

0
mars_toolkit/core/llm_tools.py Normal file → Executable file
View File

0
mars_toolkit/core/mattergen_wrapper.py Normal file → Executable file
View File

0
mars_toolkit/misc/__init__.py Normal file → Executable file
View File

BIN
mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc Normal file → Executable file

Binary file not shown.

View File

BIN
mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc Normal file → Executable file

Binary file not shown.

0
mars_toolkit/misc/misc_tools.py Normal file → Executable file
View File

0
mars_toolkit/query/__init__.py Normal file → Executable file
View File

BIN
mars_toolkit/query/__pycache__/__init__.cpython-310.pyc Normal file → Executable file

Binary file not shown.

Binary file not shown.

BIN
mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc Normal file → Executable file

Binary file not shown.

BIN
mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc Normal file → Executable file

Binary file not shown.

BIN
mars_toolkit/query/__pycache__/web_search.cpython-310.pyc Normal file → Executable file

Binary file not shown.

0
mars_toolkit/query/dify_search.py Normal file → Executable file
View File

0
mars_toolkit/query/mp_query.py Normal file → Executable file
View File

0
mars_toolkit/query/oqmd_query.py Normal file → Executable file
View File

0
mars_toolkit/query/web_search.py Normal file → Executable file
View File

0
mars_toolkit/services/__init__.py Normal file → Executable file
View File

View File

View File

0
mars_toolkit/services/mattergen_service.py Normal file → Executable file
View File

0
mars_toolkit/visualization/__init__.py Normal file → Executable file
View File

View File

View File

View File

0
mattergen_api.py Normal file → Executable file
View File

View File

@@ -1,134 +0,0 @@
import requests
import json
import argparse
import sys
def generate_material(
url="http://localhost:8051/generate_material",
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen API生成晶体结构
Args:
url: API端点URL
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
# 构建请求负载
payload = {
"properties": properties ,
"batch_size": batch_size,
"num_batches": num_batches,
"diffusion_guidance_factor": diffusion_guidance_factor
}
print(f"发送请求到 {url}")
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
# 添加headers参数包含accept头
headers = {
"Content-Type": "application/json",
"accept": "application/json"
}
# 打印完整请求信息(调试用)
print(f"完整请求URL: {url}")
print(f"请求头: {headers}")
print(f"请求体: {json.dumps(payload)}")
# 禁用代理设置
proxies = {
"http": None,
"https": None
}
# 发送POST请求添加headers参数禁用代理增加超时时间
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
# 打印响应信息(调试用)
print(f"响应状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符避免输出过长
# 检查响应状态
if response.status_code == 200:
result = response.json()
if result["success"]:
print("\n生成成功!")
return result["content"]
else:
print(f"\n生成失败: {result['message']}")
return None
else:
print(f"\n请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return None
except Exception as e:
print(f"\n发生错误: {str(e)}")
print(f"错误类型: {type(e).__name__}")
import traceback
print(f"错误堆栈: {traceback.format_exc()}")
return None
def main():
"""命令行入口函数"""
parser = argparse.ArgumentParser(description="MatterGen API客户端示例")
# 添加命令行参数
parser.add_argument("--url", default="http://localhost:8051/generate_material",
help="MatterGen API端点URL")
parser.add_argument("--property-name", default='dft_mag_density',help="属性名称例如dft_band_gap")
parser.add_argument("--property-value",default=0.15,help="属性值例如2.0")
parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量")
parser.add_argument("--num-batches", type=int, default=1, help="批次数量")
parser.add_argument("--guidance-factor", type=float, default=2.0,
help="控制生成结构与目标属性的符合程度")
args = parser.parse_args()
# 构建属性字典
properties = None
if args.property_name and args.property_value:
try:
# 尝试将属性值转换为数字
try:
value = float(args.property_value)
# 如果是整数,转换为整数
if value.is_integer():
value = int(value)
except ValueError:
# 如果无法转换为数字,保持为字符串
value = args.property_value
properties = {args.property_name: value}
except Exception as e:
print(f"解析属性值时出错: {str(e)}")
return
# 调用API
result = generate_material(
url=args.url,
properties=properties,
batch_size=args.batch_size,
num_batches=args.num_batches,
diffusion_guidance_factor=args.guidance_factor
)
if result:
print("\n生成的结构:")
print(result)
if __name__ == "__main__":
main()

Binary file not shown.

167
prompts/material_synthesis.py Executable file
View File

@@ -0,0 +1,167 @@
from typing import Dict, List, Optional
import mcp.types as types
from mcp.server.lowlevel import Server
def create_messages(
properties: Dict[str, str] = None,
batch_size: int = 2,
) -> list[types.PromptMessage]:
"""
创建用于材料生成和合成的提示词消息。
Args:
properties: 材料性质及其值的字典,例如 {"dft_band_gap": "2.0"}
batch_size: 生成材料的数量默认为2
Returns:
提示词消息列表
"""
messages = []
# 系统消息,定义助手的角色和任务
system_message = """你是一位专业的材料科学家,擅长材料生成和合成方案设计。
你的任务是:
1. 根据用户提供的材料性质要求使用mars_toolkit中的generate_materials工具生成符合要求的材料
2. 系统地分析生成的材料的四要素:成分、结构、工艺和性能
3. 为生成的材料设计科学合理的合成方案
4. 使用mermaid语法绘制材料的合成流程图
请确保你的回答包含以下内容:
- 对用户需求的理解和分析
- 使用generate_material工具生成的材料结构
- 生成材料的四要素详细分析:
* 成分(Composition):详细的化学成分、元素比例、化学计量比
* 结构(Structure):晶体结构、空间群、晶格参数、原子位置、配位环境
* 工艺(Processing):可行的合成路线、工艺参数、关键控制因素
* 性能(Properties):预期的物理、化学、机械性能及其与结构的关系
- 详细的合成方案,包括:
* 原料选择及纯度要求
* 精确的反应条件(温度、压力、时间、气氛)
* 分步骤的合成流程及每步的理论依据
* 可能的挑战及解决方案
* 表征方法建议
- 使用mermaid语法绘制的合成流程图清晰展示从原料到最终产品的全过程
"""
messages.append(
types.PromptMessage(
role="system",
content=types.TextContent(type="text", text=system_message),
)
)
# 构建主提示词
if properties and len(properties) > 0:
properties_text = "\n".join([f"- {key}: {value}" for key, value in properties.items()])
prompt = f"""请根据以下材料性质要求,生成{batch_size}个合适的材料并设计其合成方案:
{properties_text}
请按照以下步骤进行:
1. 使用mars_toolkit中的generate_material工具生成材料参数设置为batch_size={batch_size}
2. 对生成的每种材料进行系统的四要素分析:
- 成分:详细分析元素组成、化学计量比及其理论依据
- 结构:分析晶体结构、空间群、晶格参数、原子排布及其稳定性
- 工艺:探讨可行的合成路线、工艺参数及其科学依据
- 性能:预测材料可能具有的物理、化学、机械性能及其应用前景
3. 为每种材料设计详细的合成方案,包括:
- 原料选择及纯度要求
- 精确的反应条件参数(温度、压力、时间、气氛等)
- 分步骤的合成流程及每步的理论依据
- 可能遇到的挑战及解决方案
- 推荐的表征方法
4. 使用mermaid语法绘制材料的合成流程图清晰展示从原料到最终产品的全过程包括关键工艺参数。"""
else:
prompt = f"""请生成{batch_size}种具有创新性的新型材料并设计其合成方案。
请按照以下步骤进行:
1. 使用mars_toolkit中的generate_material工具生成材料参数设置为batch_size={batch_size}
2. 对生成的每种材料进行系统的四要素分析:
- 成分:详细分析元素组成、化学计量比及其理论依据
- 结构:分析晶体结构、空间群、晶格参数、原子排布及其稳定性
- 工艺:探讨可行的合成路线、工艺参数及其科学依据
- 性能:预测材料可能具有的物理、化学、机械性能及其应用前景
3. 为每种材料设计详细的合成方案,包括:
- 原料选择及纯度要求
- 精确的反应条件参数(温度、压力、时间、气氛等)
- 分步骤的合成流程及每步的理论依据
- 可能遇到的挑战及解决方案
- 推荐的表征方法
4. 使用mermaid语法绘制材料的合成流程图清晰展示从原料到最终产品的全过程包括关键工艺参数。"""
messages.append(
types.PromptMessage(
role="user",
content=types.TextContent(type="text", text=prompt)
)
)
return messages
def register_prompt_handlers(app: Server):
"""
注册提示词处理器到MCP服务器。
Args:
app: MCP服务器实例
"""
@app.list_prompts()
async def list_prompts() -> list[types.Prompt]:
return [
types.Prompt(
name="material_synthesis",
description="基于材料四要素成分、结构、工艺、性能生成材料并设计合成方案使用mermaid绘制合成流程图",
arguments=[
types.PromptArgument(
name="properties",
description="材料性质及其值的JSON字符串例如 {\"dft_band_gap\": \"2.0\"}",
required=False,
),
types.PromptArgument(
name="batch_size",
description="生成材料的数量默认为2",
required=False,
),
],
)
]
@app.get_prompt()
async def get_prompt(
name: str, arguments: dict[str, str] | None = None
) -> types.GetPromptResult:
if name != "material_synthesis":
raise ValueError(f"未知的提示词: {name}")
if arguments is None:
arguments = {}
# 解析properties参数
properties = {}
if "properties" in arguments and arguments["properties"]:
try:
import json
properties = json.loads(arguments["properties"])
except json.JSONDecodeError:
properties = {}
# 解析batch_size参数
batch_size = 2 # 默认值
if "batch_size" in arguments and arguments["batch_size"]:
try:
batch_size = int(arguments["batch_size"])
except ValueError:
pass # 使用默认值
return types.GetPromptResult(
messages=create_messages(properties=properties, batch_size=batch_size),
description="基于材料四要素成分、结构、工艺、性能生成材料并设计合成方案使用mermaid绘制合成流程图",
)

306
server.py Executable file
View File

@@ -0,0 +1,306 @@
"""Mars Toolkit MCP Server implementation."""
import anyio
import asyncio
import click
import json
import logging
import os
import sys
import traceback
from typing import Any, Dict, List, Optional, Union
from prompts.material_synthesis import create_messages
# 添加mars_toolkit模块的路径
sys.path.append('/home/ubuntu/50T/lzy/mars-mcp')
import mcp.types as types
from mcp.server.lowlevel import Server
# 导入提示词处理器
#from prompts.material_synthesis import register_prompt_handlers
# 导入Mars Toolkit工具函数
try:
# 获取当前时间
from mars_toolkit.misc.misc_tools import get_current_time
# 网络搜索
from mars_toolkit.query.web_search import search_online
# 从Materials Project查询材料属性
from mars_toolkit.query.mp_query import search_material_property_from_material_project
# 从Materials Project获取晶体结构
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
# 从化学式获取Materials Project ID
from mars_toolkit.query.mp_query import get_mpid_from_formula
# 优化晶体结构
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
# 生成材料
from mars_toolkit.compute.material_gen import generate_material
# 从OQMD获取化学成分
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
# 从知识库检索
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
# 预测属性
from mars_toolkit.compute.property_pred import predict_properties
# 获取所有工具函数
from mars_toolkit import get_tools, get_tool_schemas
MARS_TOOLKIT_AVAILABLE = True
except ImportError as e:
print(f"警告: 无法导入Mars Toolkit: {e}", file=sys.stderr)
MARS_TOOLKIT_AVAILABLE = False
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
app = Server("mars-toolkit-server")
async def call_mars_toolkit_function(func_name: str, arguments: Dict[str, Any]) -> Any:
"""
调用Mars Toolkit中的工具函数
Args:
func_name: 工具函数名称
arguments: 工具函数参数
Returns:
工具函数的执行结果
"""
if not MARS_TOOLKIT_AVAILABLE:
raise ValueError("Mars Toolkit不可用")
# 获取所有注册的工具函数
tools = get_tools()
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
raise ValueError(f"函数 '{func_name}' 不存在于工具函数字典中")
# 获取对应的工具函数
tool_func = tools[func_name]
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
print("result1",result)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
return result
def get_tool_schemas_dict() -> Dict[str, Dict[str, Any]]:
"""
获取所有工具函数的模式字典
Returns:
工具函数名称到模式的映射字典
"""
if not MARS_TOOLKIT_AVAILABLE:
return {}
schemas = get_tool_schemas()
schemas_dict = {}
for schema in schemas:
func_name = schema["function"]["name"]
schemas_dict[func_name] = schema
return schemas_dict
@click.command()
@click.option("--port", default=5666, help="Port to listen on for SSE")
@click.option(
"--transport",
type=click.Choice(["stdio", "sse"]),
default="sse",
help="Transport type",
)
def main(port: int, transport: str='SSE') -> int:
"""
Mars Toolkit MCP Server主函数
Args:
port: SSE传输的端口号
transport: 传输类型stdio或sse
Returns:
退出码
"""
if not MARS_TOOLKIT_AVAILABLE:
print("错误: Mars Toolkit不可用请确保已正确安装", file=sys.stderr)
return 1
# 获取工具函数模式字典
schemas_dict = get_tool_schemas_dict()
# 注册提示词处理器
#register_prompt_handlers(app)
@app.list_prompts()
async def list_prompts() -> list[types.Prompt]:
return [
types.Prompt(
name="material_synthesis",
description="生成材料并设计合成方案使用mermaid绘制合成流程图",
arguments=[
types.PromptArgument(
name="properties",
description="材料性质及其值的JSON字符串例如 {\"dft_band_gap\": \"2.0\"}",
required=False,
),
types.PromptArgument(
name="batch_size",
description="生成材料的数量默认为2",
required=False,
),
],
)
]
@app.get_prompt()
async def get_prompt(
name: str, arguments: dict[str, str] | None = None
) -> types.GetPromptResult:
if name != "material_synthesis":
raise ValueError(f"未知的提示词: {name}")
if arguments is None:
arguments = {}
# 解析properties参数
properties = {}
if "properties" in arguments and arguments["properties"]:
try:
import json
properties = json.loads(arguments["properties"])
except json.JSONDecodeError:
properties = {}
# 解析batch_size参数
batch_size = 2 # 默认值
if "batch_size" in arguments and arguments["batch_size"]:
try:
batch_size = int(arguments["batch_size"])
except ValueError:
pass # 使用默认值
return types.GetPromptResult(
messages=create_messages(properties=properties, batch_size=batch_size),
description="生成材料并设计合成方案使用mermaid绘制合成流程图",
)
@app.call_tool()
async def call_tool(
name: str, arguments: Dict[str, Any]
) -> List[types.TextContent | types.ImageContent | types.EmbeddedResource]:
"""
调用工具函数
Args:
name: 工具函数名称
arguments: 工具函数参数
Returns:
工具函数的执行结果
"""
try:
print(f"调用{name},参数为{arguments}")
result = await call_mars_toolkit_function(name, arguments)
print("result",result)
# 将结果转换为字符串
if isinstance(result, (dict, list)):
result_str = json.dumps(result, ensure_ascii=False, indent=2)
else:
result_str = str(result)
return [types.TextContent(type="text", text=result_str)]
except Exception as e:
error_msg = f"调用工具函数 {name} 时出错: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
@app.list_tools()
async def list_tools() -> List[types.Tool]:
"""
列出所有可用的工具函数
Returns:
工具函数列表
"""
tools = []
print("列举所有可用的工具函数")
for func_name, schema in schemas_dict.items():
# 获取函数描述
description = schema["function"].get("description", f"Mars Toolkit工具: {func_name}")
# 获取参数模式
parameters = schema["function"].get("parameters", {})
# 创建工具
tool = types.Tool(
name=func_name,
description=description,
inputSchema=parameters,
)
tools.append(tool)
return tools
if transport == "sse":
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.routing import Mount, Route
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
import uvicorn
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
else:
from mcp.server.stdio import stdio_server
async def arun():
async with stdio_server() as streams:
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
anyio.run(arun)
return 0
if __name__ == "__main__":
print(get_tool_schemas_dict())
main()

4
test_mars_toolkit.py Normal file → Executable file
View File

@@ -155,7 +155,7 @@ def print_tool_schemas():
if __name__ == "__main__":
# 打印所有工具函数的模式
#print_tool_schemas()
print_tool_schemas()
# 测试工具函数列表
tools_to_test = [
@@ -172,7 +172,7 @@ if __name__ == "__main__":
]
# 选择要测试的工具
tool_name = tools_to_test[1] # 测试 search_online 工具
tool_name = tools_to_test[2] # 测试 search_online 工具
# 运行测试
result = asyncio.run(test_tool(tool_name))