mcp,生成数据代码
This commit is contained in:
0
.gitignore
vendored
Normal file → Executable file
0
.gitignore
vendored
Normal file → Executable file
BIN
__pycache__/api_key.cpython-310.pyc
Executable file
BIN
__pycache__/api_key.cpython-310.pyc
Executable file
Binary file not shown.
0
__pycache__/execute_tool_copy.cpython-310.pyc
Normal file → Executable file
0
__pycache__/execute_tool_copy.cpython-310.pyc
Normal file → Executable file
0
__pycache__/mattergen_wrapper.cpython-310.pyc
Normal file → Executable file
0
__pycache__/mattergen_wrapper.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
0
__pycache__/normalize_material_args.cpython-310.pyc
Normal file → Executable file
0
__pycache__/normalize_material_args.cpython-310.pyc
Normal file → Executable file
0
agent_test.py
Normal file → Executable file
0
agent_test.py
Normal file → Executable file
0
api_key.py
Normal file → Executable file
0
api_key.py
Normal file → Executable file
89
generate_data/filter_generate_material_data.py
Executable file
89
generate_data/filter_generate_material_data.py
Executable file
@@ -0,0 +1,89 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
读取JSONL文件,打印包含'generate_material content begin'的行
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
def count_lines(file_path):
|
||||||
|
"""计算文件的总行数"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return sum(1 for _ in f)
|
||||||
|
|
||||||
|
def filter_generate_material(file_path):
|
||||||
|
"""
|
||||||
|
读取JSONL文件,打印包含'generate_material content begin'的行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: JSONL文件路径
|
||||||
|
"""
|
||||||
|
# 计算文件总行数用于进度条
|
||||||
|
total_lines = count_lines(file_path)
|
||||||
|
console.print(f"[bold green]文件总行数: {total_lines}[/bold green]")
|
||||||
|
|
||||||
|
# 匹配计数
|
||||||
|
match_count = 0
|
||||||
|
|
||||||
|
# 使用进度条显示处理进度
|
||||||
|
with Progress(
|
||||||
|
TextColumn("[bold blue]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
) as progress:
|
||||||
|
# 创建进度条任务
|
||||||
|
task_id = progress.add_task("[bold green]处理JSONL文件...", total=total_lines)
|
||||||
|
|
||||||
|
# 逐行处理文件
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析JSON
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
|
||||||
|
# 检查是否包含observation字段
|
||||||
|
# if 'observation' in data:
|
||||||
|
# # 检查observation是否包含目标字符串
|
||||||
|
# if 'generate_material content begin' in data['observation'] and '[generate_material content begin]Error' not in data['observation']:
|
||||||
|
|
||||||
|
# console.print(f"[bold cyan]第 {line_num} 行包含目标内容:[/bold cyan]")
|
||||||
|
# console.print(line.strip())
|
||||||
|
# console.print("-" * 80)
|
||||||
|
# print(data)
|
||||||
|
# match_count += 1
|
||||||
|
# return
|
||||||
|
if 'solution' in data:
|
||||||
|
print("line_num",line_num)
|
||||||
|
print("solution",data['solution'])
|
||||||
|
if data['solution'] != "":
|
||||||
|
print( data['solution'])
|
||||||
|
return
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
console.print(f"[bold red]第 {line_num} 行JSON解析错误[/bold red]")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[bold red]处理第 {line_num} 行时出错: {str(e)}[/bold red]")
|
||||||
|
|
||||||
|
# 更新进度条
|
||||||
|
progress.update(task_id, advance=1)
|
||||||
|
|
||||||
|
console.print(f"[bold green]处理完成,共找到 {match_count} 行包含 'generate_material content begin'[/bold green]")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 默认文件路径
|
||||||
|
file_path ='/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl'
|
||||||
|
# "/home/ubuntu/50T/lzy/mars-mcp/mars-agent_data_20250408205427.jsonl"
|
||||||
|
|
||||||
|
# 如果提供了命令行参数,则使用命令行参数作为文件路径
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
file_path = sys.argv[1]
|
||||||
|
|
||||||
|
console.print(f"[bold blue]正在处理文件: {file_path}[/bold blue]")
|
||||||
|
filter_generate_material(file_path)
|
||||||
195
execute_tool_other_tools.py → generate_data/generate_data10000.py
Normal file → Executable file
195
execute_tool_other_tools.py → generate_data/generate_data10000.py
Normal file → Executable file
@@ -17,7 +17,7 @@ import random
|
|||||||
# 初始化colorama
|
# 初始化colorama
|
||||||
init(autoreset=True)
|
init(autoreset=True)
|
||||||
|
|
||||||
from typing import Dict, Union, Any, Optional
|
from typing import Dict, Union, Any, Optional, List
|
||||||
|
|
||||||
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -110,6 +110,8 @@ def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
# 创建数据库连接池(仅用于初始加载数据)
|
||||||
connection_pool = pooling.MySQLConnectionPool(
|
connection_pool = pooling.MySQLConnectionPool(
|
||||||
pool_name="mypool",
|
pool_name="mypool",
|
||||||
pool_size=32,
|
pool_size=32,
|
||||||
@@ -120,7 +122,50 @@ connection_pool = pooling.MySQLConnectionPool(
|
|||||||
database='metadata_mat_papers'
|
database='metadata_mat_papers'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 内存缓存,用于存储从数据库加载的数据
|
||||||
|
# 结构: {doi: record, mp_id: record}
|
||||||
|
memory_cache = {}
|
||||||
|
|
||||||
|
def load_data_to_memory():
|
||||||
|
"""
|
||||||
|
从数据库加载所有数据到内存中
|
||||||
|
"""
|
||||||
|
print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
conn = connection_pool.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
try:
|
||||||
|
# 查询所有记录
|
||||||
|
cursor.execute("SELECT * FROM mp_synthesis_scheme_info")
|
||||||
|
records = cursor.fetchall()
|
||||||
|
|
||||||
|
# 将记录添加到内存缓存中
|
||||||
|
for record in records:
|
||||||
|
doi = record.get('doi')
|
||||||
|
mp_id = record.get('mp_id')
|
||||||
|
|
||||||
|
# 使用doi作为键(如果存在)
|
||||||
|
if doi:
|
||||||
|
memory_cache[doi] = record
|
||||||
|
|
||||||
|
# 使用mp_id作为键(如果存在)
|
||||||
|
if mp_id:
|
||||||
|
memory_cache[mp_id] = record
|
||||||
|
|
||||||
|
print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}")
|
||||||
|
print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# 在程序启动时加载数据到内存中
|
||||||
|
load_data_to_memory()
|
||||||
|
|
||||||
async def process_retrieval_from_knowledge_base(data):
|
async def process_retrieval_from_knowledge_base(data):
|
||||||
|
|
||||||
doi = data.get('doi')
|
doi = data.get('doi')
|
||||||
mp_id = data.get('mp_id')
|
mp_id = data.get('mp_id')
|
||||||
|
|
||||||
@@ -128,31 +173,12 @@ async def process_retrieval_from_knowledge_base(data):
|
|||||||
if doi is None and mp_id is None:
|
if doi is None and mp_id is None:
|
||||||
return "" # 如果没有提供查询参数,返回空字符串
|
return "" # 如果没有提供查询参数,返回空字符串
|
||||||
|
|
||||||
# 构建SQL查询条件
|
# 从内存缓存中查询匹配的记录
|
||||||
query = "SELECT * FROM mp_synthesis_scheme_info WHERE "
|
result = None
|
||||||
params = []
|
if doi is not None and doi in memory_cache:
|
||||||
|
result = memory_cache[doi]
|
||||||
if doi is not None and mp_id is not None:
|
elif mp_id is not None and mp_id in memory_cache:
|
||||||
query += "doi = %s OR mp_id = %s"
|
result = memory_cache[mp_id]
|
||||||
params = [doi, mp_id]
|
|
||||||
elif doi is not None:
|
|
||||||
query += "doi = %s"
|
|
||||||
params = [doi]
|
|
||||||
else: # mp_id is not None
|
|
||||||
query += "mp_id = %s"
|
|
||||||
params = [mp_id]
|
|
||||||
|
|
||||||
# 从数据库中查询匹配的记录
|
|
||||||
conn = connection_pool.get_connection()
|
|
||||||
try:
|
|
||||||
cursor = conn.cursor(dictionary=True)
|
|
||||||
try:
|
|
||||||
cursor.execute(query, params)
|
|
||||||
result = cursor.fetchone() # 获取第一个匹配的记录
|
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
# 检查是否找到匹配的记录
|
# 检查是否找到匹配的记录
|
||||||
if not result:
|
if not result:
|
||||||
@@ -265,13 +291,13 @@ def worker(data, output_file_path):
|
|||||||
arguments_data = func.get("arguments")
|
arguments_data = func.get("arguments")
|
||||||
|
|
||||||
# 使用富文本打印函数名
|
# 使用富文本打印函数名
|
||||||
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||||
|
|
||||||
# 使用富文本打印参数
|
# 使用富文本打印参数
|
||||||
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
||||||
|
|
||||||
if func.get("name") == 'retrieval_from_knowledge_base':
|
if func.get("name") == 'retrieval_from_knowledge_base':
|
||||||
pass
|
|
||||||
# delay_time = random.uniform(5, 10)
|
# delay_time = random.uniform(5, 10)
|
||||||
# time.sleep(delay_time)
|
# time.sleep(delay_time)
|
||||||
result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
||||||
@@ -281,60 +307,39 @@ def worker(data, output_file_path):
|
|||||||
formatted_results.append(formatted_result)
|
formatted_results.append(formatted_result)
|
||||||
|
|
||||||
elif func.get("name") == 'generate_material':
|
elif func.get("name") == 'generate_material':
|
||||||
# # 规范化参数
|
try:
|
||||||
# try:
|
# 确保arguments_data是字典
|
||||||
# # 确保arguments_data是字典
|
if isinstance(arguments_data, str):
|
||||||
# if isinstance(arguments_data, str):
|
try:
|
||||||
# try:
|
arguments_data = json.loads(arguments_data)
|
||||||
# arguments_data = json.loads(arguments_data)
|
except json.JSONDecodeError as e:
|
||||||
# except json.JSONDecodeError as e:
|
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
|
||||||
# print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
|
continue
|
||||||
# continue
|
|
||||||
|
|
||||||
# # 规范化参数
|
# 规范化参数
|
||||||
# normalized_args = normalize_material_args(arguments_data)
|
normalized_args = normalize_material_args(arguments_data)
|
||||||
# # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
|
||||||
# # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
|
||||||
# # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
|
||||||
|
|
||||||
# # 优先使用mattergen函数
|
# 优先使用mattergen函数
|
||||||
# try:
|
try:
|
||||||
# # output = asyncio.run(generate_material(**normalized_args))
|
|
||||||
# output = generate_material(**normalized_args)
|
|
||||||
|
|
||||||
# # 添加延迟,模拟额外的工具函数调用
|
output = generate_material(**normalized_args)
|
||||||
|
|
||||||
# # 随机延迟5-10秒
|
|
||||||
# # delay_time = random.uniform(5, 10)
|
|
||||||
# # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
|
|
||||||
# # time.sleep(delay_time)
|
|
||||||
|
|
||||||
# # # 模拟其他工具函数调用的日志输出
|
except Exception as e:
|
||||||
# # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
|
#print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
||||||
# # time.sleep(0.5)
|
continue
|
||||||
# # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
|
# 将结果添加到func_results
|
||||||
# # time.sleep(0.5)
|
func_results.append({"function": func_name, "result": output})
|
||||||
# # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
|
|
||||||
# # time.sleep(0.5)
|
|
||||||
# # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
|
|
||||||
|
|
||||||
# except Exception as e:
|
# 格式化结果
|
||||||
# print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
|
||||||
|
formatted_results.append(formatted_result)
|
||||||
# # 将结果添加到func_results
|
except Exception as e:
|
||||||
# func_results.append({"function": func_name, "result": output})
|
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
|
||||||
|
import traceback
|
||||||
# # 格式化结果
|
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
||||||
# formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
|
continue
|
||||||
# formatted_results.append(formatted_result)
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
|
|
||||||
# import traceback
|
|
||||||
# print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
# delay_time = random.uniform(5, 10)
|
|
||||||
# time.sleep(delay_time)
|
|
||||||
|
|
||||||
result = asyncio.run(execute_tool_from_dict(func))
|
result = asyncio.run(execute_tool_from_dict(func))
|
||||||
func_results.append({"function": func['name'], "result": result})
|
func_results.append({"function": func['name'], "result": result})
|
||||||
@@ -347,17 +352,17 @@ def worker(data, output_file_path):
|
|||||||
final_result = "\n\n\n".join(formatted_results)
|
final_result = "\n\n\n".join(formatted_results)
|
||||||
data['observation'] = final_result
|
data['observation'] = final_result
|
||||||
|
|
||||||
# 使用富文本打印开始和结束标记
|
#使用富文本打印开始和结束标记
|
||||||
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
|
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
|
||||||
print(data['observation'])
|
# print(data['observation'])
|
||||||
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
|
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
|
||||||
with file_lock:
|
with file_lock:
|
||||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||||
writer.write(data) # observation . data
|
writer.write(data) # observation . data
|
||||||
return f"Processed successfully"
|
return f"Processed successfully"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
|
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
|
||||||
return f"Error processing: {str(e)}"
|
return f"Error processing: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
@@ -365,7 +370,6 @@ def main(datas, output_file_path, max_workers=1):
|
|||||||
import random
|
import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import os
|
import os
|
||||||
from mysql.connector import pooling, Error
|
|
||||||
|
|
||||||
# 创建进度条
|
# 创建进度条
|
||||||
pbar = tqdm(total=len(datas), desc="Processing papers")
|
pbar = tqdm(total=len(datas), desc="Processing papers")
|
||||||
@@ -403,21 +407,26 @@ def main(datas, output_file_path, max_workers=1):
|
|||||||
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import datetime
|
import datetime
|
||||||
import jsonlines
|
import jsonlines
|
||||||
datas = []
|
datas_with_solution = []
|
||||||
|
datas_without_solution = []
|
||||||
|
|
||||||
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
|
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
|
||||||
for obj in reader:
|
for obj in reader:
|
||||||
datas.append(obj)
|
if obj['solution']!='':
|
||||||
|
datas_with_solution.append(obj)
|
||||||
|
else:
|
||||||
|
datas_without_solution.append(obj)
|
||||||
|
|
||||||
print(len(datas))
|
datas_with_solution = datas_with_solution[:5000]
|
||||||
# print()
|
datas_without_solution = datas_without_solution[:5000]
|
||||||
output_file = f"./filter_ok_questions_solutions_agent_other_tools_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
|
||||||
|
datas = datas_with_solution + datas_without_solution
|
||||||
|
import random
|
||||||
|
random.shuffle(datas)
|
||||||
|
|
||||||
|
output_file = f"./filter_ok_questions_solutions_agent_data10000_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||||
main(datas, output_file, max_workers=32)
|
main(datas, output_file, max_workers=32)
|
||||||
|
|
||||||
# 示例1:使用正确的JSON格式
|
|
||||||
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
|
|
||||||
# argument = json.loads(argument)
|
|
||||||
# print(json.dumps(argument, indent=2))
|
|
||||||
# asyncio.run(mattergen(**argument))
|
|
||||||
161
execute_tool_copy.py → generate_data/generate_tool_observation.py
Normal file → Executable file
161
execute_tool_copy.py → generate_data/generate_tool_observation.py
Normal file → Executable file
@@ -17,7 +17,7 @@ import random
|
|||||||
# 初始化colorama
|
# 初始化colorama
|
||||||
init(autoreset=True)
|
init(autoreset=True)
|
||||||
|
|
||||||
from typing import Dict, Union, Any, Optional
|
from typing import Dict, Union, Any, Optional, List
|
||||||
|
|
||||||
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -110,6 +110,8 @@ def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
# 创建数据库连接池(仅用于初始加载数据)
|
||||||
connection_pool = pooling.MySQLConnectionPool(
|
connection_pool = pooling.MySQLConnectionPool(
|
||||||
pool_name="mypool",
|
pool_name="mypool",
|
||||||
pool_size=32,
|
pool_size=32,
|
||||||
@@ -120,7 +122,50 @@ connection_pool = pooling.MySQLConnectionPool(
|
|||||||
database='metadata_mat_papers'
|
database='metadata_mat_papers'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 内存缓存,用于存储从数据库加载的数据
|
||||||
|
# 结构: {doi: record, mp_id: record}
|
||||||
|
memory_cache = {}
|
||||||
|
|
||||||
|
def load_data_to_memory():
|
||||||
|
"""
|
||||||
|
从数据库加载所有数据到内存中
|
||||||
|
"""
|
||||||
|
print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
conn = connection_pool.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
try:
|
||||||
|
# 查询所有记录
|
||||||
|
cursor.execute("SELECT * FROM mp_synthesis_scheme_info")
|
||||||
|
records = cursor.fetchall()
|
||||||
|
|
||||||
|
# 将记录添加到内存缓存中
|
||||||
|
for record in records:
|
||||||
|
doi = record.get('doi')
|
||||||
|
mp_id = record.get('mp_id')
|
||||||
|
|
||||||
|
# 使用doi作为键(如果存在)
|
||||||
|
if doi:
|
||||||
|
memory_cache[doi] = record
|
||||||
|
|
||||||
|
# 使用mp_id作为键(如果存在)
|
||||||
|
if mp_id:
|
||||||
|
memory_cache[mp_id] = record
|
||||||
|
|
||||||
|
print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}")
|
||||||
|
print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# 在程序启动时加载数据到内存中
|
||||||
|
load_data_to_memory()
|
||||||
|
|
||||||
async def process_retrieval_from_knowledge_base(data):
|
async def process_retrieval_from_knowledge_base(data):
|
||||||
|
|
||||||
doi = data.get('doi')
|
doi = data.get('doi')
|
||||||
mp_id = data.get('mp_id')
|
mp_id = data.get('mp_id')
|
||||||
|
|
||||||
@@ -128,31 +173,12 @@ async def process_retrieval_from_knowledge_base(data):
|
|||||||
if doi is None and mp_id is None:
|
if doi is None and mp_id is None:
|
||||||
return "" # 如果没有提供查询参数,返回空字符串
|
return "" # 如果没有提供查询参数,返回空字符串
|
||||||
|
|
||||||
# 构建SQL查询条件
|
# 从内存缓存中查询匹配的记录
|
||||||
query = "SELECT * FROM mp_synthesis_scheme_info WHERE "
|
result = None
|
||||||
params = []
|
if doi is not None and doi in memory_cache:
|
||||||
|
result = memory_cache[doi]
|
||||||
if doi is not None and mp_id is not None:
|
elif mp_id is not None and mp_id in memory_cache:
|
||||||
query += "doi = %s OR mp_id = %s"
|
result = memory_cache[mp_id]
|
||||||
params = [doi, mp_id]
|
|
||||||
elif doi is not None:
|
|
||||||
query += "doi = %s"
|
|
||||||
params = [doi]
|
|
||||||
else: # mp_id is not None
|
|
||||||
query += "mp_id = %s"
|
|
||||||
params = [mp_id]
|
|
||||||
|
|
||||||
# 从数据库中查询匹配的记录
|
|
||||||
conn = connection_pool.get_connection()
|
|
||||||
try:
|
|
||||||
cursor = conn.cursor(dictionary=True)
|
|
||||||
try:
|
|
||||||
cursor.execute(query, params)
|
|
||||||
result = cursor.fetchone() # 获取第一个匹配的记录
|
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
# 检查是否找到匹配的记录
|
# 检查是否找到匹配的记录
|
||||||
if not result:
|
if not result:
|
||||||
@@ -265,62 +291,43 @@ def worker(data, output_file_path):
|
|||||||
arguments_data = func.get("arguments")
|
arguments_data = func.get("arguments")
|
||||||
|
|
||||||
# 使用富文本打印函数名
|
# 使用富文本打印函数名
|
||||||
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||||
|
|
||||||
# 使用富文本打印参数
|
# 使用富文本打印参数
|
||||||
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
||||||
|
|
||||||
if func.get("name") == 'retrieval_from_knowledge_base':
|
if func.get("name") == 'retrieval_from_knowledge_base':
|
||||||
pass
|
|
||||||
# delay_time = random.uniform(5, 10)
|
# delay_time = random.uniform(5, 10)
|
||||||
# time.sleep(delay_time)
|
# time.sleep(delay_time)
|
||||||
# result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
||||||
# func_results.append({"function": func['name'], "result": result})
|
func_results.append({"function": func['name'], "result": result})
|
||||||
# # 格式化结果
|
# 格式化结果
|
||||||
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||||
# formatted_results.append(formatted_result)
|
formatted_results.append(formatted_result)
|
||||||
|
|
||||||
elif func.get("name") == 'generate_material':
|
elif func.get("name") == 'generate_material':
|
||||||
# 规范化参数
|
|
||||||
try:
|
try:
|
||||||
# 确保arguments_data是字典
|
# 确保arguments_data是字典
|
||||||
if isinstance(arguments_data, str):
|
if isinstance(arguments_data, str):
|
||||||
try:
|
try:
|
||||||
arguments_data = json.loads(arguments_data)
|
arguments_data = json.loads(arguments_data)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
|
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 规范化参数
|
# 规范化参数
|
||||||
normalized_args = normalize_material_args(arguments_data)
|
normalized_args = normalize_material_args(arguments_data)
|
||||||
# print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
|
||||||
# print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
|
||||||
# print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
|
||||||
|
|
||||||
# 优先使用mattergen函数
|
# 优先使用mattergen函数
|
||||||
try:
|
try:
|
||||||
# output = asyncio.run(generate_material(**normalized_args))
|
|
||||||
output = generate_material(**normalized_args)
|
output = generate_material(**normalized_args)
|
||||||
|
|
||||||
# 添加延迟,模拟额外的工具函数调用
|
|
||||||
|
|
||||||
# 随机延迟5-10秒
|
|
||||||
# delay_time = random.uniform(5, 10)
|
|
||||||
# print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
|
|
||||||
# time.sleep(delay_time)
|
|
||||||
|
|
||||||
# # 模拟其他工具函数调用的日志输出
|
|
||||||
# print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
|
|
||||||
# time.sleep(0.5)
|
|
||||||
# print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
|
|
||||||
# time.sleep(0.5)
|
|
||||||
# print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
|
|
||||||
# time.sleep(0.5)
|
|
||||||
# print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
#print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
||||||
|
continue
|
||||||
# 将结果添加到func_results
|
# 将结果添加到func_results
|
||||||
func_results.append({"function": func_name, "result": output})
|
func_results.append({"function": func_name, "result": output})
|
||||||
|
|
||||||
@@ -328,36 +335,34 @@ def worker(data, output_file_path):
|
|||||||
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
|
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
|
||||||
formatted_results.append(formatted_result)
|
formatted_results.append(formatted_result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
|
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
|
||||||
import traceback
|
import traceback
|
||||||
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
# delay_time = random.uniform(5, 10)
|
|
||||||
# time.sleep(delay_time)
|
result = asyncio.run(execute_tool_from_dict(func))
|
||||||
pass
|
func_results.append({"function": func['name'], "result": result})
|
||||||
# result = asyncio.run(execute_tool_from_dict(func))
|
# 格式化结果
|
||||||
# func_results.append({"function": func['name'], "result": result})
|
func_name = func.get("name")
|
||||||
# # 格式化结果
|
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||||
# func_name = func.get("name")
|
formatted_results.append(formatted_result)
|
||||||
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
|
||||||
# formatted_results.append(formatted_result)
|
|
||||||
|
|
||||||
# 将所有格式化后的结果连接起来
|
# 将所有格式化后的结果连接起来
|
||||||
final_result = "\n\n\n".join(formatted_results)
|
final_result = "\n\n\n".join(formatted_results)
|
||||||
data['observation'] = final_result
|
data['observation'] = final_result
|
||||||
|
|
||||||
# 使用富文本打印开始和结束标记
|
#使用富文本打印开始和结束标记
|
||||||
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
|
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
|
||||||
print(data['observation'])
|
# print(data['observation'])
|
||||||
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
|
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
|
||||||
with file_lock:
|
with file_lock:
|
||||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||||
writer.write(data) # observation . data
|
writer.write(data) # observation . data
|
||||||
return f"Processed successfully"
|
return f"Processed successfully"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
|
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
|
||||||
return f"Error processing: {str(e)}"
|
return f"Error processing: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
@@ -365,7 +370,6 @@ def main(datas, output_file_path, max_workers=1):
|
|||||||
import random
|
import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import os
|
import os
|
||||||
from mysql.connector import pooling, Error
|
|
||||||
|
|
||||||
# 创建进度条
|
# 创建进度条
|
||||||
pbar = tqdm(total=len(datas), desc="Processing papers")
|
pbar = tqdm(total=len(datas), desc="Processing papers")
|
||||||
@@ -409,12 +413,13 @@ if __name__ == '__main__':
|
|||||||
datas = []
|
datas = []
|
||||||
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
|
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
|
||||||
for obj in reader:
|
for obj in reader:
|
||||||
datas.append(obj)
|
#if obj['solution']!='':
|
||||||
|
datas.append(obj)
|
||||||
|
|
||||||
print(len(datas))
|
print(len(datas))
|
||||||
# print()
|
# print()
|
||||||
output_file = f"./filter_ok_questions_solutions_agent_mattergen_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
output_file = f"./filter_ok_questions_solutions_agent_data_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||||
main(datas, output_file, max_workers=1)
|
main(datas, output_file, max_workers=32)
|
||||||
|
|
||||||
# 示例1:使用正确的JSON格式
|
# 示例1:使用正确的JSON格式
|
||||||
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
|
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
|
||||||
91
generate_data/grpo_tools.py
Executable file
91
generate_data/grpo_tools.py
Executable file
@@ -0,0 +1,91 @@
|
|||||||
|
import jsonlines
|
||||||
|
import argparse
|
||||||
|
import generate_data.utils as utils
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
from ase import io
|
||||||
|
import tempfile
|
||||||
|
import re
|
||||||
|
from pymatgen.io.vasp import Poscar
|
||||||
|
from pymatgen.io.cif import CifParser
|
||||||
|
import threading
|
||||||
|
import concurrent.futures
|
||||||
|
import copy
|
||||||
|
from grpo_utils import generate_design_question, generate_props_question, generate_obs_response
|
||||||
|
# Create a lock for file writing
|
||||||
|
file_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def worker(data, output_file_path):
|
||||||
|
try:
|
||||||
|
messages = copy.deepcopy(data['messages'])
|
||||||
|
obs = data['observation']
|
||||||
|
messages[-1]['content'] = messages[-1]['content'].split("<answer>")[-1].split("</answer>")[0]
|
||||||
|
messages.append({"role": "user", "content": obs})
|
||||||
|
data['messages'].append({"role": "user", "content": obs})
|
||||||
|
# print(messages)
|
||||||
|
# print(obs)
|
||||||
|
|
||||||
|
reasoning_content, response = generate_obs_response(messages)
|
||||||
|
data['messages'].append({"role": "assistant", "content": f"<think>\n{reasoning_content}</think>\n<answer>\n{response}</answer>\n"})
|
||||||
|
# Use the lock to safely write to the file
|
||||||
|
with file_lock:
|
||||||
|
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||||
|
writer.write(messages)
|
||||||
|
|
||||||
|
return f"Processed successfully"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error processing: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def main(input_file_path, output_file_path, max_workers=1):
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
|
||||||
|
datas = None
|
||||||
|
with jsonlines.open(input_file_path, mode='r') as reader:
|
||||||
|
datas = [line for line in reader]
|
||||||
|
|
||||||
|
# 创建进度条
|
||||||
|
pbar = tqdm(total=len(datas), desc="Processing CIF files")
|
||||||
|
|
||||||
|
# 创建一个线程池
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# 提交任务到执行器
|
||||||
|
future_to_data = {}
|
||||||
|
for data in datas:
|
||||||
|
future = executor.submit(worker, data, output_file_path)
|
||||||
|
future_to_data[future] = data
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
completed = 0
|
||||||
|
failed = 0
|
||||||
|
for future in concurrent.futures.as_completed(future_to_data):
|
||||||
|
data = future_to_data[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
if "successfully" in result:
|
||||||
|
completed += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
# 更新进度条
|
||||||
|
pbar.update(1)
|
||||||
|
# 每100个文件更新一次统计信息
|
||||||
|
if (completed + failed) % 100 == 0:
|
||||||
|
pbar.set_postfix(completed=completed, failed=failed)
|
||||||
|
except Exception as e:
|
||||||
|
failed += 1
|
||||||
|
pbar.update(1)
|
||||||
|
print(f"\nWorker for {data} generated an exception: {e}")
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import datetime
|
||||||
|
origin_file = "/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl"
|
||||||
|
output_file = f"agent_questions_solutions_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||||
|
main(origin_file, output_file)
|
||||||
|
|
||||||
294
generate_data/grpo_utils.py
Executable file
294
generate_data/grpo_utils.py
Executable file
@@ -0,0 +1,294 @@
|
|||||||
|
import jsonlines
|
||||||
|
import argparse
|
||||||
|
import generate_data.utils as utils
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
from ase import io
|
||||||
|
import tempfile
|
||||||
|
import re
|
||||||
|
from pymatgen.io.vasp import Poscar
|
||||||
|
from pymatgen.io.cif import CifParser
|
||||||
|
import threading
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
# Create a lock for file writing
|
||||||
|
file_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_design_question(crystal_desc, cif_info, crystal_props, max_retries=3, initial_backoff=1.0):
|
||||||
|
instruction = """
|
||||||
|
{crystal_desc}
|
||||||
|
|
||||||
|
### 对应的晶体结构数据(CIF)如下:
|
||||||
|
{cif_info}
|
||||||
|
|
||||||
|
### 该晶体结构的物理化学性质为:
|
||||||
|
{crystal_props}
|
||||||
|
|
||||||
|
根据如上信息,我现在需要给材料科学的博士考试出题,问题要求博士们回答出上文中的完整CIF文件,如果是你你会如何出题?
|
||||||
|
也就是说,要求我们提出的问题的答案是上文中提及的完整CIF文件。当然,你的问题必须给定充足的该晶体结构的相关信息。
|
||||||
|
但是相关信息应该抽象和隐晦,避免过于直白,除明确的化学表达式外,尽量避免过多的精确信息,让博士考生们可以通过推理得到某些信息以增加问题的难度。
|
||||||
|
问题的语言一半是中文,一半是英文,以便更好地与模型进行交互。
|
||||||
|
|
||||||
|
请先生成10个问题示例,再挑选2个最好的问题示例并遵循如下格式输出:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"selected_questions": [
|
||||||
|
{
|
||||||
|
"question_id": 1,
|
||||||
|
"question_text": "问题1的完整内容...",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question_id": 2,
|
||||||
|
"question_text": "问题2的完整内容...",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
instruction = instruction.replace("{crystal_desc}", crystal_desc).replace("{cif_info}", cif_info).replace("{crystal_props}", crystal_props)
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": ""},
|
||||||
|
{"role": "user", "content": instruction}
|
||||||
|
]
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# print(f"Time: {time.time() - start_time}")
|
||||||
|
if _response == 'apierror' or _response == 'unexpectederror':
|
||||||
|
return _response
|
||||||
|
# 尝试从响应中提取JSON部分
|
||||||
|
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
json_str = json_match.group(1)
|
||||||
|
try:
|
||||||
|
questions_data = json.loads(json_str)
|
||||||
|
return questions_data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 如果JSON解析失败,尝试清理字符串后再次解析
|
||||||
|
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
|
||||||
|
try:
|
||||||
|
questions_data = json.loads(cleaned_json)
|
||||||
|
return questions_data
|
||||||
|
except:
|
||||||
|
return {"error": "Failed to parse JSON response", "raw_response": _response}
|
||||||
|
else:
|
||||||
|
# 如果没有找到JSON格式,返回原始响应
|
||||||
|
return {"error": "No JSON format found in response", "raw_response": _response}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_props_question(crystal_desc, cif_info, crystal_props, max_retries=3, initial_backoff=1.0):
|
||||||
|
instruction = """
|
||||||
|
{crystal_desc}
|
||||||
|
|
||||||
|
### 对应的晶体结构数据(CIF)如下:
|
||||||
|
{cif_info}
|
||||||
|
|
||||||
|
### 该晶体结构的物理化学性质为:
|
||||||
|
{crystal_props}
|
||||||
|
|
||||||
|
根据如上信息,我现在需要给材料科学的博士考试出题,问题要求博士们根据CIF文件回答出上文中的物理化学性质,如果是你你会如何出题?
|
||||||
|
也就是说,要求我们提出的问题的答案是上文中提及的物理化学性质。当然,你的问题必须尽量包含一个<placeholder>标签代表给定的CIF文件。
|
||||||
|
让博士考生们根据给定的CIF文件通过深入思考和推理去分析该种晶体材料在上文所提及的全部物理化学性质,并用JSON格式回答全部的物理化学性质。
|
||||||
|
问题的语言一半是中文,一半是英文,以便更好地与模型进行交互。
|
||||||
|
|
||||||
|
示例的问题:
|
||||||
|
1. <placeholder>\n,根据上文提供的CIF文件,请你xxx
|
||||||
|
2. 根据下文提供的CIF文件,请你xxx\n <<placeholder>>
|
||||||
|
|
||||||
|
请先生成10个问题示例,再挑选2个最好的问题示例并遵循如下格式输出:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"selected_questions": [
|
||||||
|
{
|
||||||
|
"question_id": 1,
|
||||||
|
"question_text": "问题1的完整内容...",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question_id": 2,
|
||||||
|
"question_text": "问题2的完整内容...",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
instruction = instruction.replace("{crystal_desc}", crystal_desc).replace("{cif_info}", cif_info).replace("{crystal_props}", crystal_props)
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": ""},
|
||||||
|
{"role": "user", "content": instruction}
|
||||||
|
]
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# print(f"Time: {time.time() - start_time}")
|
||||||
|
if _response == 'apierror' or _response == 'unexpectederror':
|
||||||
|
return _response
|
||||||
|
# 尝试从响应中提取JSON部分
|
||||||
|
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
json_str = json_match.group(1)
|
||||||
|
try:
|
||||||
|
questions_data = json.loads(json_str)
|
||||||
|
return questions_data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 如果JSON解析失败,尝试清理字符串后再次解析
|
||||||
|
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
|
||||||
|
try:
|
||||||
|
questions_data = json.loads(cleaned_json)
|
||||||
|
return questions_data
|
||||||
|
except:
|
||||||
|
return {"error": "Failed to parse JSON response", "raw_response": _response}
|
||||||
|
else:
|
||||||
|
# 如果没有找到JSON格式,返回原始响应
|
||||||
|
return {"error": "No JSON format found in response", "raw_response": _response}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_papers_other_question(paper_info, max_retries=3, initial_backoff=1.0):
|
||||||
|
instruction = """
|
||||||
|
{paper_info}
|
||||||
|
|
||||||
|
根据如上信息,我现在需要给材料科学的博士学生出题,问题要求考察博士对该材料的反应方程式、结构、性能和应用是否完全掌握,如果是你你会怎么出题?
|
||||||
|
你的问题里面应该包含该材料相关的合适的信息,且是自包含的(在只有问题的情况下问题中的关键信息不遗漏),但问题需要有难度和深度,需要博士生们深入思考和推理后才能作为准确的回答。
|
||||||
|
由于问题面向博士,因此,提出的问题需要一定的科研价值导向。涉及到反应方程式、关于结构、性能和应用等方面的具体试剂量等信息时,要求他们尽可能给出精确的数值(前提是这些数值在上文中存在)。
|
||||||
|
|
||||||
|
|
||||||
|
请先生成12个问题示例,12个问题的语言一半是中文,一半是英文,再挑选4个最好的问题示例并遵循如下格式输出:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"selected_questions": [
|
||||||
|
{
|
||||||
|
"question_id": 1,
|
||||||
|
"question_text": "问题1的完整内容...",
|
||||||
|
"question_type": "问题1的类型", # reaction_string; structure; performence; application
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question_id": 2,
|
||||||
|
"question_text": "问题2的完整内容...",
|
||||||
|
"question_type": "问题1的类型", # reaction_string; structure; performence; application
|
||||||
|
}, ...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
instruction = instruction.replace("{paper_info}", paper_info)
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": ""},
|
||||||
|
{"role": "user", "content": instruction}
|
||||||
|
]
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# print(f"Time: {time.time() - start_time}")
|
||||||
|
if _response == 'apierror' or _response == 'unexpectederror':
|
||||||
|
return _response
|
||||||
|
# 尝试从响应中提取JSON部分
|
||||||
|
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
json_str = json_match.group(1)
|
||||||
|
try:
|
||||||
|
questions_data = json.loads(json_str)
|
||||||
|
return questions_data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 如果JSON解析失败,尝试清理字符串后再次解析
|
||||||
|
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
|
||||||
|
try:
|
||||||
|
questions_data = json.loads(cleaned_json)
|
||||||
|
return questions_data
|
||||||
|
except:
|
||||||
|
return {"error": "Failed to parse JSON response", "raw_response": _response}
|
||||||
|
else:
|
||||||
|
# 如果没有找到JSON格式,返回原始响应
|
||||||
|
return {"error": "No JSON format found in response", "raw_response": _response}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_papers_synthesis_question(paper_info, max_retries=3, initial_backoff=1.0):
|
||||||
|
instruction = """
|
||||||
|
{paper_info}
|
||||||
|
|
||||||
|
根据如上信息,我现在需要给材料科学的博士学生出题,问题要求考察博士是否完全掌握该材料的合成方案,是否完全掌握给定材料的结构和性能到合成方案的精准映射关系,如果是你你会怎么出题?
|
||||||
|
你的问题里面应该包含该材料充分的结构和性能信息,问题需要有难度和深度,需要博士生们深入思考和推理后才能给出准确的合成方案并整理成JSON格式的格式化合成方案。
|
||||||
|
由于问题面向博士,因此,提出的问题需要一定的科研价值导向,并且要求博士在回答该材料的合成方案时给出精确的数值(包括试剂、前驱体、容器、温度等合成条件)。
|
||||||
|
问题中作为条件信息的部分需要尽可能的在问题中明确而不是隐晦(你要考虑到博士们拿到问题的时候并不知道上文中的信息,所以类似“基于给定的材料结构和性能信息”这种问法应该尽量避免)。
|
||||||
|
|
||||||
|
请先生成6个问题示例,6个问题的语言一半是中文,一半是英文,再挑选2个最好的问题示例并遵循如下格式输出:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"selected_questions": [
|
||||||
|
{
|
||||||
|
"question_id": 1,
|
||||||
|
"question_text": "问题1的完整内容...",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question_id": 2,
|
||||||
|
"question_text": "问题2的完整内容...",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
instruction = instruction.replace("{paper_info}", paper_info)
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": ""},
|
||||||
|
{"role": "user", "content": instruction}
|
||||||
|
]
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# print(f"Time: {time.time() - start_time}")
|
||||||
|
if _response == 'apierror' or _response == 'unexpectederror':
|
||||||
|
return _response
|
||||||
|
# 尝试从响应中提取JSON部分
|
||||||
|
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
json_str = json_match.group(1)
|
||||||
|
try:
|
||||||
|
questions_data = json.loads(json_str)
|
||||||
|
return questions_data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 如果JSON解析失败,尝试清理字符串后再次解析
|
||||||
|
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
|
||||||
|
try:
|
||||||
|
questions_data = json.loads(cleaned_json)
|
||||||
|
return questions_data
|
||||||
|
except:
|
||||||
|
return {"error": "Failed to parse JSON response", "raw_response": _response}
|
||||||
|
else:
|
||||||
|
# 如果没有找到JSON格式,返回原始响应
|
||||||
|
return {"error": "No JSON format found in response", "raw_response": _response}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_function_call(messages, tools, max_retries=3, initial_backoff=1.0):
|
||||||
|
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
instruction = """
|
||||||
|
# 问题
|
||||||
|
{question}
|
||||||
|
|
||||||
|
# 指令
|
||||||
|
在准确的回答上述问题之前,你只有现在这一次机会允许你调用工具以获取更多信息。
|
||||||
|
请尽可能深入思考上述问题,并尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
|
||||||
|
因此,你需要在回答中一次给出多个经过思考后的工具调用,以便更好地回答上述问题。
|
||||||
|
思考和回答时使用和问题相同的语言。
|
||||||
|
"""
|
||||||
|
messages[0]["content"] = instruction.replace("{question}", messages[0]["content"])
|
||||||
|
|
||||||
|
_response, functions = utils.get_response_from_qwq(messages, model_name="qwq-32b", tools=tools, max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
# print(f"Time: {time.time() - start_time}")
|
||||||
|
# print(_response)
|
||||||
|
# if _response == 'apierror' or _response == 'unexpectederror':
|
||||||
|
# return _response
|
||||||
|
return _response, functions
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_obs_response(messages, max_retries=3, initial_backoff=1.0):
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
_reasoning_content, response = utils.get_response_from_deepseek_r1(messages, prefix=False, max_retries=max_retries, initial_backoff=initial_backoff)
|
||||||
|
return _reasoning_content, response
|
||||||
800
generate_data/utils.py
Executable file
800
generate_data/utils.py
Executable file
@@ -0,0 +1,800 @@
|
|||||||
|
"""
|
||||||
|
This script generates questions and answers from a given set of CIFs.
|
||||||
|
It uses the OpenAI API and MySQL for storing and retrieving data.
|
||||||
|
@author: Yutang Li
|
||||||
|
"""
|
||||||
|
import multiprocessing
|
||||||
|
import sqlite3
|
||||||
|
import tiktoken
|
||||||
|
import re
|
||||||
|
from fractions import Fraction
|
||||||
|
import numpy as np
|
||||||
|
import glob
|
||||||
|
import tqdm
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from openai import OpenAI, APIError, RateLimitError
|
||||||
|
from mysql.connector import pooling, Error
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, max_retries: int = 3, initial_backoff: float = 1.0):
|
||||||
|
"""
|
||||||
|
Get response from DeepSeek API with retry mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
prefix: Whether to use the prefix URL
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
initial_backoff: Initial backoff time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (reasoning_content, content) or error messages
|
||||||
|
"""
|
||||||
|
retries = 0
|
||||||
|
while retries <= max_retries:
|
||||||
|
try:
|
||||||
|
base_url = "https://api.deepseek.com/beta" if prefix else "https://vip.apiyi.com/v1"
|
||||||
|
api_key = "sk-59279cc16ec740089146ef9aef9c1671" if prefix else "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
|
||||||
|
|
||||||
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="deepseek-r1",
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.6
|
||||||
|
)
|
||||||
|
|
||||||
|
# reasoning_content = "null" if prefix else "<think>\n" + response.choices[0].message.model_extra['reasoning_content'] + "\n</think>\n"
|
||||||
|
reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
|
||||||
|
content = response.choices[0].message.content.split("</think>\n")[-1]
|
||||||
|
return reasoning_content, content
|
||||||
|
|
||||||
|
except RateLimitError as rate_error:
|
||||||
|
retries += 1
|
||||||
|
if retries > max_retries:
|
||||||
|
print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||||||
|
return 'apierror', 'apierror'
|
||||||
|
|
||||||
|
# Exponential backoff with jitter
|
||||||
|
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||||
|
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
|
||||||
|
time.sleep(backoff_time)
|
||||||
|
|
||||||
|
except APIError as api_error:
|
||||||
|
retries += 1
|
||||||
|
if retries > max_retries:
|
||||||
|
print(f"Max retries exceeded for APIError: {api_error}")
|
||||||
|
return 'apierror', 'apierror'
|
||||||
|
|
||||||
|
# Check if the error is retryable
|
||||||
|
error_str = str(api_error)
|
||||||
|
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
|
||||||
|
# Exponential backoff with jitter
|
||||||
|
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||||
|
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
|
||||||
|
time.sleep(backoff_time)
|
||||||
|
else:
|
||||||
|
# Non-retryable API error
|
||||||
|
print(f"Non-retryable API error: {api_error}")
|
||||||
|
return 'apierror', 'apierror'
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"generate_design_question Unexpected error: {e}")
|
||||||
|
return 'unexpectederror', 'unexpectederror'
|
||||||
|
|
||||||
|
|
||||||
|
def get_response_from_llm(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
|
||||||
|
"""
|
||||||
|
Get response from LLM API with retry mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
model_name: Name of the model to use
|
||||||
|
tools: Optional list of tools to use
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
initial_backoff: Initial backoff time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content of the response or error message
|
||||||
|
"""
|
||||||
|
retries = 0
|
||||||
|
while retries <= max_retries:
|
||||||
|
try:
|
||||||
|
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
|
||||||
|
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||||
|
if tools is None:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice='auto',
|
||||||
|
parallel_tool_calls=True
|
||||||
|
)
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
return content
|
||||||
|
|
||||||
|
except RateLimitError as rate_error:
|
||||||
|
retries += 1
|
||||||
|
if retries > max_retries:
|
||||||
|
print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||||||
|
return 'apierror'
|
||||||
|
|
||||||
|
# Exponential backoff with jitter
|
||||||
|
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||||
|
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
|
||||||
|
time.sleep(backoff_time)
|
||||||
|
|
||||||
|
except APIError as api_error:
|
||||||
|
retries += 1
|
||||||
|
if retries > max_retries:
|
||||||
|
print(f"Max retries exceeded for APIError: {api_error}")
|
||||||
|
return 'apierror'
|
||||||
|
|
||||||
|
# Check if the error is retryable
|
||||||
|
error_str = str(api_error)
|
||||||
|
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
|
||||||
|
# Exponential backoff with jitter
|
||||||
|
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||||
|
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
|
||||||
|
time.sleep(backoff_time)
|
||||||
|
else:
|
||||||
|
# Non-retryable API error
|
||||||
|
print(f"Non-retryable API error: {api_error}")
|
||||||
|
return 'apierror'
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"generate_design_question Unexpected error: {e}")
|
||||||
|
return 'unexpectederror'
|
||||||
|
|
||||||
|
def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
|
||||||
|
"""
|
||||||
|
Get response from LLM API with retry mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
model_name: Name of the model to use
|
||||||
|
tools: Optional list of tools to use
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
initial_backoff: Initial backoff time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content of the response or error message
|
||||||
|
"""
|
||||||
|
retries = 0
|
||||||
|
while retries <= max_retries:
|
||||||
|
try:
|
||||||
|
# client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
|
||||||
|
# client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||||
|
import random
|
||||||
|
if random.random() > 0.5:
|
||||||
|
client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||||
|
else:
|
||||||
|
client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||||
|
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||||
|
if tools is None:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice='auto',
|
||||||
|
parallel_tool_calls=True,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning_content = "" # 定义完整思考过程
|
||||||
|
answer_content = "" # 定义完整回复
|
||||||
|
tool_info = [] # 存储工具调用信息
|
||||||
|
is_answering = False # 判断是否结束思考过程并开始回复
|
||||||
|
# print("="*20+"思考过程"+"="*20)
|
||||||
|
for chunk in response:
|
||||||
|
# if not chunk.choices:
|
||||||
|
# # 处理用量统计信息
|
||||||
|
# print("\n"+"="*20+"Usage"+"="*20)
|
||||||
|
# print(chunk.usage)
|
||||||
|
# else:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
# 处理AI的思考过程(链式推理)
|
||||||
|
if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None:
|
||||||
|
reasoning_content += delta.reasoning_content
|
||||||
|
# print(delta.reasoning_content,end="",flush=True) # 实时输出思考过程
|
||||||
|
|
||||||
|
# 处理最终回复内容
|
||||||
|
else:
|
||||||
|
if not is_answering: # 首次进入回复阶段时打印标题
|
||||||
|
is_answering = True
|
||||||
|
# print("\n"+"="*20+"回复内容"+"="*20)
|
||||||
|
if delta.content is not None:
|
||||||
|
answer_content += delta.content
|
||||||
|
# print(delta.content,end="",flush=True) # 流式输出回复内容
|
||||||
|
|
||||||
|
# 处理工具调用信息(支持并行工具调用)
|
||||||
|
if delta.tool_calls is not None:
|
||||||
|
for tool_call in delta.tool_calls:
|
||||||
|
index = tool_call.index # 工具调用索引,用于并行调用
|
||||||
|
|
||||||
|
# 动态扩展工具信息存储列表
|
||||||
|
while len(tool_info) <= index:
|
||||||
|
tool_info.append({})
|
||||||
|
|
||||||
|
# 收集工具调用ID(用于后续函数调用)
|
||||||
|
# if tool_call.id:
|
||||||
|
# tool_info[index]['id'] = tool_info[index].get('id', '') + tool_call.id
|
||||||
|
|
||||||
|
# 收集函数名称(用于后续路由到具体函数)
|
||||||
|
if tool_call.function and tool_call.function.name:
|
||||||
|
tool_info[index]['name'] = tool_info[index].get('name', '') + tool_call.function.name
|
||||||
|
|
||||||
|
# 收集函数参数(JSON字符串格式,需要后续解析)
|
||||||
|
if tool_call.function and tool_call.function.arguments:
|
||||||
|
tool_info[index]['arguments'] = tool_info[index].get('arguments', '') + tool_call.function.arguments
|
||||||
|
|
||||||
|
tools_response = ""
|
||||||
|
for tool in tool_info:
|
||||||
|
tools_response += ("<tool_call>\n" + json.dumps(tool, ensure_ascii=False) + "\n</tool_call>\n")
|
||||||
|
response = "<think>\n" + reasoning_content + "\n</think>\n" + "<answer>\n" + answer_content + tools_response + "\n</answer>\n"
|
||||||
|
return response, tool_info
|
||||||
|
# return reasoning_content, answer_content, tool_info
|
||||||
|
|
||||||
|
except RateLimitError as rate_error:
|
||||||
|
retries += 1
|
||||||
|
if retries > max_retries:
|
||||||
|
print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||||||
|
return 'apierror', []
|
||||||
|
|
||||||
|
# Exponential backoff with jitter
|
||||||
|
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||||
|
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
|
||||||
|
time.sleep(backoff_time)
|
||||||
|
|
||||||
|
except APIError as api_error:
|
||||||
|
retries += 1
|
||||||
|
if retries > max_retries:
|
||||||
|
print(f"Max retries exceeded for APIError: {api_error}")
|
||||||
|
return 'apierror', []
|
||||||
|
|
||||||
|
# Check if the error is retryable
|
||||||
|
error_str = str(api_error)
|
||||||
|
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
|
||||||
|
# Exponential backoff with jitter
|
||||||
|
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||||
|
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
|
||||||
|
time.sleep(backoff_time)
|
||||||
|
else:
|
||||||
|
# Non-retryable API error
|
||||||
|
print(f"Non-retryable API error: {api_error}")
|
||||||
|
return 'apierror', []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"generate_design_question Unexpected error: {e}")
|
||||||
|
return 'unexpectederror', []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def read_json_file(file_path):
|
||||||
|
"""Read the json file and return its content."""
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {file_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
################################## utils
|
||||||
|
|
||||||
|
def clean_all_repetitions_with_details(text, min_length=10, threshold=10):
|
||||||
|
"""
|
||||||
|
综合清理文本中的各种重复内容,并返回详细信息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- text: 要清理的文本
|
||||||
|
- min_length: 最小重复片段长度
|
||||||
|
- threshold: 重复内容的阈值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- cleaned_text: 清理后的文本
|
||||||
|
- is_repetitive: 是否检测到重复
|
||||||
|
- repetition_details: 重复内容的详细信息
|
||||||
|
"""
|
||||||
|
original_text = text
|
||||||
|
is_repetitive = False
|
||||||
|
repetition_details = []
|
||||||
|
|
||||||
|
# 1. 首先处理有换行符的重复
|
||||||
|
if '\n' in text:
|
||||||
|
lines = text.split('\n')
|
||||||
|
unique_lines = []
|
||||||
|
line_counts = {}
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
normalized = line.strip().lower()
|
||||||
|
if not normalized:
|
||||||
|
unique_lines.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
line_counts[normalized] = line_counts.get(normalized, 0) + 1
|
||||||
|
|
||||||
|
if line_counts[normalized] <= threshold:
|
||||||
|
unique_lines.append(line)
|
||||||
|
|
||||||
|
# 如果这是第一次超过阈值,记录重复详情
|
||||||
|
if line_counts[normalized] == threshold + 1:
|
||||||
|
# 找到原始形式(保留大小写)
|
||||||
|
original_form = None
|
||||||
|
for l in lines[:i]:
|
||||||
|
if l.strip().lower() == normalized:
|
||||||
|
original_form = l
|
||||||
|
break
|
||||||
|
|
||||||
|
if original_form is None:
|
||||||
|
original_form = line
|
||||||
|
|
||||||
|
repetition_details.append({
|
||||||
|
'type': 'line_repetition',
|
||||||
|
'repeated_string': original_form,
|
||||||
|
'repeat_count': line_counts[normalized]
|
||||||
|
})
|
||||||
|
|
||||||
|
if any(count > threshold for count in line_counts.values()):
|
||||||
|
text = '\n'.join(unique_lines)
|
||||||
|
is_repetitive = True
|
||||||
|
|
||||||
|
# 2. 处理同一行内的连续重复模式
|
||||||
|
for length in range(min_length, 101):
|
||||||
|
pattern = r'(.{' + str(length) + r'})(\1)+'
|
||||||
|
|
||||||
|
while True:
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if not match:
|
||||||
|
break
|
||||||
|
|
||||||
|
repeated_part = match.group(1)
|
||||||
|
full_match = match.group(0)
|
||||||
|
|
||||||
|
# 计算重复次数
|
||||||
|
repeat_count = len(full_match) // len(repeated_part)
|
||||||
|
|
||||||
|
# 记录重复详情
|
||||||
|
repetition_details.append({
|
||||||
|
'type': 'inline_repetition',
|
||||||
|
'repeated_string': repeated_part,
|
||||||
|
'repeat_count': repeat_count,
|
||||||
|
'total_length': len(full_match),
|
||||||
|
'position': match.start()
|
||||||
|
})
|
||||||
|
|
||||||
|
text = text.replace(full_match, repeated_part)
|
||||||
|
is_repetitive = True
|
||||||
|
|
||||||
|
# 3. 处理句子级别的重复
|
||||||
|
sentences = re.split(r'(?<=[.!?。?!])\s+', text)
|
||||||
|
if len(sentences) > 1:
|
||||||
|
sentence_counter = Counter(sentences)
|
||||||
|
|
||||||
|
for sentence, count in sentence_counter.items():
|
||||||
|
if count > threshold:
|
||||||
|
repetition_details.append({
|
||||||
|
'type': 'sentence_repetition',
|
||||||
|
'repeated_string': sentence,
|
||||||
|
'repeat_count': count
|
||||||
|
})
|
||||||
|
|
||||||
|
if any(count > threshold for count in sentence_counter.values()):
|
||||||
|
unique_sentences = []
|
||||||
|
seen_sentences = {}
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
seen_sentences[sentence] = seen_sentences.get(sentence, 0) + 1
|
||||||
|
if seen_sentences[sentence] <= threshold:
|
||||||
|
unique_sentences.append(sentence)
|
||||||
|
|
||||||
|
# 重新组合文本
|
||||||
|
text = ' '.join(unique_sentences)
|
||||||
|
is_repetitive = True
|
||||||
|
|
||||||
|
# 4. 处理更短的重复(如果前面的方法没有检测到重复)
|
||||||
|
if not is_repetitive and min_length > 5:
|
||||||
|
for length in range(5, min_length):
|
||||||
|
pattern = r'(.{' + str(length) + r'})(\1){2,}' # 至少重复3次才处理
|
||||||
|
|
||||||
|
while True:
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if not match:
|
||||||
|
break
|
||||||
|
|
||||||
|
repeated_part = match.group(1)
|
||||||
|
full_match = match.group(0)
|
||||||
|
|
||||||
|
# 计算重复次数
|
||||||
|
repeat_count = len(full_match) // len(repeated_part)
|
||||||
|
|
||||||
|
# 记录重复详情
|
||||||
|
repetition_details.append({
|
||||||
|
'type': 'short_repetition',
|
||||||
|
'repeated_string': repeated_part,
|
||||||
|
'repeat_count': repeat_count,
|
||||||
|
'total_length': len(full_match),
|
||||||
|
'position': match.start()
|
||||||
|
})
|
||||||
|
|
||||||
|
text = text.replace(full_match, repeated_part)
|
||||||
|
is_repetitive = True
|
||||||
|
|
||||||
|
# 按重复类型和长度排序
|
||||||
|
repetition_details.sort(key=lambda x: (-len(x['repeated_string']), x['type']))
|
||||||
|
|
||||||
|
return text, is_repetitive or text != original_text, repetition_details
|
||||||
|
|
||||||
|
def create_table(table_name, connection_pool):
|
||||||
|
"""Create the required MySQL table if it does not exist."""
|
||||||
|
db = connection_pool.get_connection()
|
||||||
|
cursor = db.cursor()
|
||||||
|
create_table_query = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||||
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
mp_id TEXT,
|
||||||
|
question_model TEXT,
|
||||||
|
question TEXT,
|
||||||
|
answer_model TEXT,
|
||||||
|
answer TEXT,
|
||||||
|
answer_len INT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
cursor.execute(create_table_query)
|
||||||
|
db.commit()
|
||||||
|
cursor.close()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
def record_exists(mp_id, table_name, connection_pool):
|
||||||
|
"""Check if a mp_id already exists in the table."""
|
||||||
|
db = connection_pool.get_connection()
|
||||||
|
cursor = db.cursor()
|
||||||
|
query = f"SELECT * FROM {table_name} WHERE mp_id = %s"
|
||||||
|
cursor.execute(query, (mp_id,))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
cursor.fetchall() # Ensure all results are processed
|
||||||
|
cursor.close()
|
||||||
|
db.close()
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
def insert_record(entry, table_name, connection_pool):
|
||||||
|
"""Insert a record into the MySQL table."""
|
||||||
|
db = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
db = connection_pool.get_connection()
|
||||||
|
cursor = db.cursor()
|
||||||
|
|
||||||
|
insert_query = f"""
|
||||||
|
INSERT INTO {table_name}
|
||||||
|
(mp_id, question_model, question, answer_model, answer, answer_len)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s)
|
||||||
|
"""
|
||||||
|
values = (
|
||||||
|
entry["mp_id"], entry["question_model"],
|
||||||
|
entry["question"], entry["answer_model"], entry["answer"], entry["answer_len"],
|
||||||
|
)
|
||||||
|
cursor.execute(insert_query, values)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
except Error as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
db.rollback()
|
||||||
|
finally:
|
||||||
|
# Ensure cursor is closed
|
||||||
|
if cursor:
|
||||||
|
cursor.close()
|
||||||
|
# Ensure connection is returned to the pool
|
||||||
|
if db:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize SQLite database connection
|
||||||
|
def initialize_db():
|
||||||
|
conn = sqlite3.connect('multi_turns_data.db', check_same_thread=False)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
mp_id TEXT PRIMARY KEY,
|
||||||
|
sample TEXT,
|
||||||
|
token_num INTEGER
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
conn.commit()
|
||||||
|
return conn
|
||||||
|
|
||||||
|
# Save sample to SQLite database
|
||||||
|
def save_to_db(conn, mp_id, sample, total_token):
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('''
|
||||||
|
INSERT OR REPLACE INTO conversations (mp_id, sample, token_num)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
''', (mp_id, str(sample), total_token))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def read_cif_txt_file(file_path):
|
||||||
|
"""Read the markdown file and return its content."""
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return f.read()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {file_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def round_values(data, precision=3):
|
||||||
|
"""
|
||||||
|
递归地将字典中的所有值保留三位小数
|
||||||
|
"""
|
||||||
|
if isinstance(data, dict): # 如果是字典
|
||||||
|
return {key: round_values(value) for key, value in data.items()}
|
||||||
|
elif isinstance(data, list): # 如果是列表,递归处理每个元素
|
||||||
|
return [round_values(item) for item in data]
|
||||||
|
elif isinstance(data, (int, float)): # 如果是数字,保留三位小数
|
||||||
|
return round(data, precision)
|
||||||
|
else: # 对其他类型,直接返回
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def decimal_to_fraction(decimal_value, max_denominator=1000):
|
||||||
|
"""
|
||||||
|
将小数转换为分数表示
|
||||||
|
|
||||||
|
参数:
|
||||||
|
decimal_value: 要转换的小数
|
||||||
|
max_denominator: 分母的最大值,用于控制精度
|
||||||
|
|
||||||
|
返回:
|
||||||
|
分数表示的字符串
|
||||||
|
"""
|
||||||
|
frac = Fraction(decimal_value).limit_denominator(max_denominator)
|
||||||
|
return f"{frac.numerator}/{frac.denominator}"
|
||||||
|
|
||||||
|
def poscar_to_fractional_representation(poscar_content, max_denominator=1000):
|
||||||
|
"""
|
||||||
|
将POSCAR文件中的数值转换为分数表示
|
||||||
|
|
||||||
|
参数:
|
||||||
|
poscar_content: POSCAR文件内容
|
||||||
|
max_denominator: 分母的最大值,用于控制精度
|
||||||
|
|
||||||
|
返回:
|
||||||
|
转换后的POSCAR内容,数值以分数表示
|
||||||
|
"""
|
||||||
|
lines = poscar_content.strip().split('\n')
|
||||||
|
result_lines = []
|
||||||
|
|
||||||
|
# 保留系统名称
|
||||||
|
result_lines.append(lines[0])
|
||||||
|
|
||||||
|
# 保留缩放因子
|
||||||
|
scaling_factor = float(lines[1])
|
||||||
|
result_lines.append(lines[1])
|
||||||
|
|
||||||
|
# 处理晶格向量
|
||||||
|
for i in range(2, 5):
|
||||||
|
vector = [float(x) for x in lines[i].split()]
|
||||||
|
# 将每个分量转换为分数
|
||||||
|
fractional_vector = [decimal_to_fraction(x, max_denominator) for x in vector]
|
||||||
|
result_lines.append(" " + " ".join(fractional_vector))
|
||||||
|
|
||||||
|
# 保留元素类型和数量
|
||||||
|
if len(lines) > 5:
|
||||||
|
result_lines.append(lines[5])
|
||||||
|
if len(lines) > 6:
|
||||||
|
result_lines.append(lines[6])
|
||||||
|
|
||||||
|
# 保留坐标类型
|
||||||
|
if len(lines) > 7:
|
||||||
|
result_lines.append(lines[7])
|
||||||
|
|
||||||
|
# 处理原子坐标
|
||||||
|
for i in range(8, len(lines)):
|
||||||
|
parts = lines[i].split()
|
||||||
|
if len(parts) >= 3:
|
||||||
|
# 将坐标转换为分数
|
||||||
|
coords = [float(parts[j]) for j in range(3)]
|
||||||
|
fractional_coords = [decimal_to_fraction(x, max_denominator) for x in coords]
|
||||||
|
|
||||||
|
# 构建新行
|
||||||
|
new_line = " " + " ".join(fractional_coords)
|
||||||
|
if len(parts) > 3:
|
||||||
|
new_line += " " + " ".join(parts[3:])
|
||||||
|
result_lines.append(new_line)
|
||||||
|
else:
|
||||||
|
# 保留非坐标行
|
||||||
|
result_lines.append(lines[i])
|
||||||
|
|
||||||
|
return "\n".join(result_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_symmetry_equiv_xyz(cif_content):
|
||||||
|
"""
|
||||||
|
删除CIF文件中的对称性操作部分
|
||||||
|
|
||||||
|
参数:
|
||||||
|
cif_content: CIF文件内容字符串
|
||||||
|
|
||||||
|
返回:
|
||||||
|
清理后的CIF内容字符串
|
||||||
|
"""
|
||||||
|
lines = cif_content.split('\n')
|
||||||
|
output_lines = []
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(lines):
|
||||||
|
line = lines[i].strip()
|
||||||
|
|
||||||
|
# 检测循环开始
|
||||||
|
if line == 'loop_':
|
||||||
|
# 查看下一行,检查是否是对称性循环
|
||||||
|
next_lines = []
|
||||||
|
j = i + 1
|
||||||
|
while j < len(lines) and lines[j].strip().startswith('_'):
|
||||||
|
next_lines.append(lines[j].strip())
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
# 检查是否包含对称性操作标签
|
||||||
|
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
|
||||||
|
# 跳过整个循环块
|
||||||
|
while i < len(lines):
|
||||||
|
if i + 1 >= len(lines):
|
||||||
|
break
|
||||||
|
|
||||||
|
next_line = lines[i + 1].strip()
|
||||||
|
# 检查是否到达下一个循环或数据块
|
||||||
|
if next_line == 'loop_' or next_line.startswith('data_'):
|
||||||
|
break
|
||||||
|
|
||||||
|
# 检查是否到达原子位置部分
|
||||||
|
if next_line.startswith('_atom_site_'):
|
||||||
|
break
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
# 不是对称性循环,保留loop_行
|
||||||
|
output_lines.append(lines[i])
|
||||||
|
else:
|
||||||
|
# 非循环开始行,直接保留
|
||||||
|
output_lines.append(lines[i])
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return '\n'.join(output_lines)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def remove_null_values(d):
|
||||||
|
"""
|
||||||
|
Recursively remove key-value pairs with null (None) values from a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d (dict): The dictionary to clean.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A new dictionary without null values.
|
||||||
|
"""
|
||||||
|
if not isinstance(d, dict):
|
||||||
|
raise ValueError("Input must be a dictionary")
|
||||||
|
_d = copy.deepcopy(d)
|
||||||
|
|
||||||
|
def recursive_remove(d):
|
||||||
|
cleaned_dict = {}
|
||||||
|
for key, value in d.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# Recursively clean nested dictionaries
|
||||||
|
nested_cleaned = recursive_remove(value)
|
||||||
|
if nested_cleaned: # Only add non-empty dictionaries
|
||||||
|
cleaned_dict[key] = nested_cleaned
|
||||||
|
elif value is not None and key != 'version':
|
||||||
|
cleaned_dict[key] = value
|
||||||
|
|
||||||
|
return cleaned_dict
|
||||||
|
|
||||||
|
clean_dict = recursive_remove(d)
|
||||||
|
if _d['cbm'] is None and _d['vbm'] is None and _d['band_gap'] is not None:
|
||||||
|
# clean_dict['band_gap'] = None
|
||||||
|
clean_dict.pop('band_gap')
|
||||||
|
return clean_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_extra_cif_info(path: str, fields_name: list):
|
||||||
|
"""Extract specific fields from the CIF description."""
|
||||||
|
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
|
||||||
|
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
|
||||||
|
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
|
||||||
|
# metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"]
|
||||||
|
|
||||||
|
selected_fields = []
|
||||||
|
if fields_name[0] == 'all_fields':
|
||||||
|
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
|
||||||
|
# selected_fields = energy_electronic_fields + metal_magentic_fields
|
||||||
|
else:
|
||||||
|
for field in fields_name:
|
||||||
|
selected_fields.extend(locals().get(field, []))
|
||||||
|
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
docs = json.load(f)
|
||||||
|
|
||||||
|
new_docs = {}
|
||||||
|
for field_name in selected_fields:
|
||||||
|
new_docs[field_name] = docs.get(field_name, '')
|
||||||
|
|
||||||
|
# new_docs['structure'] = {"lattice": docs['structure']['lattice']}
|
||||||
|
return new_docs
|
||||||
|
|
||||||
|
def extract_json(text):
|
||||||
|
"""Extract JSON content from a block of text using regex."""
|
||||||
|
json_pattern = re.compile(r'\\{(?:[^{}]|(?R))*\\}')
|
||||||
|
matches = json_pattern.search(text)
|
||||||
|
if matches:
|
||||||
|
json_str = matches.group(0)
|
||||||
|
try:
|
||||||
|
return json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_and_parse_json(response):
|
||||||
|
"""Extract and parse JSON from a response."""
|
||||||
|
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
|
||||||
|
json_str = json_match.group(1) if json_match else response.strip()
|
||||||
|
json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str)
|
||||||
|
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
|
||||||
|
try:
|
||||||
|
return json.loads(json_str)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"JSON parse error: {e}")
|
||||||
|
return 'errformat'
|
||||||
|
|
||||||
|
|
||||||
|
# 计算输入消息的tokens
|
||||||
|
def count_message_tokens(messages, model_name):
|
||||||
|
encoding = tiktoken.encoding_for_model(model_name)
|
||||||
|
num_tokens = 0
|
||||||
|
|
||||||
|
num_tokens += len(encoding.encode(messages))
|
||||||
|
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
def make_multi_turns_sharegpt_sample(humans: list[str], gpts: list[str], system: str="{SYSTEM}"):
|
||||||
|
sample = {}
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
if system is not None and system != "":
|
||||||
|
sample["system"] = system
|
||||||
|
|
||||||
|
assert len(humans) !=0, "human cannot be None"
|
||||||
|
assert len(gpts) == len(humans), "human and gpt must have the same length"
|
||||||
|
|
||||||
|
for human, gpt in zip(humans, gpts):
|
||||||
|
if human is not None and human != "":
|
||||||
|
assert gpt is not None, "gpt cannot be None"
|
||||||
|
assert gpt != "", "gpt cannot be empty"
|
||||||
|
# 下列顺序不可改
|
||||||
|
conversations.append({"from": "human", "value": human})
|
||||||
|
conversations.append({"from": "gpt", "value": gpt})
|
||||||
|
|
||||||
|
sample["conversations"] = conversations
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##################################### utils
|
||||||
1172
mars_toolkit.log
1172
mars_toolkit.log
File diff suppressed because it is too large
Load Diff
0
mars_toolkit/__init__.py
Normal file → Executable file
0
mars_toolkit/__init__.py
Normal file → Executable file
BIN
mars_toolkit/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
0
mars_toolkit/compute/__init__.py
Normal file → Executable file
0
mars_toolkit/compute/__init__.py
Normal file → Executable file
BIN
mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
0
mars_toolkit/compute/material_gen.py
Normal file → Executable file
0
mars_toolkit/compute/material_gen.py
Normal file → Executable file
0
mars_toolkit/compute/property_pred.py
Normal file → Executable file
0
mars_toolkit/compute/property_pred.py
Normal file → Executable file
0
mars_toolkit/compute/structure_opt.py
Normal file → Executable file
0
mars_toolkit/compute/structure_opt.py
Normal file → Executable file
0
mars_toolkit/core/__init__.py
Normal file → Executable file
0
mars_toolkit/core/__init__.py
Normal file → Executable file
BIN
mars_toolkit/core/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/core/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/config.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/core/__pycache__/config.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
0
mars_toolkit/core/__pycache__/error_handlers.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/core/__pycache__/error_handlers.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
0
mars_toolkit/core/__pycache__/utils.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/core/__pycache__/utils.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/core/cif_utils.py
Normal file → Executable file
0
mars_toolkit/core/cif_utils.py
Normal file → Executable file
8
mars_toolkit/core/config.py
Normal file → Executable file
8
mars_toolkit/core/config.py
Normal file → Executable file
@@ -22,12 +22,12 @@ class Config:
|
|||||||
HTTPS_PROXY = 'http://192.168.168.1:20171'
|
HTTPS_PROXY = 'http://192.168.168.1:20171'
|
||||||
|
|
||||||
# FairChem
|
# FairChem
|
||||||
FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
|
FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt'
|
||||||
FMAX = 0.05
|
FMAX = 0.05
|
||||||
|
|
||||||
# MatterGen
|
# MatterGen
|
||||||
MATTERGENMODEL_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
|
MATTERGENMODEL_ROOT = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/mattergen_ckpt'
|
||||||
MATTERGEN_ROOT='/home/ubuntu/50T/lzy/mars-mcp/mattergen'
|
MATTERGEN_ROOT='/home/ubuntu/50T/nfs/lzy/mars-mcp/mattergen'
|
||||||
MATTERGENMODEL_RESULT_PATH = 'results/'
|
MATTERGENMODEL_RESULT_PATH = 'results/'
|
||||||
|
|
||||||
# Dify
|
# Dify
|
||||||
@@ -38,7 +38,7 @@ class Config:
|
|||||||
SEARXNG_HOST="http://192.168.168.1:40032/"
|
SEARXNG_HOST="http://192.168.168.1:40032/"
|
||||||
|
|
||||||
# Visualization
|
# Visualization
|
||||||
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
|
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/nfs/lzy/mars-mcp/outputs/cif_visualization'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def as_dict(cls) -> Dict[str, Any]:
|
def as_dict(cls) -> Dict[str, Any]:
|
||||||
|
|||||||
0
mars_toolkit/core/llm_tools.py
Normal file → Executable file
0
mars_toolkit/core/llm_tools.py
Normal file → Executable file
0
mars_toolkit/core/mattergen_wrapper.py
Normal file → Executable file
0
mars_toolkit/core/mattergen_wrapper.py
Normal file → Executable file
0
mars_toolkit/misc/__init__.py
Normal file → Executable file
0
mars_toolkit/misc/__init__.py
Normal file → Executable file
BIN
mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
0
mars_toolkit/misc/__pycache__/general_tools.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/misc/__pycache__/general_tools.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
0
mars_toolkit/misc/misc_tools.py
Normal file → Executable file
0
mars_toolkit/misc/misc_tools.py
Normal file → Executable file
0
mars_toolkit/query/__init__.py
Normal file → Executable file
0
mars_toolkit/query/__init__.py
Normal file → Executable file
BIN
mars_toolkit/query/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/query/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
BIN
mars_toolkit/query/__pycache__/web_search.cpython-310.pyc
Normal file → Executable file
BIN
mars_toolkit/query/__pycache__/web_search.cpython-310.pyc
Normal file → Executable file
Binary file not shown.
0
mars_toolkit/query/dify_search.py
Normal file → Executable file
0
mars_toolkit/query/dify_search.py
Normal file → Executable file
0
mars_toolkit/query/mp_query.py
Normal file → Executable file
0
mars_toolkit/query/mp_query.py
Normal file → Executable file
0
mars_toolkit/query/oqmd_query.py
Normal file → Executable file
0
mars_toolkit/query/oqmd_query.py
Normal file → Executable file
0
mars_toolkit/query/web_search.py
Normal file → Executable file
0
mars_toolkit/query/web_search.py
Normal file → Executable file
0
mars_toolkit/services/__init__.py
Normal file → Executable file
0
mars_toolkit/services/__init__.py
Normal file → Executable file
0
mars_toolkit/services/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/services/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/services/mattergen_service.py
Normal file → Executable file
0
mars_toolkit/services/mattergen_service.py
Normal file → Executable file
0
mars_toolkit/visualization/__init__.py
Normal file → Executable file
0
mars_toolkit/visualization/__init__.py
Normal file → Executable file
0
mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/visualization/__pycache__/crystal_vis.cpython-310.pyc
Normal file → Executable file
0
mars_toolkit/visualization/__pycache__/crystal_vis.cpython-310.pyc
Normal file → Executable file
0
mattergen_api.py
Normal file → Executable file
0
mattergen_api.py
Normal file → Executable file
@@ -1,134 +0,0 @@
|
|||||||
import requests
|
|
||||||
import json
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
|
|
||||||
def generate_material(
|
|
||||||
url="http://localhost:8051/generate_material",
|
|
||||||
properties=None,
|
|
||||||
batch_size=2,
|
|
||||||
num_batches=1,
|
|
||||||
diffusion_guidance_factor=2.0
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
调用MatterGen API生成晶体结构
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: API端点URL
|
|
||||||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
|
||||||
batch_size: 每批生成的结构数量
|
|
||||||
num_batches: 批次数量
|
|
||||||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
生成的结构内容或错误信息
|
|
||||||
"""
|
|
||||||
# 构建请求负载
|
|
||||||
payload = {
|
|
||||||
"properties": properties ,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"num_batches": num_batches,
|
|
||||||
"diffusion_guidance_factor": diffusion_guidance_factor
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"发送请求到 {url}")
|
|
||||||
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 添加headers参数,包含accept头
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"accept": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 打印完整请求信息(调试用)
|
|
||||||
print(f"完整请求URL: {url}")
|
|
||||||
print(f"请求头: {headers}")
|
|
||||||
print(f"请求体: {json.dumps(payload)}")
|
|
||||||
|
|
||||||
# 禁用代理设置
|
|
||||||
proxies = {
|
|
||||||
"http": None,
|
|
||||||
"https": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送POST请求,添加headers参数,禁用代理,增加超时时间
|
|
||||||
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
|
|
||||||
|
|
||||||
# 打印响应信息(调试用)
|
|
||||||
print(f"响应状态码: {response.status_code}")
|
|
||||||
print(f"响应头: {dict(response.headers)}")
|
|
||||||
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长
|
|
||||||
|
|
||||||
# 检查响应状态
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
print("\n生成成功!")
|
|
||||||
return result["content"]
|
|
||||||
else:
|
|
||||||
print(f"\n生成失败: {result['message']}")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
print(f"\n请求失败,状态码: {response.status_code}")
|
|
||||||
print(f"响应内容: {response.text}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n发生错误: {str(e)}")
|
|
||||||
print(f"错误类型: {type(e).__name__}")
|
|
||||||
import traceback
|
|
||||||
print(f"错误堆栈: {traceback.format_exc()}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""命令行入口函数"""
|
|
||||||
parser = argparse.ArgumentParser(description="MatterGen API客户端示例")
|
|
||||||
|
|
||||||
# 添加命令行参数
|
|
||||||
parser.add_argument("--url", default="http://localhost:8051/generate_material",
|
|
||||||
help="MatterGen API端点URL")
|
|
||||||
parser.add_argument("--property-name", default='dft_mag_density',help="属性名称,例如dft_band_gap")
|
|
||||||
parser.add_argument("--property-value",default=0.15,help="属性值,例如2.0")
|
|
||||||
parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量")
|
|
||||||
parser.add_argument("--num-batches", type=int, default=1, help="批次数量")
|
|
||||||
parser.add_argument("--guidance-factor", type=float, default=2.0,
|
|
||||||
help="控制生成结构与目标属性的符合程度")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# 构建属性字典
|
|
||||||
properties = None
|
|
||||||
if args.property_name and args.property_value:
|
|
||||||
try:
|
|
||||||
# 尝试将属性值转换为数字
|
|
||||||
try:
|
|
||||||
value = float(args.property_value)
|
|
||||||
# 如果是整数,转换为整数
|
|
||||||
if value.is_integer():
|
|
||||||
value = int(value)
|
|
||||||
except ValueError:
|
|
||||||
# 如果无法转换为数字,保持为字符串
|
|
||||||
value = args.property_value
|
|
||||||
|
|
||||||
properties = {args.property_name: value}
|
|
||||||
except Exception as e:
|
|
||||||
print(f"解析属性值时出错: {str(e)}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 调用API
|
|
||||||
result = generate_material(
|
|
||||||
url=args.url,
|
|
||||||
properties=properties,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
num_batches=args.num_batches,
|
|
||||||
diffusion_guidance_factor=args.guidance_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
print("\n生成的结构:")
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
BIN
prompts/__pycache__/material_synthesis.cpython-310.pyc
Executable file
BIN
prompts/__pycache__/material_synthesis.cpython-310.pyc
Executable file
Binary file not shown.
167
prompts/material_synthesis.py
Executable file
167
prompts/material_synthesis.py
Executable file
@@ -0,0 +1,167 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
import mcp.types as types
|
||||||
|
from mcp.server.lowlevel import Server
|
||||||
|
|
||||||
|
|
||||||
|
def create_messages(
|
||||||
|
properties: Dict[str, str] = None,
|
||||||
|
batch_size: int = 2,
|
||||||
|
) -> list[types.PromptMessage]:
|
||||||
|
"""
|
||||||
|
创建用于材料生成和合成的提示词消息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
properties: 材料性质及其值的字典,例如 {"dft_band_gap": "2.0"}
|
||||||
|
batch_size: 生成材料的数量,默认为2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提示词消息列表
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# 系统消息,定义助手的角色和任务
|
||||||
|
system_message = """你是一位专业的材料科学家,擅长材料生成和合成方案设计。
|
||||||
|
你的任务是:
|
||||||
|
1. 根据用户提供的材料性质要求,使用mars_toolkit中的generate_materials工具生成符合要求的材料
|
||||||
|
2. 系统地分析生成的材料的四要素:成分、结构、工艺和性能
|
||||||
|
3. 为生成的材料设计科学合理的合成方案
|
||||||
|
4. 使用mermaid语法绘制材料的合成流程图
|
||||||
|
|
||||||
|
请确保你的回答包含以下内容:
|
||||||
|
- 对用户需求的理解和分析
|
||||||
|
- 使用generate_material工具生成的材料结构
|
||||||
|
- 生成材料的四要素详细分析:
|
||||||
|
* 成分(Composition):详细的化学成分、元素比例、化学计量比
|
||||||
|
* 结构(Structure):晶体结构、空间群、晶格参数、原子位置、配位环境
|
||||||
|
* 工艺(Processing):可行的合成路线、工艺参数、关键控制因素
|
||||||
|
* 性能(Properties):预期的物理、化学、机械性能及其与结构的关系
|
||||||
|
- 详细的合成方案,包括:
|
||||||
|
* 原料选择及纯度要求
|
||||||
|
* 精确的反应条件(温度、压力、时间、气氛)
|
||||||
|
* 分步骤的合成流程及每步的理论依据
|
||||||
|
* 可能的挑战及解决方案
|
||||||
|
* 表征方法建议
|
||||||
|
- 使用mermaid语法绘制的合成流程图,清晰展示从原料到最终产品的全过程
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
types.PromptMessage(
|
||||||
|
role="system",
|
||||||
|
content=types.TextContent(type="text", text=system_message),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建主提示词
|
||||||
|
if properties and len(properties) > 0:
|
||||||
|
properties_text = "\n".join([f"- {key}: {value}" for key, value in properties.items()])
|
||||||
|
prompt = f"""请根据以下材料性质要求,生成{batch_size}个合适的材料并设计其合成方案:
|
||||||
|
|
||||||
|
{properties_text}
|
||||||
|
|
||||||
|
请按照以下步骤进行:
|
||||||
|
|
||||||
|
1. 使用mars_toolkit中的generate_material工具生成材料,参数设置为batch_size={batch_size}
|
||||||
|
2. 对生成的每种材料进行系统的四要素分析:
|
||||||
|
- 成分:详细分析元素组成、化学计量比及其理论依据
|
||||||
|
- 结构:分析晶体结构、空间群、晶格参数、原子排布及其稳定性
|
||||||
|
- 工艺:探讨可行的合成路线、工艺参数及其科学依据
|
||||||
|
- 性能:预测材料可能具有的物理、化学、机械性能及其应用前景
|
||||||
|
|
||||||
|
3. 为每种材料设计详细的合成方案,包括:
|
||||||
|
- 原料选择及纯度要求
|
||||||
|
- 精确的反应条件参数(温度、压力、时间、气氛等)
|
||||||
|
- 分步骤的合成流程及每步的理论依据
|
||||||
|
- 可能遇到的挑战及解决方案
|
||||||
|
- 推荐的表征方法
|
||||||
|
|
||||||
|
4. 使用mermaid语法绘制材料的合成流程图,清晰展示从原料到最终产品的全过程,包括关键工艺参数。"""
|
||||||
|
else:
|
||||||
|
prompt = f"""请生成{batch_size}种具有创新性的新型材料并设计其合成方案。
|
||||||
|
|
||||||
|
请按照以下步骤进行:
|
||||||
|
|
||||||
|
1. 使用mars_toolkit中的generate_material工具生成材料,参数设置为batch_size={batch_size}
|
||||||
|
2. 对生成的每种材料进行系统的四要素分析:
|
||||||
|
- 成分:详细分析元素组成、化学计量比及其理论依据
|
||||||
|
- 结构:分析晶体结构、空间群、晶格参数、原子排布及其稳定性
|
||||||
|
- 工艺:探讨可行的合成路线、工艺参数及其科学依据
|
||||||
|
- 性能:预测材料可能具有的物理、化学、机械性能及其应用前景
|
||||||
|
|
||||||
|
3. 为每种材料设计详细的合成方案,包括:
|
||||||
|
- 原料选择及纯度要求
|
||||||
|
- 精确的反应条件参数(温度、压力、时间、气氛等)
|
||||||
|
- 分步骤的合成流程及每步的理论依据
|
||||||
|
- 可能遇到的挑战及解决方案
|
||||||
|
- 推荐的表征方法
|
||||||
|
|
||||||
|
4. 使用mermaid语法绘制材料的合成流程图,清晰展示从原料到最终产品的全过程,包括关键工艺参数。"""
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
types.PromptMessage(
|
||||||
|
role="user",
|
||||||
|
content=types.TextContent(type="text", text=prompt)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def register_prompt_handlers(app: Server):
|
||||||
|
"""
|
||||||
|
注册提示词处理器到MCP服务器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: MCP服务器实例
|
||||||
|
"""
|
||||||
|
@app.list_prompts()
|
||||||
|
async def list_prompts() -> list[types.Prompt]:
|
||||||
|
return [
|
||||||
|
types.Prompt(
|
||||||
|
name="material_synthesis",
|
||||||
|
description="基于材料四要素(成分、结构、工艺、性能)生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||||||
|
arguments=[
|
||||||
|
types.PromptArgument(
|
||||||
|
name="properties",
|
||||||
|
description="材料性质及其值的JSON字符串,例如 {\"dft_band_gap\": \"2.0\"}",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
types.PromptArgument(
|
||||||
|
name="batch_size",
|
||||||
|
description="生成材料的数量,默认为2",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@app.get_prompt()
|
||||||
|
async def get_prompt(
|
||||||
|
name: str, arguments: dict[str, str] | None = None
|
||||||
|
) -> types.GetPromptResult:
|
||||||
|
if name != "material_synthesis":
|
||||||
|
raise ValueError(f"未知的提示词: {name}")
|
||||||
|
|
||||||
|
if arguments is None:
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
# 解析properties参数
|
||||||
|
properties = {}
|
||||||
|
if "properties" in arguments and arguments["properties"]:
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
properties = json.loads(arguments["properties"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
properties = {}
|
||||||
|
|
||||||
|
# 解析batch_size参数
|
||||||
|
batch_size = 2 # 默认值
|
||||||
|
if "batch_size" in arguments and arguments["batch_size"]:
|
||||||
|
try:
|
||||||
|
batch_size = int(arguments["batch_size"])
|
||||||
|
except ValueError:
|
||||||
|
pass # 使用默认值
|
||||||
|
|
||||||
|
return types.GetPromptResult(
|
||||||
|
messages=create_messages(properties=properties, batch_size=batch_size),
|
||||||
|
description="基于材料四要素(成分、结构、工艺、性能)生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||||||
|
)
|
||||||
306
server.py
Executable file
306
server.py
Executable file
@@ -0,0 +1,306 @@
|
|||||||
|
"""Mars Toolkit MCP Server implementation."""
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
import asyncio
|
||||||
|
import click
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from prompts.material_synthesis import create_messages
|
||||||
|
|
||||||
|
# 添加mars_toolkit模块的路径
|
||||||
|
sys.path.append('/home/ubuntu/50T/lzy/mars-mcp')
|
||||||
|
|
||||||
|
import mcp.types as types
|
||||||
|
from mcp.server.lowlevel import Server
|
||||||
|
|
||||||
|
# 导入提示词处理器
|
||||||
|
#from prompts.material_synthesis import register_prompt_handlers
|
||||||
|
|
||||||
|
# 导入Mars Toolkit工具函数
|
||||||
|
try:
|
||||||
|
# 获取当前时间
|
||||||
|
from mars_toolkit.misc.misc_tools import get_current_time
|
||||||
|
# 网络搜索
|
||||||
|
from mars_toolkit.query.web_search import search_online
|
||||||
|
# 从Materials Project查询材料属性
|
||||||
|
from mars_toolkit.query.mp_query import search_material_property_from_material_project
|
||||||
|
# 从Materials Project获取晶体结构
|
||||||
|
from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project
|
||||||
|
# 从化学式获取Materials Project ID
|
||||||
|
from mars_toolkit.query.mp_query import get_mpid_from_formula
|
||||||
|
# 优化晶体结构
|
||||||
|
from mars_toolkit.compute.structure_opt import optimize_crystal_structure
|
||||||
|
# 生成材料
|
||||||
|
from mars_toolkit.compute.material_gen import generate_material
|
||||||
|
# 从OQMD获取化学成分
|
||||||
|
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||||
|
# 从知识库检索
|
||||||
|
from mars_toolkit.query.dify_search import retrieval_from_knowledge_base
|
||||||
|
# 预测属性
|
||||||
|
from mars_toolkit.compute.property_pred import predict_properties
|
||||||
|
|
||||||
|
# 获取所有工具函数
|
||||||
|
from mars_toolkit import get_tools, get_tool_schemas
|
||||||
|
|
||||||
|
MARS_TOOLKIT_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"警告: 无法导入Mars Toolkit: {e}", file=sys.stderr)
|
||||||
|
MARS_TOOLKIT_AVAILABLE = False
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
app = Server("mars-toolkit-server")
|
||||||
|
|
||||||
|
|
||||||
|
async def call_mars_toolkit_function(func_name: str, arguments: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
调用Mars Toolkit中的工具函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_name: 工具函数名称
|
||||||
|
arguments: 工具函数参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具函数的执行结果
|
||||||
|
"""
|
||||||
|
if not MARS_TOOLKIT_AVAILABLE:
|
||||||
|
raise ValueError("Mars Toolkit不可用")
|
||||||
|
|
||||||
|
# 获取所有注册的工具函数
|
||||||
|
tools = get_tools()
|
||||||
|
|
||||||
|
# 检查函数名是否存在于工具函数字典中
|
||||||
|
if func_name not in tools:
|
||||||
|
raise ValueError(f"函数 '{func_name}' 不存在于工具函数字典中")
|
||||||
|
|
||||||
|
# 获取对应的工具函数
|
||||||
|
tool_func = tools[func_name]
|
||||||
|
|
||||||
|
# 调用工具函数
|
||||||
|
if asyncio.iscoroutinefunction(tool_func):
|
||||||
|
# 如果是异步函数,使用await调用
|
||||||
|
result = await tool_func(**arguments)
|
||||||
|
print("result1",result)
|
||||||
|
else:
|
||||||
|
# 如果是同步函数,直接调用
|
||||||
|
result = tool_func(**arguments)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_schemas_dict() -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取所有工具函数的模式字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具函数名称到模式的映射字典
|
||||||
|
"""
|
||||||
|
if not MARS_TOOLKIT_AVAILABLE:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
schemas = get_tool_schemas()
|
||||||
|
schemas_dict = {}
|
||||||
|
|
||||||
|
for schema in schemas:
|
||||||
|
func_name = schema["function"]["name"]
|
||||||
|
schemas_dict[func_name] = schema
|
||||||
|
|
||||||
|
return schemas_dict
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option("--port", default=5666, help="Port to listen on for SSE")
|
||||||
|
@click.option(
|
||||||
|
"--transport",
|
||||||
|
type=click.Choice(["stdio", "sse"]),
|
||||||
|
default="sse",
|
||||||
|
help="Transport type",
|
||||||
|
)
|
||||||
|
def main(port: int, transport: str='SSE') -> int:
|
||||||
|
"""
|
||||||
|
Mars Toolkit MCP Server主函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
port: SSE传输的端口号
|
||||||
|
transport: 传输类型,stdio或sse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
退出码
|
||||||
|
"""
|
||||||
|
if not MARS_TOOLKIT_AVAILABLE:
|
||||||
|
print("错误: Mars Toolkit不可用,请确保已正确安装", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
# 获取工具函数模式字典
|
||||||
|
schemas_dict = get_tool_schemas_dict()
|
||||||
|
|
||||||
|
# 注册提示词处理器
|
||||||
|
#register_prompt_handlers(app)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@app.list_prompts()
|
||||||
|
async def list_prompts() -> list[types.Prompt]:
|
||||||
|
return [
|
||||||
|
types.Prompt(
|
||||||
|
name="material_synthesis",
|
||||||
|
description="生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||||||
|
arguments=[
|
||||||
|
types.PromptArgument(
|
||||||
|
name="properties",
|
||||||
|
description="材料性质及其值的JSON字符串,例如 {\"dft_band_gap\": \"2.0\"}",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
types.PromptArgument(
|
||||||
|
name="batch_size",
|
||||||
|
description="生成材料的数量,默认为2",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@app.get_prompt()
|
||||||
|
async def get_prompt(
|
||||||
|
name: str, arguments: dict[str, str] | None = None
|
||||||
|
) -> types.GetPromptResult:
|
||||||
|
if name != "material_synthesis":
|
||||||
|
raise ValueError(f"未知的提示词: {name}")
|
||||||
|
|
||||||
|
if arguments is None:
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
# 解析properties参数
|
||||||
|
properties = {}
|
||||||
|
if "properties" in arguments and arguments["properties"]:
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
properties = json.loads(arguments["properties"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
properties = {}
|
||||||
|
|
||||||
|
# 解析batch_size参数
|
||||||
|
batch_size = 2 # 默认值
|
||||||
|
if "batch_size" in arguments and arguments["batch_size"]:
|
||||||
|
try:
|
||||||
|
batch_size = int(arguments["batch_size"])
|
||||||
|
except ValueError:
|
||||||
|
pass # 使用默认值
|
||||||
|
|
||||||
|
return types.GetPromptResult(
|
||||||
|
messages=create_messages(properties=properties, batch_size=batch_size),
|
||||||
|
description="生成材料并设计合成方案,使用mermaid绘制合成流程图",
|
||||||
|
)
|
||||||
|
@app.call_tool()
|
||||||
|
async def call_tool(
|
||||||
|
name: str, arguments: Dict[str, Any]
|
||||||
|
) -> List[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||||
|
"""
|
||||||
|
调用工具函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具函数名称
|
||||||
|
arguments: 工具函数参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具函数的执行结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print(f"调用{name},参数为{arguments}")
|
||||||
|
result = await call_mars_toolkit_function(name, arguments)
|
||||||
|
print("result",result)
|
||||||
|
# 将结果转换为字符串
|
||||||
|
if isinstance(result, (dict, list)):
|
||||||
|
result_str = json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
else:
|
||||||
|
result_str = str(result)
|
||||||
|
|
||||||
|
return [types.TextContent(type="text", text=result_str)]
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"调用工具函数 {name} 时出错: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return [types.TextContent(type="text", text=error_msg)]
|
||||||
|
|
||||||
|
@app.list_tools()
|
||||||
|
async def list_tools() -> List[types.Tool]:
|
||||||
|
"""
|
||||||
|
列出所有可用的工具函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具函数列表
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
print("列举所有可用的工具函数")
|
||||||
|
for func_name, schema in schemas_dict.items():
|
||||||
|
# 获取函数描述
|
||||||
|
description = schema["function"].get("description", f"Mars Toolkit工具: {func_name}")
|
||||||
|
|
||||||
|
# 获取参数模式
|
||||||
|
parameters = schema["function"].get("parameters", {})
|
||||||
|
|
||||||
|
# 创建工具
|
||||||
|
tool = types.Tool(
|
||||||
|
name=func_name,
|
||||||
|
description=description,
|
||||||
|
inputSchema=parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
if transport == "sse":
|
||||||
|
from mcp.server.sse import SseServerTransport
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
|
sse = SseServerTransport("/messages/")
|
||||||
|
|
||||||
|
async def handle_sse(request):
|
||||||
|
async with sse.connect_sse(
|
||||||
|
request.scope, request.receive, request._send
|
||||||
|
) as streams:
|
||||||
|
await app.run(
|
||||||
|
streams[0], streams[1], app.create_initialization_options()
|
||||||
|
)
|
||||||
|
|
||||||
|
starlette_app = Starlette(
|
||||||
|
debug=True,
|
||||||
|
routes=[
|
||||||
|
Route("/sse", endpoint=handle_sse),
|
||||||
|
Mount("/messages/", app=sse.handle_post_message),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
|
||||||
|
else:
|
||||||
|
from mcp.server.stdio import stdio_server
|
||||||
|
|
||||||
|
async def arun():
|
||||||
|
async with stdio_server() as streams:
|
||||||
|
await app.run(
|
||||||
|
streams[0], streams[1], app.create_initialization_options()
|
||||||
|
)
|
||||||
|
|
||||||
|
anyio.run(arun)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(get_tool_schemas_dict())
|
||||||
|
main()
|
||||||
|
|
||||||
4
test_mars_toolkit.py
Normal file → Executable file
4
test_mars_toolkit.py
Normal file → Executable file
@@ -155,7 +155,7 @@ def print_tool_schemas():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 打印所有工具函数的模式
|
# 打印所有工具函数的模式
|
||||||
#print_tool_schemas()
|
print_tool_schemas()
|
||||||
|
|
||||||
# 测试工具函数列表
|
# 测试工具函数列表
|
||||||
tools_to_test = [
|
tools_to_test = [
|
||||||
@@ -172,7 +172,7 @@ if __name__ == "__main__":
|
|||||||
]
|
]
|
||||||
|
|
||||||
# 选择要测试的工具
|
# 选择要测试的工具
|
||||||
tool_name = tools_to_test[1] # 测试 search_online 工具
|
tool_name = tools_to_test[2] # 测试 search_online 工具
|
||||||
|
|
||||||
# 运行测试
|
# 运行测试
|
||||||
result = asyncio.run(test_tool(tool_name))
|
result = asyncio.run(test_tool(tool_name))
|
||||||
|
|||||||
Reference in New Issue
Block a user