mcp,生成数据代码

This commit is contained in:
lzy
2025-04-16 11:15:01 +08:00
parent 72045e5cfe
commit 6b92e54a41
66 changed files with 1938 additions and 1483 deletions

View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
读取JSONL文件打印包含'generate_material content begin'的行
"""
import json
import sys
from rich.console import Console
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
console = Console()
def count_lines(file_path):
"""计算文件的总行数"""
with open(file_path, 'r', encoding='utf-8') as f:
return sum(1 for _ in f)
def filter_generate_material(file_path):
"""
读取JSONL文件打印包含'generate_material content begin'的行
Args:
file_path: JSONL文件路径
"""
# 计算文件总行数用于进度条
total_lines = count_lines(file_path)
console.print(f"[bold green]文件总行数: {total_lines}[/bold green]")
# 匹配计数
match_count = 0
# 使用进度条显示处理进度
with Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
) as progress:
# 创建进度条任务
task_id = progress.add_task("[bold green]处理JSONL文件...", total=total_lines)
# 逐行处理文件
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
try:
# 解析JSON
data = json.loads(line.strip())
# 检查是否包含observation字段
# if 'observation' in data:
# # 检查observation是否包含目标字符串
# if 'generate_material content begin' in data['observation'] and '[generate_material content begin]Error' not in data['observation']:
# console.print(f"[bold cyan]第 {line_num} 行包含目标内容:[/bold cyan]")
# console.print(line.strip())
# console.print("-" * 80)
# print(data)
# match_count += 1
# return
if 'solution' in data:
print("line_num",line_num)
print("solution",data['solution'])
if data['solution'] != "":
print( data['solution'])
return
except json.JSONDecodeError:
console.print(f"[bold red]第 {line_num} 行JSON解析错误[/bold red]")
except Exception as e:
console.print(f"[bold red]处理第 {line_num} 行时出错: {str(e)}[/bold red]")
# 更新进度条
progress.update(task_id, advance=1)
console.print(f"[bold green]处理完成,共找到 {match_count} 行包含 'generate_material content begin'[/bold green]")
if __name__ == "__main__":
# 默认文件路径
file_path ='/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl'
# "/home/ubuntu/50T/lzy/mars-mcp/mars-agent_data_20250408205427.jsonl"
# 如果提供了命令行参数,则使用命令行参数作为文件路径
if len(sys.argv) > 1:
file_path = sys.argv[1]
console.print(f"[bold blue]正在处理文件: {file_path}[/bold blue]")
filter_generate_material(file_path)

View File

@@ -0,0 +1,432 @@
import json
import asyncio
import concurrent.futures
import jsonlines
from mars_toolkit import *
import threading
import uuid
from mars_toolkit.compute.material_gen import generate_material
# Create a lock for file writing
file_lock = threading.Lock()
from mysql.connector import pooling
from colorama import Fore, Back, Style, init
import time
import random
# 初始化colorama
init(autoreset=True)
from typing import Dict, Union, Any, Optional, List
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化传递给generate_material函数的参数格式。
处理以下情况:
1. properties参数可能是字符串形式的JSON需要解析为字典
2. properties中的值可能需要转换为适当的类型数字或字符串
3. 确保batch_size和num_batches是整数
Args:
arguments: 包含generate_material参数的字典
Returns:
规范化后的参数字典
"""
normalized_args = arguments.copy()
# 处理properties参数
if "properties" in normalized_args:
properties = normalized_args["properties"]
# 如果properties是字符串尝试解析为JSON
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError as e:
raise ValueError(f"无法解析properties JSON字符串: {e}")
# 确保properties是字典
if not isinstance(properties, dict):
raise ValueError(f"properties必须是字典或JSON字符串而不是 {type(properties)}")
# 处理properties中的值
normalized_properties = {}
for key, value in properties.items():
# 处理范围值,例如 "0.0-2.0" 或 "40-50"
if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"):
# 保持范围值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith(">"):
# 保持大于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith("<"):
# 保持小于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.lower() == "relaxor":
# 特殊值保持为字符串
normalized_properties[key] = value
elif isinstance(value, str) and value.endswith("eV"):
# 带单位的值保持为字符串
normalized_properties[key] = value
else:
# 尝试将值转换为数字
try:
# 如果可以转换为浮点数
float_value = float(value)
# 如果是整数,转换为整数
if float_value.is_integer():
normalized_properties[key] = int(float_value)
else:
normalized_properties[key] = float_value
except (ValueError, TypeError):
# 如果无法转换为数字,保持原值
normalized_properties[key] = value
normalized_args["properties"] = normalized_properties
# 确保batch_size和num_batches是整数
if "batch_size" in normalized_args:
try:
normalized_args["batch_size"] = int(normalized_args["batch_size"])
except (ValueError, TypeError):
raise ValueError(f"batch_size必须是整数而不是 {normalized_args['batch_size']}")
if "num_batches" in normalized_args:
try:
normalized_args["num_batches"] = int(normalized_args["num_batches"])
except (ValueError, TypeError):
raise ValueError(f"num_batches必须是整数而不是 {normalized_args['num_batches']}")
# 确保diffusion_guidance_factor是浮点数
if "diffusion_guidance_factor" in normalized_args:
try:
normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"])
except (ValueError, TypeError):
raise ValueError(f"diffusion_guidance_factor必须是数字而不是 {normalized_args['diffusion_guidance_factor']}")
return normalized_args
import requests
# 创建数据库连接池(仅用于初始加载数据)
connection_pool = pooling.MySQLConnectionPool(
pool_name="mypool",
pool_size=32,
pool_reset_session=True,
host='localhost',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
# 内存缓存,用于存储从数据库加载的数据
# 结构: {doi: record, mp_id: record}
memory_cache = {}
def load_data_to_memory():
"""
从数据库加载所有数据到内存中
"""
print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}")
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
# 查询所有记录
cursor.execute("SELECT * FROM mp_synthesis_scheme_info")
records = cursor.fetchall()
# 将记录添加到内存缓存中
for record in records:
doi = record.get('doi')
mp_id = record.get('mp_id')
# 使用doi作为键如果存在
if doi:
memory_cache[doi] = record
# 使用mp_id作为键如果存在
if mp_id:
memory_cache[mp_id] = record
print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}")
print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}")
finally:
cursor.close()
finally:
conn.close()
# 在程序启动时加载数据到内存中
load_data_to_memory()
async def process_retrieval_from_knowledge_base(data):
doi = data.get('doi')
mp_id = data.get('mp_id')
# 检查是否提供了至少一个查询参数
if doi is None and mp_id is None:
return "" # 如果没有提供查询参数,返回空字符串
# 从内存缓存中查询匹配的记录
result = None
if doi is not None and doi in memory_cache:
result = memory_cache[doi]
elif mp_id is not None and mp_id in memory_cache:
result = memory_cache[mp_id]
# 检查是否找到匹配的记录
if not result:
return "" # 如果没有找到匹配记录,返回空字符串
# 构建markdown格式的结果
markdown_result = ""
# 添加各个字段除了doi和mp_id
fields = [
"target_material",
"reaction_string",
"chara_structure",
"chara_performance",
"chara_application",
"synthesis_schemes"
]
for field in fields:
# 获取字段内容
field_content = result.get(field, "")
# 只有当字段内容不为空时才添加该字段
if field_content and field_content.strip():
markdown_result += f"\n## {field}\n{field_content}\n\n"
return markdown_result # 直接返回markdown文本
async def execute_tool_from_dict(input_dict: dict):
"""
从字典中提取工具函数名称和参数,并执行相应的工具函数
Args:
input_dict: 字典,例如:
{"name": "search_material_property_from_material_project",
"arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}
Returns:
工具函数的执行结果,如果工具函数不存在则返回错误信息
"""
try:
# 解析输入字符串为字典
# input_dict = json.loads(input_str)
# 提取函数名和参数
func_name = input_dict.get("name")
arguments_data = input_dict.get("arguments")
#print('func_name', func_name)
#print("argument", arguments_data)
if not func_name:
return {"status": "error", "message": "未提供函数名称"}
# 获取所有注册的工具函数
tools = get_tools()
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"}
# 获取对应的工具函数
tool_func = tools[func_name]
# 处理参数
arguments = {}
if arguments_data:
# 检查arguments是字符串还是字典
if isinstance(arguments_data, dict):
# 如果已经是字典,直接使用
arguments = arguments_data
elif isinstance(arguments_data, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(arguments_data)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": arguments_data}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
except json.JSONDecodeError as e:
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
except Exception as e:
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
def worker(data, output_file_path):
try:
func_contents = data["function_calls"]
func_results = []
formatted_results = [] # 新增一个列表来存储格式化后的结果
for func in func_contents:
func_name = func.get("name")
arguments_data = func.get("arguments")
# 使用富文本打印函数名
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# 使用富文本打印参数
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
try:
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
continue
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
# 优先使用mattergen函数
try:
output = generate_material(**normalized_args)
except Exception as e:
#print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
continue
# 将结果添加到func_results
func_results.append({"function": func_name, "result": output})
# 格式化结果
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
continue
else:
result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
func_name = func.get("name")
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['observation'] = final_result
#使用富文本打印开始和结束标记
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
# print(data['observation'])
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}"
def main(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing papers")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_path = {}
for path in datas:
future = executor.submit(worker, path, output_file_path)
future_to_path[future] = path
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_path):
path = future_to_path[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {path} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
import jsonlines
datas_with_solution = []
datas_without_solution = []
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
for obj in reader:
if obj['solution']!='':
datas_with_solution.append(obj)
else:
datas_without_solution.append(obj)
datas_with_solution = datas_with_solution[:5000]
datas_without_solution = datas_without_solution[:5000]
datas = datas_with_solution + datas_without_solution
import random
random.shuffle(datas)
output_file = f"./filter_ok_questions_solutions_agent_data10000_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=32)

View File

@@ -0,0 +1,428 @@
import json
import asyncio
import concurrent.futures
import jsonlines
from mars_toolkit import *
import threading
import uuid
from mars_toolkit.compute.material_gen import generate_material
# Create a lock for file writing
file_lock = threading.Lock()
from mysql.connector import pooling
from colorama import Fore, Back, Style, init
import time
import random
# 初始化colorama
init(autoreset=True)
from typing import Dict, Union, Any, Optional, List
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化传递给generate_material函数的参数格式。
处理以下情况:
1. properties参数可能是字符串形式的JSON需要解析为字典
2. properties中的值可能需要转换为适当的类型数字或字符串
3. 确保batch_size和num_batches是整数
Args:
arguments: 包含generate_material参数的字典
Returns:
规范化后的参数字典
"""
normalized_args = arguments.copy()
# 处理properties参数
if "properties" in normalized_args:
properties = normalized_args["properties"]
# 如果properties是字符串尝试解析为JSON
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError as e:
raise ValueError(f"无法解析properties JSON字符串: {e}")
# 确保properties是字典
if not isinstance(properties, dict):
raise ValueError(f"properties必须是字典或JSON字符串而不是 {type(properties)}")
# 处理properties中的值
normalized_properties = {}
for key, value in properties.items():
# 处理范围值,例如 "0.0-2.0" 或 "40-50"
if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"):
# 保持范围值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith(">"):
# 保持大于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith("<"):
# 保持小于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.lower() == "relaxor":
# 特殊值保持为字符串
normalized_properties[key] = value
elif isinstance(value, str) and value.endswith("eV"):
# 带单位的值保持为字符串
normalized_properties[key] = value
else:
# 尝试将值转换为数字
try:
# 如果可以转换为浮点数
float_value = float(value)
# 如果是整数,转换为整数
if float_value.is_integer():
normalized_properties[key] = int(float_value)
else:
normalized_properties[key] = float_value
except (ValueError, TypeError):
# 如果无法转换为数字,保持原值
normalized_properties[key] = value
normalized_args["properties"] = normalized_properties
# 确保batch_size和num_batches是整数
if "batch_size" in normalized_args:
try:
normalized_args["batch_size"] = int(normalized_args["batch_size"])
except (ValueError, TypeError):
raise ValueError(f"batch_size必须是整数而不是 {normalized_args['batch_size']}")
if "num_batches" in normalized_args:
try:
normalized_args["num_batches"] = int(normalized_args["num_batches"])
except (ValueError, TypeError):
raise ValueError(f"num_batches必须是整数而不是 {normalized_args['num_batches']}")
# 确保diffusion_guidance_factor是浮点数
if "diffusion_guidance_factor" in normalized_args:
try:
normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"])
except (ValueError, TypeError):
raise ValueError(f"diffusion_guidance_factor必须是数字而不是 {normalized_args['diffusion_guidance_factor']}")
return normalized_args
import requests
# 创建数据库连接池(仅用于初始加载数据)
connection_pool = pooling.MySQLConnectionPool(
pool_name="mypool",
pool_size=32,
pool_reset_session=True,
host='localhost',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
# 内存缓存,用于存储从数据库加载的数据
# 结构: {doi: record, mp_id: record}
memory_cache = {}
def load_data_to_memory():
"""
从数据库加载所有数据到内存中
"""
print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}")
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
# 查询所有记录
cursor.execute("SELECT * FROM mp_synthesis_scheme_info")
records = cursor.fetchall()
# 将记录添加到内存缓存中
for record in records:
doi = record.get('doi')
mp_id = record.get('mp_id')
# 使用doi作为键如果存在
if doi:
memory_cache[doi] = record
# 使用mp_id作为键如果存在
if mp_id:
memory_cache[mp_id] = record
print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}")
print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}")
finally:
cursor.close()
finally:
conn.close()
# 在程序启动时加载数据到内存中
load_data_to_memory()
async def process_retrieval_from_knowledge_base(data):
doi = data.get('doi')
mp_id = data.get('mp_id')
# 检查是否提供了至少一个查询参数
if doi is None and mp_id is None:
return "" # 如果没有提供查询参数,返回空字符串
# 从内存缓存中查询匹配的记录
result = None
if doi is not None and doi in memory_cache:
result = memory_cache[doi]
elif mp_id is not None and mp_id in memory_cache:
result = memory_cache[mp_id]
# 检查是否找到匹配的记录
if not result:
return "" # 如果没有找到匹配记录,返回空字符串
# 构建markdown格式的结果
markdown_result = ""
# 添加各个字段除了doi和mp_id
fields = [
"target_material",
"reaction_string",
"chara_structure",
"chara_performance",
"chara_application",
"synthesis_schemes"
]
for field in fields:
# 获取字段内容
field_content = result.get(field, "")
# 只有当字段内容不为空时才添加该字段
if field_content and field_content.strip():
markdown_result += f"\n## {field}\n{field_content}\n\n"
return markdown_result # 直接返回markdown文本
async def execute_tool_from_dict(input_dict: dict):
"""
从字典中提取工具函数名称和参数,并执行相应的工具函数
Args:
input_dict: 字典,例如:
{"name": "search_material_property_from_material_project",
"arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}
Returns:
工具函数的执行结果,如果工具函数不存在则返回错误信息
"""
try:
# 解析输入字符串为字典
# input_dict = json.loads(input_str)
# 提取函数名和参数
func_name = input_dict.get("name")
arguments_data = input_dict.get("arguments")
#print('func_name', func_name)
#print("argument", arguments_data)
if not func_name:
return {"status": "error", "message": "未提供函数名称"}
# 获取所有注册的工具函数
tools = get_tools()
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"}
# 获取对应的工具函数
tool_func = tools[func_name]
# 处理参数
arguments = {}
if arguments_data:
# 检查arguments是字符串还是字典
if isinstance(arguments_data, dict):
# 如果已经是字典,直接使用
arguments = arguments_data
elif isinstance(arguments_data, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(arguments_data)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": arguments_data}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
except json.JSONDecodeError as e:
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
except Exception as e:
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
def worker(data, output_file_path):
try:
func_contents = data["function_calls"]
func_results = []
formatted_results = [] # 新增一个列表来存储格式化后的结果
for func in func_contents:
func_name = func.get("name")
arguments_data = func.get("arguments")
# 使用富文本打印函数名
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# 使用富文本打印参数
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
try:
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
continue
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
# 优先使用mattergen函数
try:
output = generate_material(**normalized_args)
except Exception as e:
#print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
continue
# 将结果添加到func_results
func_results.append({"function": func_name, "result": output})
# 格式化结果
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
continue
else:
result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
func_name = func.get("name")
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['observation'] = final_result
#使用富文本打印开始和结束标记
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
# print(data['observation'])
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}"
def main(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing papers")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_path = {}
for path in datas:
future = executor.submit(worker, path, output_file_path)
future_to_path[future] = path
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_path):
path = future_to_path[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {path} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
import jsonlines
datas = []
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
for obj in reader:
#if obj['solution']!='':
datas.append(obj)
print(len(datas))
# print()
output_file = f"./filter_ok_questions_solutions_agent_data_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=32)
# 示例1使用正确的JSON格式
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
# argument = json.loads(argument)
# print(json.dumps(argument, indent=2))
# asyncio.run(mattergen(**argument))

91
generate_data/grpo_tools.py Executable file
View File

@@ -0,0 +1,91 @@
import jsonlines
import argparse
import generate_data.utils as utils
import glob
import json
from ase import io
import tempfile
import re
from pymatgen.io.vasp import Poscar
from pymatgen.io.cif import CifParser
import threading
import concurrent.futures
import copy
from grpo_utils import generate_design_question, generate_props_question, generate_obs_response
# Create a lock for file writing
file_lock = threading.Lock()
def worker(data, output_file_path):
try:
messages = copy.deepcopy(data['messages'])
obs = data['observation']
messages[-1]['content'] = messages[-1]['content'].split("<answer>")[-1].split("</answer>")[0]
messages.append({"role": "user", "content": obs})
data['messages'].append({"role": "user", "content": obs})
# print(messages)
# print(obs)
reasoning_content, response = generate_obs_response(messages)
data['messages'].append({"role": "assistant", "content": f"<think>\n{reasoning_content}</think>\n<answer>\n{response}</answer>\n"})
# Use the lock to safely write to the file
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(messages)
return f"Processed successfully"
except Exception as e:
return f"Error processing: {str(e)}"
def main(input_file_path, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
datas = None
with jsonlines.open(input_file_path, mode='r') as reader:
datas = [line for line in reader]
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing CIF files")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_data = {}
for data in datas:
future = executor.submit(worker, data, output_file_path)
future_to_data[future] = data
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_data):
data = future_to_data[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {data} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
origin_file = "/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl"
output_file = f"agent_questions_solutions_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(origin_file, output_file)

294
generate_data/grpo_utils.py Executable file
View File

@@ -0,0 +1,294 @@
import jsonlines
import argparse
import generate_data.utils as utils
import glob
import json
from ase import io
import tempfile
import re
from pymatgen.io.vasp import Poscar
from pymatgen.io.cif import CifParser
import threading
import concurrent.futures
# Create a lock for file writing
file_lock = threading.Lock()
def generate_design_question(crystal_desc, cif_info, crystal_props, max_retries=3, initial_backoff=1.0):
instruction = """
{crystal_desc}
### 对应的晶体结构数据(CIF)如下:
{cif_info}
### 该晶体结构的物理化学性质为:
{crystal_props}
根据如上信息我现在需要给材料科学的博士考试出题问题要求博士们回答出上文中的完整CIF文件如果是你你会如何出题
也就是说要求我们提出的问题的答案是上文中提及的完整CIF文件。当然你的问题必须给定充足的该晶体结构的相关信息。
但是相关信息应该抽象和隐晦,避免过于直白,除明确的化学表达式外,尽量避免过多的精确信息,让博士考生们可以通过推理得到某些信息以增加问题的难度。
问题的语言一半是中文,一半是英文,以便更好地与模型进行交互。
请先生成10个问题示例再挑选2个最好的问题示例并遵循如下格式输出
```json
{
"selected_questions": [
{
"question_id": 1,
"question_text": "问题1的完整内容...",
},
{
"question_id": 2,
"question_text": "问题2的完整内容...",
}
]
}
"""
instruction = instruction.replace("{crystal_desc}", crystal_desc).replace("{cif_info}", cif_info).replace("{crystal_props}", crystal_props)
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": instruction}
]
import time
start_time = time.time()
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
if _response == 'apierror' or _response == 'unexpectederror':
return _response
# 尝试从响应中提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
questions_data = json.loads(json_str)
return questions_data
except json.JSONDecodeError:
# 如果JSON解析失败尝试清理字符串后再次解析
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
try:
questions_data = json.loads(cleaned_json)
return questions_data
except:
return {"error": "Failed to parse JSON response", "raw_response": _response}
else:
# 如果没有找到JSON格式返回原始响应
return {"error": "No JSON format found in response", "raw_response": _response}
def generate_props_question(crystal_desc, cif_info, crystal_props, max_retries=3, initial_backoff=1.0):
instruction = """
{crystal_desc}
### 对应的晶体结构数据(CIF)如下:
{cif_info}
### 该晶体结构的物理化学性质为:
{crystal_props}
根据如上信息我现在需要给材料科学的博士考试出题问题要求博士们根据CIF文件回答出上文中的物理化学性质如果是你你会如何出题
也就是说,要求我们提出的问题的答案是上文中提及的物理化学性质。当然,你的问题必须尽量包含一个<placeholder>标签代表给定的CIF文件。
让博士考生们根据给定的CIF文件通过深入思考和推理去分析该种晶体材料在上文所提及的全部物理化学性质并用JSON格式回答全部的物理化学性质。
问题的语言一半是中文,一半是英文,以便更好地与模型进行交互。
示例的问题:
1. <placeholder>\n根据上文提供的CIF文件请你xxx
2. 根据下文提供的CIF文件请你xxx\n <<placeholder>>
请先生成10个问题示例再挑选2个最好的问题示例并遵循如下格式输出
```json
{
"selected_questions": [
{
"question_id": 1,
"question_text": "问题1的完整内容...",
},
{
"question_id": 2,
"question_text": "问题2的完整内容...",
}
]
}
```
"""
instruction = instruction.replace("{crystal_desc}", crystal_desc).replace("{cif_info}", cif_info).replace("{crystal_props}", crystal_props)
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": instruction}
]
import time
start_time = time.time()
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
if _response == 'apierror' or _response == 'unexpectederror':
return _response
# 尝试从响应中提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
questions_data = json.loads(json_str)
return questions_data
except json.JSONDecodeError:
# 如果JSON解析失败尝试清理字符串后再次解析
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
try:
questions_data = json.loads(cleaned_json)
return questions_data
except:
return {"error": "Failed to parse JSON response", "raw_response": _response}
else:
# 如果没有找到JSON格式返回原始响应
return {"error": "No JSON format found in response", "raw_response": _response}
def generate_papers_other_question(paper_info, max_retries=3, initial_backoff=1.0):
instruction = """
{paper_info}
根据如上信息,我现在需要给材料科学的博士学生出题,问题要求考察博士对该材料的反应方程式、结构、性能和应用是否完全掌握,如果是你你会怎么出题?
你的问题里面应该包含该材料相关的合适的信息,且是自包含的(在只有问题的情况下问题中的关键信息不遗漏),但问题需要有难度和深度,需要博士生们深入思考和推理后才能作为准确的回答。
由于问题面向博士,因此,提出的问题需要一定的科研价值导向。涉及到反应方程式、关于结构、性能和应用等方面的具体试剂量等信息时,要求他们尽可能给出精确的数值(前提是这些数值在上文中存在)。
请先生成12个问题示例12个问题的语言一半是中文一半是英文再挑选4个最好的问题示例并遵循如下格式输出
```json
{
"selected_questions": [
{
"question_id": 1,
"question_text": "问题1的完整内容...",
"question_type": "问题1的类型", # reaction_string; structure; performence; application
},
{
"question_id": 2,
"question_text": "问题2的完整内容...",
"question_type": "问题1的类型", # reaction_string; structure; performence; application
}, ...
]
}
"""
instruction = instruction.replace("{paper_info}", paper_info)
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": instruction}
]
import time
start_time = time.time()
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
if _response == 'apierror' or _response == 'unexpectederror':
return _response
# 尝试从响应中提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
questions_data = json.loads(json_str)
return questions_data
except json.JSONDecodeError:
# 如果JSON解析失败尝试清理字符串后再次解析
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
try:
questions_data = json.loads(cleaned_json)
return questions_data
except:
return {"error": "Failed to parse JSON response", "raw_response": _response}
else:
# 如果没有找到JSON格式返回原始响应
return {"error": "No JSON format found in response", "raw_response": _response}
def generate_papers_synthesis_question(paper_info, max_retries=3, initial_backoff=1.0):
instruction = """
{paper_info}
根据如上信息,我现在需要给材料科学的博士学生出题,问题要求考察博士是否完全掌握该材料的合成方案,是否完全掌握给定材料的结构和性能到合成方案的精准映射关系,如果是你你会怎么出题?
你的问题里面应该包含该材料充分的结构和性能信息问题需要有难度和深度需要博士生们深入思考和推理后才能给出准确的合成方案并整理成JSON格式的格式化合成方案。
由于问题面向博士,因此,提出的问题需要一定的科研价值导向,并且要求博士在回答该材料的合成方案时给出精确的数值(包括试剂、前驱体、容器、温度等合成条件)。
问题中作为条件信息的部分需要尽可能的在问题中明确而不是隐晦(你要考虑到博士们拿到问题的时候并不知道上文中的信息,所以类似“基于给定的材料结构和性能信息”这种问法应该尽量避免)。
请先生成6个问题示例6个问题的语言一半是中文一半是英文再挑选2个最好的问题示例并遵循如下格式输出
```json
{
"selected_questions": [
{
"question_id": 1,
"question_text": "问题1的完整内容...",
},
{
"question_id": 2,
"question_text": "问题2的完整内容...",
},
]
}
"""
instruction = instruction.replace("{paper_info}", paper_info)
messages=[
{"role": "system", "content": ""},
{"role": "user", "content": instruction}
]
import time
start_time = time.time()
_response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
if _response == 'apierror' or _response == 'unexpectederror':
return _response
# 尝试从响应中提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
questions_data = json.loads(json_str)
return questions_data
except json.JSONDecodeError:
# 如果JSON解析失败尝试清理字符串后再次解析
cleaned_json = re.sub(r'[\n\r\t]', '', json_str)
try:
questions_data = json.loads(cleaned_json)
return questions_data
except:
return {"error": "Failed to parse JSON response", "raw_response": _response}
else:
# 如果没有找到JSON格式返回原始响应
return {"error": "No JSON format found in response", "raw_response": _response}
def generate_function_call(messages, tools, max_retries=3, initial_backoff=1.0):
import time
start_time = time.time()
instruction = """
# 问题
{question}
# 指令
在准确的回答上述问题之前,你只有现在这一次机会允许你调用工具以获取更多信息。
请尽可能深入思考上述问题,并尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。
因此,你需要在回答中一次给出多个经过思考后的工具调用,以便更好地回答上述问题。
思考和回答时使用和问题相同的语言。
"""
messages[0]["content"] = instruction.replace("{question}", messages[0]["content"])
_response, functions = utils.get_response_from_qwq(messages, model_name="qwq-32b", tools=tools, max_retries=max_retries, initial_backoff=initial_backoff)
# reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff)
# print(f"Time: {time.time() - start_time}")
# print(_response)
# if _response == 'apierror' or _response == 'unexpectederror':
# return _response
return _response, functions
def generate_obs_response(messages, max_retries=3, initial_backoff=1.0):
import time
start_time = time.time()
_reasoning_content, response = utils.get_response_from_deepseek_r1(messages, prefix=False, max_retries=max_retries, initial_backoff=initial_backoff)
return _reasoning_content, response

800
generate_data/utils.py Executable file
View File

@@ -0,0 +1,800 @@
"""
This script generates questions and answers from a given set of CIFs.
It uses the OpenAI API and MySQL for storing and retrieving data.
@author: Yutang Li
"""
import multiprocessing
import sqlite3
import tiktoken
import re
from fractions import Fraction
import numpy as np
import glob
import tqdm
import copy
import json
import time
import random
from openai import OpenAI, APIError, RateLimitError
from mysql.connector import pooling, Error
from collections import Counter
def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, max_retries: int = 3, initial_backoff: float = 1.0):
"""
Get response from DeepSeek API with retry mechanism.
Args:
messages: List of message dictionaries
prefix: Whether to use the prefix URL
max_retries: Maximum number of retry attempts
initial_backoff: Initial backoff time in seconds
Returns:
Tuple of (reasoning_content, content) or error messages
"""
retries = 0
while retries <= max_retries:
try:
base_url = "https://api.deepseek.com/beta" if prefix else "https://vip.apiyi.com/v1"
api_key = "sk-59279cc16ec740089146ef9aef9c1671" if prefix else "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
client = OpenAI(api_key=api_key, base_url=base_url)
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
response = client.chat.completions.create(
model="deepseek-r1",
messages=messages,
temperature=0.6
)
# reasoning_content = "null" if prefix else "<think>\n" + response.choices[0].message.model_extra['reasoning_content'] + "\n</think>\n"
reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
content = response.choices[0].message.content.split("</think>\n")[-1]
return reasoning_content, content
except RateLimitError as rate_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for RateLimitError: {rate_error}")
return 'apierror', 'apierror'
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
time.sleep(backoff_time)
except APIError as api_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for APIError: {api_error}")
return 'apierror', 'apierror'
# Check if the error is retryable
error_str = str(api_error)
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
time.sleep(backoff_time)
else:
# Non-retryable API error
print(f"Non-retryable API error: {api_error}")
return 'apierror', 'apierror'
except Exception as e:
print(f"generate_design_question Unexpected error: {e}")
return 'unexpectederror', 'unexpectederror'
def get_response_from_llm(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
"""
Get response from LLM API with retry mechanism.
Args:
messages: List of message dictionaries
model_name: Name of the model to use
tools: Optional list of tools to use
max_retries: Maximum number of retry attempts
initial_backoff: Initial backoff time in seconds
Returns:
Content of the response or error message
"""
retries = 0
while retries <= max_retries:
try:
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
if tools is None:
response = client.chat.completions.create(
model=model_name,
messages=messages,
)
else:
response = client.chat.completions.create(
model=model_name,
messages=messages,
tools=tools,
tool_choice='auto',
parallel_tool_calls=True
)
content = response.choices[0].message.content
return content
except RateLimitError as rate_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for RateLimitError: {rate_error}")
return 'apierror'
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
time.sleep(backoff_time)
except APIError as api_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for APIError: {api_error}")
return 'apierror'
# Check if the error is retryable
error_str = str(api_error)
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
time.sleep(backoff_time)
else:
# Non-retryable API error
print(f"Non-retryable API error: {api_error}")
return 'apierror'
except Exception as e:
print(f"generate_design_question Unexpected error: {e}")
return 'unexpectederror'
def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
"""
Get response from LLM API with retry mechanism.
Args:
messages: List of message dictionaries
model_name: Name of the model to use
tools: Optional list of tools to use
max_retries: Maximum number of retry attempts
initial_backoff: Initial backoff time in seconds
Returns:
Content of the response or error message
"""
retries = 0
while retries <= max_retries:
try:
# client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
# client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
import random
if random.random() > 0.5:
client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
else:
client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
if tools is None:
response = client.chat.completions.create(
model=model_name,
messages=messages,
stream=True
)
else:
response = client.chat.completions.create(
model=model_name,
messages=messages,
tools=tools,
tool_choice='auto',
parallel_tool_calls=True,
stream=True
)
reasoning_content = "" # 定义完整思考过程
answer_content = "" # 定义完整回复
tool_info = [] # 存储工具调用信息
is_answering = False # 判断是否结束思考过程并开始回复
# print("="*20+"思考过程"+"="*20)
for chunk in response:
# if not chunk.choices:
# # 处理用量统计信息
# print("\n"+"="*20+"Usage"+"="*20)
# print(chunk.usage)
# else:
delta = chunk.choices[0].delta
# 处理AI的思考过程链式推理
if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None:
reasoning_content += delta.reasoning_content
# print(delta.reasoning_content,end="",flush=True) # 实时输出思考过程
# 处理最终回复内容
else:
if not is_answering: # 首次进入回复阶段时打印标题
is_answering = True
# print("\n"+"="*20+"回复内容"+"="*20)
if delta.content is not None:
answer_content += delta.content
# print(delta.content,end="",flush=True) # 流式输出回复内容
# 处理工具调用信息(支持并行工具调用)
if delta.tool_calls is not None:
for tool_call in delta.tool_calls:
index = tool_call.index # 工具调用索引,用于并行调用
# 动态扩展工具信息存储列表
while len(tool_info) <= index:
tool_info.append({})
# 收集工具调用ID用于后续函数调用
# if tool_call.id:
# tool_info[index]['id'] = tool_info[index].get('id', '') + tool_call.id
# 收集函数名称(用于后续路由到具体函数)
if tool_call.function and tool_call.function.name:
tool_info[index]['name'] = tool_info[index].get('name', '') + tool_call.function.name
# 收集函数参数JSON字符串格式需要后续解析
if tool_call.function and tool_call.function.arguments:
tool_info[index]['arguments'] = tool_info[index].get('arguments', '') + tool_call.function.arguments
tools_response = ""
for tool in tool_info:
tools_response += ("<tool_call>\n" + json.dumps(tool, ensure_ascii=False) + "\n</tool_call>\n")
response = "<think>\n" + reasoning_content + "\n</think>\n" + "<answer>\n" + answer_content + tools_response + "\n</answer>\n"
return response, tool_info
# return reasoning_content, answer_content, tool_info
except RateLimitError as rate_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for RateLimitError: {rate_error}")
return 'apierror', []
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
time.sleep(backoff_time)
except APIError as api_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for APIError: {api_error}")
return 'apierror', []
# Check if the error is retryable
error_str = str(api_error)
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
time.sleep(backoff_time)
else:
# Non-retryable API error
print(f"Non-retryable API error: {api_error}")
return 'apierror', []
except Exception as e:
print(f"generate_design_question Unexpected error: {e}")
return 'unexpectederror', []
def read_json_file(file_path):
"""Read the json file and return its content."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return None
################################## utils
def clean_all_repetitions_with_details(text, min_length=10, threshold=10):
"""
综合清理文本中的各种重复内容,并返回详细信息
参数:
- text: 要清理的文本
- min_length: 最小重复片段长度
- threshold: 重复内容的阈值
返回:
- cleaned_text: 清理后的文本
- is_repetitive: 是否检测到重复
- repetition_details: 重复内容的详细信息
"""
original_text = text
is_repetitive = False
repetition_details = []
# 1. 首先处理有换行符的重复
if '\n' in text:
lines = text.split('\n')
unique_lines = []
line_counts = {}
for i, line in enumerate(lines):
normalized = line.strip().lower()
if not normalized:
unique_lines.append(line)
continue
line_counts[normalized] = line_counts.get(normalized, 0) + 1
if line_counts[normalized] <= threshold:
unique_lines.append(line)
# 如果这是第一次超过阈值,记录重复详情
if line_counts[normalized] == threshold + 1:
# 找到原始形式(保留大小写)
original_form = None
for l in lines[:i]:
if l.strip().lower() == normalized:
original_form = l
break
if original_form is None:
original_form = line
repetition_details.append({
'type': 'line_repetition',
'repeated_string': original_form,
'repeat_count': line_counts[normalized]
})
if any(count > threshold for count in line_counts.values()):
text = '\n'.join(unique_lines)
is_repetitive = True
# 2. 处理同一行内的连续重复模式
for length in range(min_length, 101):
pattern = r'(.{' + str(length) + r'})(\1)+'
while True:
match = re.search(pattern, text)
if not match:
break
repeated_part = match.group(1)
full_match = match.group(0)
# 计算重复次数
repeat_count = len(full_match) // len(repeated_part)
# 记录重复详情
repetition_details.append({
'type': 'inline_repetition',
'repeated_string': repeated_part,
'repeat_count': repeat_count,
'total_length': len(full_match),
'position': match.start()
})
text = text.replace(full_match, repeated_part)
is_repetitive = True
# 3. 处理句子级别的重复
sentences = re.split(r'(?<=[.!?。?!])\s+', text)
if len(sentences) > 1:
sentence_counter = Counter(sentences)
for sentence, count in sentence_counter.items():
if count > threshold:
repetition_details.append({
'type': 'sentence_repetition',
'repeated_string': sentence,
'repeat_count': count
})
if any(count > threshold for count in sentence_counter.values()):
unique_sentences = []
seen_sentences = {}
for sentence in sentences:
seen_sentences[sentence] = seen_sentences.get(sentence, 0) + 1
if seen_sentences[sentence] <= threshold:
unique_sentences.append(sentence)
# 重新组合文本
text = ' '.join(unique_sentences)
is_repetitive = True
# 4. 处理更短的重复(如果前面的方法没有检测到重复)
if not is_repetitive and min_length > 5:
for length in range(5, min_length):
pattern = r'(.{' + str(length) + r'})(\1){2,}' # 至少重复3次才处理
while True:
match = re.search(pattern, text)
if not match:
break
repeated_part = match.group(1)
full_match = match.group(0)
# 计算重复次数
repeat_count = len(full_match) // len(repeated_part)
# 记录重复详情
repetition_details.append({
'type': 'short_repetition',
'repeated_string': repeated_part,
'repeat_count': repeat_count,
'total_length': len(full_match),
'position': match.start()
})
text = text.replace(full_match, repeated_part)
is_repetitive = True
# 按重复类型和长度排序
repetition_details.sort(key=lambda x: (-len(x['repeated_string']), x['type']))
return text, is_repetitive or text != original_text, repetition_details
def create_table(table_name, connection_pool):
"""Create the required MySQL table if it does not exist."""
db = connection_pool.get_connection()
cursor = db.cursor()
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
id INT AUTO_INCREMENT PRIMARY KEY,
mp_id TEXT,
question_model TEXT,
question TEXT,
answer_model TEXT,
answer TEXT,
answer_len INT
)
"""
cursor.execute(create_table_query)
db.commit()
cursor.close()
db.close()
def record_exists(mp_id, table_name, connection_pool):
"""Check if a mp_id already exists in the table."""
db = connection_pool.get_connection()
cursor = db.cursor()
query = f"SELECT * FROM {table_name} WHERE mp_id = %s"
cursor.execute(query, (mp_id,))
result = cursor.fetchone()
cursor.fetchall() # Ensure all results are processed
cursor.close()
db.close()
return result is not None
def insert_record(entry, table_name, connection_pool):
"""Insert a record into the MySQL table."""
db = None
cursor = None
try:
db = connection_pool.get_connection()
cursor = db.cursor()
insert_query = f"""
INSERT INTO {table_name}
(mp_id, question_model, question, answer_model, answer, answer_len)
VALUES (%s, %s, %s, %s, %s, %s)
"""
values = (
entry["mp_id"], entry["question_model"],
entry["question"], entry["answer_model"], entry["answer"], entry["answer_len"],
)
cursor.execute(insert_query, values)
db.commit()
except Error as e:
print(f"Error: {e}")
db.rollback()
finally:
# Ensure cursor is closed
if cursor:
cursor.close()
# Ensure connection is returned to the pool
if db:
db.close()
# Initialize SQLite database connection
def initialize_db():
conn = sqlite3.connect('multi_turns_data.db', check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS conversations (
mp_id TEXT PRIMARY KEY,
sample TEXT,
token_num INTEGER
)
''')
conn.commit()
return conn
# Save sample to SQLite database
def save_to_db(conn, mp_id, sample, total_token):
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO conversations (mp_id, sample, token_num)
VALUES (?, ?, ?)
''', (mp_id, str(sample), total_token))
conn.commit()
def read_cif_txt_file(file_path):
"""Read the markdown file and return its content."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return None
def round_values(data, precision=3):
"""
递归地将字典中的所有值保留三位小数
"""
if isinstance(data, dict): # 如果是字典
return {key: round_values(value) for key, value in data.items()}
elif isinstance(data, list): # 如果是列表,递归处理每个元素
return [round_values(item) for item in data]
elif isinstance(data, (int, float)): # 如果是数字,保留三位小数
return round(data, precision)
else: # 对其他类型,直接返回
return data
def decimal_to_fraction(decimal_value, max_denominator=1000):
"""
将小数转换为分数表示
参数:
decimal_value: 要转换的小数
max_denominator: 分母的最大值,用于控制精度
返回:
分数表示的字符串
"""
frac = Fraction(decimal_value).limit_denominator(max_denominator)
return f"{frac.numerator}/{frac.denominator}"
def poscar_to_fractional_representation(poscar_content, max_denominator=1000):
"""
将POSCAR文件中的数值转换为分数表示
参数:
poscar_content: POSCAR文件内容
max_denominator: 分母的最大值,用于控制精度
返回:
转换后的POSCAR内容数值以分数表示
"""
lines = poscar_content.strip().split('\n')
result_lines = []
# 保留系统名称
result_lines.append(lines[0])
# 保留缩放因子
scaling_factor = float(lines[1])
result_lines.append(lines[1])
# 处理晶格向量
for i in range(2, 5):
vector = [float(x) for x in lines[i].split()]
# 将每个分量转换为分数
fractional_vector = [decimal_to_fraction(x, max_denominator) for x in vector]
result_lines.append(" " + " ".join(fractional_vector))
# 保留元素类型和数量
if len(lines) > 5:
result_lines.append(lines[5])
if len(lines) > 6:
result_lines.append(lines[6])
# 保留坐标类型
if len(lines) > 7:
result_lines.append(lines[7])
# 处理原子坐标
for i in range(8, len(lines)):
parts = lines[i].split()
if len(parts) >= 3:
# 将坐标转换为分数
coords = [float(parts[j]) for j in range(3)]
fractional_coords = [decimal_to_fraction(x, max_denominator) for x in coords]
# 构建新行
new_line = " " + " ".join(fractional_coords)
if len(parts) > 3:
new_line += " " + " ".join(parts[3:])
result_lines.append(new_line)
else:
# 保留非坐标行
result_lines.append(lines[i])
return "\n".join(result_lines)
def remove_symmetry_equiv_xyz(cif_content):
"""
删除CIF文件中的对称性操作部分
参数:
cif_content: CIF文件内容字符串
返回:
清理后的CIF内容字符串
"""
lines = cif_content.split('\n')
output_lines = []
i = 0
while i < len(lines):
line = lines[i].strip()
# 检测循环开始
if line == 'loop_':
# 查看下一行,检查是否是对称性循环
next_lines = []
j = i + 1
while j < len(lines) and lines[j].strip().startswith('_'):
next_lines.append(lines[j].strip())
j += 1
# 检查是否包含对称性操作标签
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
# 跳过整个循环块
while i < len(lines):
if i + 1 >= len(lines):
break
next_line = lines[i + 1].strip()
# 检查是否到达下一个循环或数据块
if next_line == 'loop_' or next_line.startswith('data_'):
break
# 检查是否到达原子位置部分
if next_line.startswith('_atom_site_'):
break
i += 1
else:
# 不是对称性循环保留loop_行
output_lines.append(lines[i])
else:
# 非循环开始行,直接保留
output_lines.append(lines[i])
i += 1
return '\n'.join(output_lines)
def remove_null_values(d):
"""
Recursively remove key-value pairs with null (None) values from a dictionary.
Args:
d (dict): The dictionary to clean.
Returns:
dict: A new dictionary without null values.
"""
if not isinstance(d, dict):
raise ValueError("Input must be a dictionary")
_d = copy.deepcopy(d)
def recursive_remove(d):
cleaned_dict = {}
for key, value in d.items():
if isinstance(value, dict):
# Recursively clean nested dictionaries
nested_cleaned = recursive_remove(value)
if nested_cleaned: # Only add non-empty dictionaries
cleaned_dict[key] = nested_cleaned
elif value is not None and key != 'version':
cleaned_dict[key] = value
return cleaned_dict
clean_dict = recursive_remove(d)
if _d['cbm'] is None and _d['vbm'] is None and _d['band_gap'] is not None:
# clean_dict['band_gap'] = None
clean_dict.pop('band_gap')
return clean_dict
def get_extra_cif_info(path: str, fields_name: list):
"""Extract specific fields from the CIF description."""
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
# metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"]
selected_fields = []
if fields_name[0] == 'all_fields':
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
# selected_fields = energy_electronic_fields + metal_magentic_fields
else:
for field in fields_name:
selected_fields.extend(locals().get(field, []))
with open(path, 'r') as f:
docs = json.load(f)
new_docs = {}
for field_name in selected_fields:
new_docs[field_name] = docs.get(field_name, '')
# new_docs['structure'] = {"lattice": docs['structure']['lattice']}
return new_docs
def extract_json(text):
"""Extract JSON content from a block of text using regex."""
json_pattern = re.compile(r'\\{(?:[^{}]|(?R))*\\}')
matches = json_pattern.search(text)
if matches:
json_str = matches.group(0)
try:
return json.loads(json_str)
except json.JSONDecodeError:
return None
return None
def extract_and_parse_json(response):
"""Extract and parse JSON from a response."""
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
json_str = json_match.group(1) if json_match else response.strip()
json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str)
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"JSON parse error: {e}")
return 'errformat'
# 计算输入消息的tokens
def count_message_tokens(messages, model_name):
encoding = tiktoken.encoding_for_model(model_name)
num_tokens = 0
num_tokens += len(encoding.encode(messages))
return num_tokens
def make_multi_turns_sharegpt_sample(humans: list[str], gpts: list[str], system: str="{SYSTEM}"):
sample = {}
conversations = []
if system is not None and system != "":
sample["system"] = system
assert len(humans) !=0, "human cannot be None"
assert len(gpts) == len(humans), "human and gpt must have the same length"
for human, gpt in zip(humans, gpts):
if human is not None and human != "":
assert gpt is not None, "gpt cannot be None"
assert gpt != "", "gpt cannot be empty"
# 下列顺序不可改
conversations.append({"from": "human", "value": human})
conversations.append({"from": "gpt", "value": gpt})
sample["conversations"] = conversations
return sample
##################################### utils