Files
mars-mcp/generate_data/generate_data10000.py

434 lines
16 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 原始数据分为两类 一种是带solution的一种是没有solution的这个是提取了各5000条
import json
import asyncio
import concurrent.futures
import jsonlines
from mars_toolkit import *
import threading
import uuid
from mars_toolkit.compute.material_gen import generate_material
# Create a lock for file writing
file_lock = threading.Lock()
from mysql.connector import pooling
from colorama import Fore, Back, Style, init
import time
import random
# 初始化colorama
init(autoreset=True)
from typing import Dict, Union, Any, Optional, List
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化传递给generate_material函数的参数格式。
处理以下情况:
1. properties参数可能是字符串形式的JSON需要解析为字典
2. properties中的值可能需要转换为适当的类型数字或字符串
3. 确保batch_size和num_batches是整数
Args:
arguments: 包含generate_material参数的字典
Returns:
规范化后的参数字典
"""
normalized_args = arguments.copy()
# 处理properties参数
if "properties" in normalized_args:
properties = normalized_args["properties"]
# 如果properties是字符串尝试解析为JSON
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError as e:
raise ValueError(f"无法解析properties JSON字符串: {e}")
# 确保properties是字典
if not isinstance(properties, dict):
raise ValueError(f"properties必须是字典或JSON字符串而不是 {type(properties)}")
# 处理properties中的值
normalized_properties = {}
for key, value in properties.items():
# 处理范围值,例如 "0.0-2.0" 或 "40-50"
if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"):
# 保持范围值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith(">"):
# 保持大于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith("<"):
# 保持小于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.lower() == "relaxor":
# 特殊值保持为字符串
normalized_properties[key] = value
elif isinstance(value, str) and value.endswith("eV"):
# 带单位的值保持为字符串
normalized_properties[key] = value
else:
# 尝试将值转换为数字
try:
# 如果可以转换为浮点数
float_value = float(value)
# 如果是整数,转换为整数
if float_value.is_integer():
normalized_properties[key] = int(float_value)
else:
normalized_properties[key] = float_value
except (ValueError, TypeError):
# 如果无法转换为数字,保持原值
normalized_properties[key] = value
normalized_args["properties"] = normalized_properties
# 确保batch_size和num_batches是整数
if "batch_size" in normalized_args:
try:
normalized_args["batch_size"] = int(normalized_args["batch_size"])
except (ValueError, TypeError):
raise ValueError(f"batch_size必须是整数而不是 {normalized_args['batch_size']}")
if "num_batches" in normalized_args:
try:
normalized_args["num_batches"] = int(normalized_args["num_batches"])
except (ValueError, TypeError):
raise ValueError(f"num_batches必须是整数而不是 {normalized_args['num_batches']}")
# 确保diffusion_guidance_factor是浮点数
if "diffusion_guidance_factor" in normalized_args:
try:
normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"])
except (ValueError, TypeError):
raise ValueError(f"diffusion_guidance_factor必须是数字而不是 {normalized_args['diffusion_guidance_factor']}")
return normalized_args
import requests
# 创建数据库连接池(仅用于初始加载数据)
connection_pool = pooling.MySQLConnectionPool(
pool_name="mypool",
pool_size=32,
pool_reset_session=True,
host='localhost',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
# 内存缓存,用于存储从数据库加载的数据
# 结构: {doi: record, mp_id: record}
memory_cache = {}
def load_data_to_memory():
"""
从数据库加载所有数据到内存中
"""
print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}")
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
# 查询所有记录
cursor.execute("SELECT * FROM mp_synthesis_scheme_info")
records = cursor.fetchall()
# 将记录添加到内存缓存中
for record in records:
doi = record.get('doi')
mp_id = record.get('mp_id')
# 使用doi作为键如果存在
if doi:
memory_cache[doi] = record
# 使用mp_id作为键如果存在
if mp_id:
memory_cache[mp_id] = record
print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}")
print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}")
finally:
cursor.close()
finally:
conn.close()
# 在程序启动时加载数据到内存中
load_data_to_memory()
async def process_retrieval_from_knowledge_base(data):
doi = data.get('doi')
mp_id = data.get('mp_id')
# 检查是否提供了至少一个查询参数
if doi is None and mp_id is None:
return "" # 如果没有提供查询参数,返回空字符串
# 从内存缓存中查询匹配的记录
result = None
if doi is not None and doi in memory_cache:
result = memory_cache[doi]
elif mp_id is not None and mp_id in memory_cache:
result = memory_cache[mp_id]
# 检查是否找到匹配的记录
if not result:
return "" # 如果没有找到匹配记录,返回空字符串
# 构建markdown格式的结果
markdown_result = ""
# 添加各个字段除了doi和mp_id
fields = [
"target_material",
"reaction_string",
"chara_structure",
"chara_performance",
"chara_application",
"synthesis_schemes"
]
for field in fields:
# 获取字段内容
field_content = result.get(field, "")
# 只有当字段内容不为空时才添加该字段
if field_content and field_content.strip():
markdown_result += f"\n## {field}\n{field_content}\n\n"
return markdown_result # 直接返回markdown文本
async def execute_tool_from_dict(input_dict: dict):
"""
从字典中提取工具函数名称和参数,并执行相应的工具函数
Args:
input_dict: 字典,例如:
{"name": "search_material_property_from_material_project",
"arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}
Returns:
工具函数的执行结果,如果工具函数不存在则返回错误信息
"""
try:
# 解析输入字符串为字典
# input_dict = json.loads(input_str)
# 提取函数名和参数
func_name = input_dict.get("name")
arguments_data = input_dict.get("arguments")
#print('func_name', func_name)
#print("argument", arguments_data)
if not func_name:
return {"status": "error", "message": "未提供函数名称"}
# 获取所有注册的工具函数
tools = get_tools()
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"}
# 获取对应的工具函数
tool_func = tools[func_name]
# 处理参数
arguments = {}
if arguments_data:
# 检查arguments是字符串还是字典
if isinstance(arguments_data, dict):
# 如果已经是字典,直接使用
arguments = arguments_data
elif isinstance(arguments_data, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(arguments_data)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": arguments_data}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
except json.JSONDecodeError as e:
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
except Exception as e:
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
def worker(data, output_file_path):
try:
func_contents = data["function_calls"]
func_results = []
formatted_results = [] # 新增一个列表来存储格式化后的结果
for func in func_contents:
func_name = func.get("name")
arguments_data = func.get("arguments")
# 使用富文本打印函数名
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# 使用富文本打印参数
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
try:
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
continue
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
# 优先使用mattergen函数
try:
output = generate_material(**normalized_args)
except Exception as e:
#print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
continue
# 将结果添加到func_results
func_results.append({"function": func_name, "result": output})
# 格式化结果
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
continue
else:
result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
func_name = func.get("name")
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['observation'] = final_result
#使用富文本打印开始和结束标记
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
# print(data['observation'])
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}"
def main(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing papers")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_path = {}
for path in datas:
future = executor.submit(worker, path, output_file_path)
future_to_path[future] = path
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_path):
path = future_to_path[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {path} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
import jsonlines
datas_with_solution = []
datas_without_solution = []
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
for obj in reader:
if obj['solution']!='':
datas_with_solution.append(obj)
else:
datas_without_solution.append(obj)
datas_with_solution = datas_with_solution[:5000]
datas_without_solution = datas_without_solution[:5000]
datas = datas_with_solution + datas_without_solution
import random
random.shuffle(datas)
output_file = f"./filter_ok_questions_solutions_agent_data10000_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=32)