import json import asyncio import concurrent.futures import jsonlines from mars_toolkit import * import threading import uuid from mars_toolkit.compute.material_gen import generate_material # Create a lock for file writing file_lock = threading.Lock() from mysql.connector import pooling from colorama import Fore, Back, Style, init import time import random # 初始化colorama init(autoreset=True) from typing import Dict, Union, Any, Optional def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: """ 规范化传递给generate_material函数的参数格式。 处理以下情况: 1. properties参数可能是字符串形式的JSON,需要解析为字典 2. properties中的值可能需要转换为适当的类型(数字或字符串) 3. 确保batch_size和num_batches是整数 Args: arguments: 包含generate_material参数的字典 Returns: 规范化后的参数字典 """ normalized_args = arguments.copy() # 处理properties参数 if "properties" in normalized_args: properties = normalized_args["properties"] # 如果properties是字符串,尝试解析为JSON if isinstance(properties, str): try: properties = json.loads(properties) except json.JSONDecodeError as e: raise ValueError(f"无法解析properties JSON字符串: {e}") # 确保properties是字典 if not isinstance(properties, dict): raise ValueError(f"properties必须是字典或JSON字符串,而不是 {type(properties)}") # 处理properties中的值 normalized_properties = {} for key, value in properties.items(): # 处理范围值,例如 "0.0-2.0" 或 "40-50" if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"): # 保持范围值为字符串格式 normalized_properties[key] = value elif isinstance(value, str) and value.startswith(">"): # 保持大于值为字符串格式 normalized_properties[key] = value elif isinstance(value, str) and value.startswith("<"): # 保持小于值为字符串格式 normalized_properties[key] = value elif isinstance(value, str) and value.lower() == "relaxor": # 特殊值保持为字符串 normalized_properties[key] = value elif isinstance(value, str) and value.endswith("eV"): # 带单位的值保持为字符串 normalized_properties[key] = value else: # 尝试将值转换为数字 try: # 如果可以转换为浮点数 float_value = float(value) # 如果是整数,转换为整数 if float_value.is_integer(): normalized_properties[key] = int(float_value) else: normalized_properties[key] = float_value except (ValueError, TypeError): # 如果无法转换为数字,保持原值 normalized_properties[key] = value normalized_args["properties"] = normalized_properties # 确保batch_size和num_batches是整数 if "batch_size" in normalized_args: try: normalized_args["batch_size"] = int(normalized_args["batch_size"]) except (ValueError, TypeError): raise ValueError(f"batch_size必须是整数,而不是 {normalized_args['batch_size']}") if "num_batches" in normalized_args: try: normalized_args["num_batches"] = int(normalized_args["num_batches"]) except (ValueError, TypeError): raise ValueError(f"num_batches必须是整数,而不是 {normalized_args['num_batches']}") # 确保diffusion_guidance_factor是浮点数 if "diffusion_guidance_factor" in normalized_args: try: normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"]) except (ValueError, TypeError): raise ValueError(f"diffusion_guidance_factor必须是数字,而不是 {normalized_args['diffusion_guidance_factor']}") return normalized_args import requests connection_pool = pooling.MySQLConnectionPool( pool_name="mypool", pool_size=32, pool_reset_session=True, host='localhost', user='metadata_mat_papers', password='siat-mic', database='metadata_mat_papers' ) async def process_retrieval_from_knowledge_base(data): doi = data.get('doi') mp_id = data.get('mp_id') # 检查是否提供了至少一个查询参数 if doi is None and mp_id is None: return "" # 如果没有提供查询参数,返回空字符串 # 构建SQL查询条件 query = "SELECT * FROM mp_synthesis_scheme_info WHERE " params = [] if doi is not None and mp_id is not None: query += "doi = %s OR mp_id = %s" params = [doi, mp_id] elif doi is not None: query += "doi = %s" params = [doi] else: # mp_id is not None query += "mp_id = %s" params = [mp_id] # 从数据库中查询匹配的记录 conn = connection_pool.get_connection() try: cursor = conn.cursor(dictionary=True) try: cursor.execute(query, params) result = cursor.fetchone() # 获取第一个匹配的记录 finally: cursor.close() finally: conn.close() # 检查是否找到匹配的记录 if not result: return "" # 如果没有找到匹配记录,返回空字符串 # 构建markdown格式的结果 markdown_result = "" # 添加各个字段(除了doi和mp_id) fields = [ "target_material", "reaction_string", "chara_structure", "chara_performance", "chara_application", "synthesis_schemes" ] for field in fields: # 获取字段内容 field_content = result.get(field, "") # 只有当字段内容不为空时才添加该字段 if field_content and field_content.strip(): markdown_result += f"\n## {field}\n{field_content}\n\n" return markdown_result # 直接返回markdown文本 async def execute_tool_from_dict(input_dict: dict): """ 从字典中提取工具函数名称和参数,并执行相应的工具函数 Args: input_dict: 字典,例如: {"name": "search_material_property_from_material_project", "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"} Returns: 工具函数的执行结果,如果工具函数不存在则返回错误信息 """ try: # 解析输入字符串为字典 # input_dict = json.loads(input_str) # 提取函数名和参数 func_name = input_dict.get("name") arguments_data = input_dict.get("arguments") #print('func_name', func_name) #print("argument", arguments_data) if not func_name: return {"status": "error", "message": "未提供函数名称"} # 获取所有注册的工具函数 tools = get_tools() # 检查函数名是否存在于工具函数字典中 if func_name not in tools: return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"} # 获取对应的工具函数 tool_func = tools[func_name] # 处理参数 arguments = {} if arguments_data: # 检查arguments是字符串还是字典 if isinstance(arguments_data, dict): # 如果已经是字典,直接使用 arguments = arguments_data elif isinstance(arguments_data, str): # 如果是字符串,尝试解析为JSON try: # 尝试直接解析为JSON对象 arguments = json.loads(arguments_data) except json.JSONDecodeError: # 如果解析失败,可能是因为字符串中包含转义字符 # 尝试修复常见的JSON字符串问题 fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\') try: arguments = json.loads(fixed_str) except json.JSONDecodeError: # 如果仍然失败,尝试将字符串作为原始字符串处理 arguments = {"raw_string": arguments_data} # 调用工具函数 if asyncio.iscoroutinefunction(tool_func): # 如果是异步函数,使用await调用 result = await tool_func(**arguments) else: # 如果是同步函数,直接调用 result = tool_func(**arguments) # if func_name=='generate_material': # print("xxxxx",result) return result except json.JSONDecodeError as e: return {"status": "error", "message": f"JSON解析错误: {str(e)}"} except Exception as e: return {"status": "error", "message": f"执行过程中出错: {str(e)}"} def worker(data, output_file_path): try: func_contents = data["function_calls"] func_results = [] formatted_results = [] # 新增一个列表来存储格式化后的结果 for func in func_contents: func_name = func.get("name") arguments_data = func.get("arguments") # 使用富文本打印函数名 print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") # 使用富文本打印参数 print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") if func.get("name") == 'retrieval_from_knowledge_base': pass # delay_time = random.uniform(5, 10) # time.sleep(delay_time) result = asyncio.run(process_retrieval_from_knowledge_base(data)) func_results.append({"function": func['name'], "result": result}) # 格式化结果 formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" formatted_results.append(formatted_result) elif func.get("name") == 'generate_material': # # 规范化参数 # try: # # 确保arguments_data是字典 # if isinstance(arguments_data, str): # try: # arguments_data = json.loads(arguments_data) # except json.JSONDecodeError as e: # print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") # continue # # 规范化参数 # normalized_args = normalize_material_args(arguments_data) # # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") # # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") # # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") # # 优先使用mattergen函数 # try: # # output = asyncio.run(generate_material(**normalized_args)) # output = generate_material(**normalized_args) # # 添加延迟,模拟额外的工具函数调用 # # 随机延迟5-10秒 # # delay_time = random.uniform(5, 10) # # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") # # time.sleep(delay_time) # # # 模拟其他工具函数调用的日志输出 # # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") # # time.sleep(0.5) # # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") # # time.sleep(0.5) # # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") # # time.sleep(0.5) # # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") # except Exception as e: # print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") # # 将结果添加到func_results # func_results.append({"function": func_name, "result": output}) # # 格式化结果 # formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" # formatted_results.append(formatted_result) # except Exception as e: # print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") # import traceback # print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") pass else: # delay_time = random.uniform(5, 10) # time.sleep(delay_time) result = asyncio.run(execute_tool_from_dict(func)) func_results.append({"function": func['name'], "result": result}) # 格式化结果 func_name = func.get("name") formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" formatted_results.append(formatted_result) # 将所有格式化后的结果连接起来 final_result = "\n\n\n".join(formatted_results) data['observation'] = final_result # 使用富文本打印开始和结束标记 print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") print(data['observation']) print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") with file_lock: with jsonlines.open(output_file_path, mode='a') as writer: writer.write(data) # observation . data return f"Processed successfully" except Exception as e: print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") return f"Error processing: {str(e)}" def main(datas, output_file_path, max_workers=1): import random from tqdm import tqdm import os from mysql.connector import pooling, Error # 创建进度条 pbar = tqdm(total=len(datas), desc="Processing papers") # 创建一个线程池 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # 提交任务到执行器 future_to_path = {} for path in datas: future = executor.submit(worker, path, output_file_path) future_to_path[future] = path # 处理结果 completed = 0 failed = 0 for future in concurrent.futures.as_completed(future_to_path): path = future_to_path[future] try: result = future.result() if "successfully" in result: completed += 1 else: failed += 1 # 更新进度条 pbar.update(1) # 每100个文件更新一次统计信息 if (completed + failed) % 100 == 0: pbar.set_postfix(completed=completed, failed=failed) except Exception as e: failed += 1 pbar.update(1) print(f"\nWorker for {path} generated an exception: {e}") pbar.close() print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") if __name__ == '__main__': import datetime import jsonlines datas = [] with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader: for obj in reader: datas.append(obj) print(len(datas)) # print() output_file = f"./filter_ok_questions_solutions_agent_other_tools_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" main(datas, output_file, max_workers=32) # 示例1:使用正确的JSON格式 # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' # argument = json.loads(argument) # print(json.dumps(argument, indent=2)) # asyncio.run(mattergen(**argument))