import json import asyncio import concurrent.futures from tools_for_ms.llm_tools import * import threading # Create a lock for file writing file_lock = threading.Lock() from mysql.connector import pooling connection_pool = pooling.MySQLConnectionPool( pool_name="mypool", pool_size=32, pool_reset_session=True, host='localhost', user='metadata_mat_papers', password='siat-mic', database='metadata_mat_papers' ) def process_retrieval_from_knowledge_base(data): doi = data.get('doi') mp_id = data.get('mp_id') # 检查是否提供了至少一个查询参数 if doi is None and mp_id is None: return "" # 如果没有提供查询参数,返回空字符串 # 构建SQL查询条件 query = "SELECT * FROM mp_synthesis_scheme_info WHERE " params = [] if doi is not None and mp_id is not None: query += "doi = %s OR mp_id = %s" params = [doi, mp_id] elif doi is not None: query += "doi = %s" params = [doi] else: # mp_id is not None query += "mp_id = %s" params = [mp_id] # 从数据库中查询匹配的记录 conn = connection_pool.get_connection() try: cursor = conn.cursor(dictionary=True) try: cursor.execute(query, params) result = cursor.fetchone() # 获取第一个匹配的记录 finally: cursor.close() finally: conn.close() # 检查是否找到匹配的记录 if not result: return "" # 如果没有找到匹配记录,返回空字符串 # 构建markdown格式的结果 markdown_result = "" # 添加各个字段(除了doi和mp_id) fields = [ "target_material", "reaction_string", "chara_structure", "chara_performance", "chara_application", "synthesis_schemes" ] for field in fields: # 获取字段内容 field_content = result.get(field, "") # 只有当字段内容不为空时才添加该字段 if field_content and field_content.strip(): markdown_result += f"\n## {field}\n{field_content}\n\n" return markdown_result # 直接返回markdown文本 async def execute_tool_from_dict(input_dict: dict): """ 从字典中提取工具函数名称和参数,并执行相应的工具函数 Args: input_dict: 字典,例如: {"name": "search_material_property_from_material_project", "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"} Returns: 工具函数的执行结果,如果工具函数不存在则返回错误信息 """ try: # 解析输入字符串为字典 # input_dict = json.loads(input_str) # 提取函数名和参数 func_name = input_dict.get("name") arguments_data = input_dict.get("arguments") #print('func_name', func_name) #print("argument", arguments_data) if not func_name: return {"status": "error", "message": "未提供函数名称"} # 获取所有注册的工具函数 tools = get_tools() # 检查函数名是否存在于工具函数字典中 if func_name not in tools: return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"} # 获取对应的工具函数 tool_func = tools[func_name] # 处理参数 arguments = {} if arguments_data: # 检查arguments是字符串还是字典 if isinstance(arguments_data, dict): # 如果已经是字典,直接使用 arguments = arguments_data elif isinstance(arguments_data, str): # 如果是字符串,尝试解析为JSON try: # 尝试直接解析为JSON对象 arguments = json.loads(arguments_data) except json.JSONDecodeError: # 如果解析失败,可能是因为字符串中包含转义字符 # 尝试修复常见的JSON字符串问题 fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\') try: arguments = json.loads(fixed_str) except json.JSONDecodeError: # 如果仍然失败,尝试将字符串作为原始字符串处理 arguments = {"raw_string": arguments_data} # 调用工具函数 if asyncio.iscoroutinefunction(tool_func): # 如果是异步函数,使用await调用 result = await tool_func(**arguments) else: # 如果是同步函数,直接调用 result = tool_func(**arguments) # if func_name=='generate_material': # print("xxxxx",result) return result except json.JSONDecodeError as e: return {"status": "error", "message": f"JSON解析错误: {str(e)}"} except Exception as e: return {"status": "error", "message": f"执行过程中出错: {str(e)}"} # # 示例用法 # if __name__ == "__main__": # # 示例输入 # input_str = '{"name": "search_material_property_from_material_project", "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}' # # 调用函数 # result = asyncio.run(execute_tool_from_string(input_str)) # print(result) def worker(data, output_file_path): try: # rich.console.Console().print(tools_schema) # print(tools_schema) func_contents = data["function_calls"] func_results = [] formatted_results = [] # 新增一个列表来存储格式化后的结果 for func in func_contents: if func.get("name") == 'retrieval_from_knowledge_base': func_name = func.get("name") arguments_data = func.get("arguments") # print('func_name', func_name) # print("argument", arguments_data) result = process_retrieval_from_knowledge_base(data) func_results.append({"function": func['name'], "result": result}) # 格式化结果 formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" formatted_results.append(formatted_result) else: result = asyncio.run(execute_tool_from_dict(func)) func_results.append({"function": func['name'], "result": result}) # 格式化结果 func_name = func.get("name") formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" formatted_results.append(formatted_result) # 将所有格式化后的结果连接起来 final_result = "\n\n\n".join(formatted_results) data['obeservation']=final_result # print("#"*50,"start","#"*50) # print(data['obeservation']) # print("#"*50,'end',"#"*50) #return final_result # 返回格式化后的结果,而不是固定消息 with file_lock: with jsonlines.open(output_file_path, mode='a') as writer: writer.write(data) # obeservation . data return f"Processed successfully" except Exception as e: return f"Error processing: {str(e)}" def main(datas, output_file_path, max_workers=1): import random from tqdm import tqdm import os from mysql.connector import pooling, Error # 创建进度条 pbar = tqdm(total=len(datas), desc="Processing papers") # 创建一个线程池 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # 提交任务到执行器 future_to_path = {} for path in datas: future = executor.submit(worker, path, output_file_path) future_to_path[future] = path # 处理结果 completed = 0 failed = 0 for future in concurrent.futures.as_completed(future_to_path): path = future_to_path[future] try: result = future.result() if "successfully" in result: completed += 1 else: failed += 1 # 更新进度条 pbar.update(1) # 每100个文件更新一次统计信息 if (completed + failed) % 100 == 0: pbar.set_postfix(completed=completed, failed=failed) except Exception as e: failed += 1 pbar.update(1) print(f"\nWorker for {path} generated an exception: {e}") pbar.close() print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") if __name__ == '__main__': import datetime import jsonlines datas = [] with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader: for obj in reader: datas.append(obj) print(len(datas)) # print() output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" main(datas, output_file,max_workers=8) # print("开始测试 process_retrieval_from_knowledge_base 函数...") # data={'doi':'10.1016_s0025-5408(01)00495-0','mp_id':None} # result = process_retrieval_from_knowledge_base(data) # print("函数执行结果:") # print(result) # print("测试完成")