import json
import asyncio
import concurrent.futures
import sys
sys.path.append('/home/ubuntu/sas0/lzy/mars-mcp/')
import jsonlines
from mars_toolkit import *
import threading
import uuid
from mars_toolkit.compute.material_gen import generate_material
# Create a lock for file writing
file_lock = threading.Lock()
from mysql.connector import pooling
from colorama import Fore, Back, Style, init
import time
import random
# 初始化colorama
init(autoreset=True)
from typing import Dict, Union, Any, Optional, List
import re
def extract_tool_calls(text):
"""
提取字符串中所有包裹在\n\n 中的JSON内容并转换为字典列表
参数:
text (str): 包含工具调用的文本
返回:
list: 包含所有工具调用的字典列表
"""
# 使用正则表达式提取和之间的内容
# (?s)表示让.也匹配换行符,使模式可以跨行匹配
pattern = r'\n(.*?)'
matches = re.finditer(pattern, text, re.DOTALL)
tool_calls = []
for match in matches:
json_str = match.group(1).strip()
try:
# 将JSON字符串转换为Python字典
tool_call_dict = json.loads(json_str)
tool_calls.append(tool_call_dict)
except json.JSONDecodeError as e:
tool_calls.append(f"无法解析JSON: {e},问题字符串{json_str}")
return tool_calls
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化传递给generate_material函数的参数格式。
处理以下情况:
1. properties参数可能是字符串形式的JSON,需要解析为字典
2. properties中的值可能需要转换为适当的类型(数字或字符串)
3. 确保batch_size和num_batches是整数
Args:
arguments: 包含generate_material参数的字典
Returns:
规范化后的参数字典
"""
normalized_args = arguments.copy()
# 处理properties参数
if "properties" in normalized_args:
properties = normalized_args["properties"]
# 如果properties是字符串,尝试解析为JSON
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError as e:
raise ValueError(f"无法解析properties JSON字符串: {e}")
# 确保properties是字典
if not isinstance(properties, dict):
raise ValueError(f"properties必须是字典或JSON字符串,而不是 {type(properties)}")
# 处理properties中的值
normalized_properties = {}
for key, value in properties.items():
# 处理范围值,例如 "0.0-2.0" 或 "40-50"
if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"):
# 保持范围值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith(">"):
# 保持大于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith("<"):
# 保持小于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.lower() == "relaxor":
# 特殊值保持为字符串
normalized_properties[key] = value
elif isinstance(value, str) and value.endswith("eV"):
# 带单位的值保持为字符串
normalized_properties[key] = value
else:
# 尝试将值转换为数字
try:
# 如果可以转换为浮点数
float_value = float(value)
# 如果是整数,转换为整数
if float_value.is_integer():
normalized_properties[key] = int(float_value)
else:
normalized_properties[key] = float_value
except (ValueError, TypeError):
# 如果无法转换为数字,保持原值
normalized_properties[key] = value
normalized_args["properties"] = normalized_properties
# 确保batch_size和num_batches是整数
if "batch_size" in normalized_args:
try:
normalized_args["batch_size"] = int(normalized_args["batch_size"])
except (ValueError, TypeError):
raise ValueError(f"batch_size必须是整数,而不是 {normalized_args['batch_size']}")
if "num_batches" in normalized_args:
try:
normalized_args["num_batches"] = int(normalized_args["num_batches"])
except (ValueError, TypeError):
raise ValueError(f"num_batches必须是整数,而不是 {normalized_args['num_batches']}")
# 确保diffusion_guidance_factor是浮点数
if "diffusion_guidance_factor" in normalized_args:
try:
normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"])
except (ValueError, TypeError):
raise ValueError(f"diffusion_guidance_factor必须是数字,而不是 {normalized_args['diffusion_guidance_factor']}")
return normalized_args
import requests
# 创建数据库连接池(仅用于初始加载数据)
connection_pool = pooling.MySQLConnectionPool(
pool_name="mypool",
pool_size=32,
pool_reset_session=True,
host='localhost',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
# 内存缓存,用于存储从数据库加载的数据
# 结构: {doi: record, mp_id: record}
memory_cache = {}
def load_data_to_memory():
"""
从数据库加载所有数据到内存中
"""
print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}")
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
# 查询所有记录
cursor.execute("SELECT * FROM mp_synthesis_scheme_info")
records = cursor.fetchall()
# 将记录添加到内存缓存中
for record in records:
doi = record.get('doi')
mp_id = record.get('mp_id')
# 使用doi作为键(如果存在)
if doi:
memory_cache[doi] = record
# 使用mp_id作为键(如果存在)
if mp_id:
memory_cache[mp_id] = record
print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}")
print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}")
finally:
cursor.close()
finally:
conn.close()
# 在程序启动时加载数据到内存中
load_data_to_memory()
async def process_retrieval_from_knowledge_base(data):
doi = data.get('doi')
mp_id = data.get('mp_id')
# 检查是否提供了至少一个查询参数
if doi is None and mp_id is None:
return "" # 如果没有提供查询参数,返回空字符串
# 从内存缓存中查询匹配的记录
result = None
if doi is not None and doi in memory_cache:
result = memory_cache[doi]
elif mp_id is not None and mp_id in memory_cache:
result = memory_cache[mp_id]
# 检查是否找到匹配的记录
if not result:
return "" # 如果没有找到匹配记录,返回空字符串
# 构建markdown格式的结果
markdown_result = ""
# 添加各个字段(除了doi和mp_id)
fields = [
"target_material",
"reaction_string",
"chara_structure",
"chara_performance",
"chara_application",
"synthesis_schemes"
]
for field in fields:
# 获取字段内容
field_content = result.get(field, "")
# 只有当字段内容不为空时才添加该字段
if field_content and field_content.strip():
markdown_result += f"\n## {field}\n{field_content}\n\n"
return markdown_result # 直接返回markdown文本
async def execute_tool_from_dict(input_dict: dict):
"""
从字典中提取工具函数名称和参数,并执行相应的工具函数
Args:
input_dict: 字典,例如:
{"name": "search_material_property_from_material_project",
"arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}
Returns:
工具函数的执行结果,如果工具函数不存在则返回错误信息
"""
try:
# 解析输入字符串为字典
# input_dict = json.loads(input_str)
# 提取函数名和参数
func_name = input_dict.get("name")
arguments_data = input_dict.get("arguments")
#print('func_name', func_name)
#print("argument", arguments_data)
if not func_name:
return {"status": "error", "message": "未提供函数名称"}
# 获取所有注册的工具函数
tools = get_tools()
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
return f"函数 '{func_name}' 不存在于工具函数字典中"
# 获取对应的工具函数
tool_func = tools[func_name]
# 处理参数
arguments = {}
if arguments_data:
# 检查arguments是字符串还是字典
if isinstance(arguments_data, dict):
# 如果已经是字典,直接使用
arguments = arguments_data
elif isinstance(arguments_data, str):
# 如果是字符串,尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(arguments_data)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": arguments_data}
# 调用工具函数
try:
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数,使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
except Exception as e:
result = f'工具函数调用时出错:str{e}'
# if func_name=='generate_material':
# print("xxxxx",result)
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
return formatted_result
except json.JSONDecodeError as e:
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
except Exception as e:
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
pass
def worker(data, output_file_path):
try:
tool_call_str = data['messages'][-1]['content'].split("")[-1]
if '' in tool_call_str:
func_contents=extract_tool_calls(tool_call_str)
#print(func_contents)
#func_results = []
formatted_results = [] # 新增一个列表来存储格式化后的结果
for func in func_contents:
if isinstance(func,Dict):
func_name = func.get("name")
arguments_data = func.get("arguments")
# 使用富文本打印函数名
#print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# 使用富文本打印参数
#print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
try:
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
#print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
continue
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
# 优先使用mattergen函数
try:
output = generate_material(**normalized_args)
except Exception as e:
#print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
formatted_result = f"调用时出错,请检查输入的参数,异常为{e}"
# 格式化结果
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
#print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
#print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
continue
else:
formatted_result = asyncio.run(execute_tool_from_dict(func))
formatted_results.append(formatted_result)
else:
formatted_results.append(func)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['messages'].append({"role": "user", "content": final_result})
#print("last message",data["messages"][-1])
#使用富文本打印开始和结束标记
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
# print(data['observation'])
# print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
else:
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:
#print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}"
def main(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing papers")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_path = {}
for path in datas:
future = executor.submit(worker, path, output_file_path)
future_to_path[future] = path
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_path):
path = future_to_path[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {path} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
import jsonlines
datas = []
total_count=0
filtered_count=0
with jsonlines.open('/home/ubuntu/sas0/lzy/mars-mcp/generate_data/agent_questions_solutions_qwq1.jsonl') as reader:
for obj in reader:
datas.append(obj)
for data in datas:
tool_call_str=data['messages'][-1]['content'].split("\n")[-1]
if '' in tool_call_str:
filtered_count+=1
total_count+=1
print("total count",total_count)
print("filtered count",filtered_count)
# for data in datas[:5]:
# tool_call_str=data['messages'][-1]['content'].split("\n")[-1].split("")[0]
# tool_call_dict_list=extract_tool_calls(tool_call_str)
# for tool_call_dict in tool_call_dict_list:
# print("tool name",tool_call_dict['name'])
# print("tool arguments",tool_call_dict['arguments'])
# print("xxx")
# print("==="*20)
# # print()
# exit()
output_file = f"./agent_questions_solution_turn2_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=48)
# 示例1:使用正确的JSON格式
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
# argument = json.loads(argument)
# print(json.dumps(argument, indent=2))
# asyncio.run(mattergen(**argument))