import json import asyncio import concurrent.futures import jsonlines from mars_toolkit import * import threading import uuid # Create a lock for file writing file_lock = threading.Lock() from mysql.connector import pooling from colorama import Fore, Back, Style, init import time import random # 初始化colorama init(autoreset=True) from typing import Dict, Union, Any, Optional def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: """ 规范化传递给generate_material函数的参数格式。 处理以下情况: 1. properties参数可能是字符串形式的JSON,需要解析为字典 2. properties中的值可能需要转换为适当的类型(数字或字符串) 3. 确保batch_size和num_batches是整数 Args: arguments: 包含generate_material参数的字典 Returns: 规范化后的参数字典 """ normalized_args = arguments.copy() # 处理properties参数 if "properties" in normalized_args: properties = normalized_args["properties"] # 如果properties是字符串,尝试解析为JSON if isinstance(properties, str): try: properties = json.loads(properties) except json.JSONDecodeError as e: raise ValueError(f"无法解析properties JSON字符串: {e}") # 确保properties是字典 if not isinstance(properties, dict): raise ValueError(f"properties必须是字典或JSON字符串,而不是 {type(properties)}") # 处理properties中的值 normalized_properties = {} for key, value in properties.items(): # 处理范围值,例如 "0.0-2.0" 或 "40-50" if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"): # 保持范围值为字符串格式 normalized_properties[key] = value elif isinstance(value, str) and value.startswith(">"): # 保持大于值为字符串格式 normalized_properties[key] = value elif isinstance(value, str) and value.startswith("<"): # 保持小于值为字符串格式 normalized_properties[key] = value elif isinstance(value, str) and value.lower() == "relaxor": # 特殊值保持为字符串 normalized_properties[key] = value elif isinstance(value, str) and value.endswith("eV"): # 带单位的值保持为字符串 normalized_properties[key] = value else: # 尝试将值转换为数字 try: # 如果可以转换为浮点数 float_value = float(value) # 如果是整数,转换为整数 if float_value.is_integer(): normalized_properties[key] = int(float_value) else: normalized_properties[key] = float_value except (ValueError, TypeError): # 如果无法转换为数字,保持原值 normalized_properties[key] = value normalized_args["properties"] = normalized_properties # 确保batch_size和num_batches是整数 if "batch_size" in normalized_args: try: normalized_args["batch_size"] = int(normalized_args["batch_size"]) except (ValueError, TypeError): raise ValueError(f"batch_size必须是整数,而不是 {normalized_args['batch_size']}") if "num_batches" in normalized_args: try: normalized_args["num_batches"] = int(normalized_args["num_batches"]) except (ValueError, TypeError): raise ValueError(f"num_batches必须是整数,而不是 {normalized_args['num_batches']}") # 确保diffusion_guidance_factor是浮点数 if "diffusion_guidance_factor" in normalized_args: try: normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"]) except (ValueError, TypeError): raise ValueError(f"diffusion_guidance_factor必须是数字,而不是 {normalized_args['diffusion_guidance_factor']}") return normalized_args import requests connection_pool = pooling.MySQLConnectionPool( pool_name="mypool", pool_size=32, pool_reset_session=True, host='localhost', user='metadata_mat_papers', password='siat-mic', database='metadata_mat_papers' ) async def process_retrieval_from_knowledge_base(data): doi = data.get('doi') mp_id = data.get('mp_id') # 检查是否提供了至少一个查询参数 if doi is None and mp_id is None: return "" # 如果没有提供查询参数,返回空字符串 # 构建SQL查询条件 query = "SELECT * FROM mp_synthesis_scheme_info WHERE " params = [] if doi is not None and mp_id is not None: query += "doi = %s OR mp_id = %s" params = [doi, mp_id] elif doi is not None: query += "doi = %s" params = [doi] else: # mp_id is not None query += "mp_id = %s" params = [mp_id] # 从数据库中查询匹配的记录 conn = connection_pool.get_connection() try: cursor = conn.cursor(dictionary=True) try: cursor.execute(query, params) result = cursor.fetchone() # 获取第一个匹配的记录 finally: cursor.close() finally: conn.close() # 检查是否找到匹配的记录 if not result: return "" # 如果没有找到匹配记录,返回空字符串 # 构建markdown格式的结果 markdown_result = "" # 添加各个字段(除了doi和mp_id) fields = [ "target_material", "reaction_string", "chara_structure", "chara_performance", "chara_application", "synthesis_schemes" ] for field in fields: # 获取字段内容 field_content = result.get(field, "") # 只有当字段内容不为空时才添加该字段 if field_content and field_content.strip(): markdown_result += f"\n## {field}\n{field_content}\n\n" return markdown_result # 直接返回markdown文本 async def mattergen( properties=None, batch_size=2, num_batches=1, diffusion_guidance_factor=2.0 ): """ 调用MatterGen服务生成晶体结构 Args: properties: 可选的属性约束,例如{"dft_band_gap": 2.0} batch_size: 每批生成的结构数量 num_batches: 批次数量 diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 Returns: 生成的结构内容或错误信息 """ try: # 导入MatterGenService from mars_toolkit.services.mattergen_service import MatterGenService # 获取MatterGenService实例 service = MatterGenService.get_instance() # 使用服务生成材料 result = await service.generate( properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor ) return result except Exception as e: import logging logger = logging.getLogger(__name__) logger.error(f"Error in mattergen: {e}") import traceback logger.error(traceback.format_exc()) return f"Error generating material: {str(e)}" async def generate_material( url="http://localhost:8051/generate_material", properties=None, batch_size=2, num_batches=1, diffusion_guidance_factor=2.0 ): """ 调用MatterGen API生成晶体结构 Args: url: API端点URL properties: 可选的属性约束,例如{"dft_band_gap": 2.0} batch_size: 每批生成的结构数量 num_batches: 批次数量 diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 Returns: 生成的结构内容或错误信息 """ # 尝试使用本地MatterGen服务 try: print("尝试使用本地MatterGen服务...") result = await mattergen( properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor ) if result and not result.startswith("Error"): print("本地MatterGen服务生成成功!") return result else: print(f"本地MatterGen服务生成失败,尝试使用API: {result}") except Exception as e: print(f"本地MatterGen服务出错,尝试使用API: {str(e)}") # 如果本地服务失败,回退到API调用 # 规范化参数 normalized_args = normalize_material_args({ "properties": properties, "batch_size": batch_size, "num_batches": num_batches, "diffusion_guidance_factor": diffusion_guidance_factor }) # 构建请求负载 payload = { "properties": normalized_args["properties"], "batch_size": normalized_args["batch_size"], "num_batches": normalized_args["num_batches"], "diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"] } print(f"发送请求到 {url}") print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}") try: # 添加headers参数,包含accept头 headers = { "Content-Type": "application/json", "accept": "application/json" } # 打印完整请求信息(调试用) print(f"完整请求URL: {url}") print(f"请求头: {json.dumps(headers, indent=2)}") print(f"请求体: {json.dumps(payload, indent=2)}") # 禁用代理设置 proxies = { "http": None, "https": None } # 发送POST请求,添加headers参数,禁用代理,增加超时时间 response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300) # 打印响应信息(调试用) print(f"响应状态码: {response.status_code}") print(f"响应头: {dict(response.headers)}") print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长 # 检查响应状态 if response.status_code == 200: result = response.json() if result["success"]: print("\n生成成功!") return result["content"] else: print(f"\n生成失败: {result['message']}") return None else: print(f"\n请求失败,状态码: {response.status_code}") print(f"响应内容: {response.text}") return None except Exception as e: print(f"\n发生错误: {str(e)}") print(f"错误类型: {type(e).__name__}") import traceback print(f"错误堆栈: {traceback.format_exc()}") return None async def execute_tool_from_dict(input_dict: dict): """ 从字典中提取工具函数名称和参数,并执行相应的工具函数 Args: input_dict: 字典,例如: {"name": "search_material_property_from_material_project", "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"} Returns: 工具函数的执行结果,如果工具函数不存在则返回错误信息 """ try: # 解析输入字符串为字典 # input_dict = json.loads(input_str) # 提取函数名和参数 func_name = input_dict.get("name") arguments_data = input_dict.get("arguments") #print('func_name', func_name) #print("argument", arguments_data) if not func_name: return {"status": "error", "message": "未提供函数名称"} # 获取所有注册的工具函数 tools = get_tools() # 检查函数名是否存在于工具函数字典中 if func_name not in tools: return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"} # 获取对应的工具函数 tool_func = tools[func_name] # 处理参数 arguments = {} if arguments_data: # 检查arguments是字符串还是字典 if isinstance(arguments_data, dict): # 如果已经是字典,直接使用 arguments = arguments_data elif isinstance(arguments_data, str): # 如果是字符串,尝试解析为JSON try: # 尝试直接解析为JSON对象 arguments = json.loads(arguments_data) except json.JSONDecodeError: # 如果解析失败,可能是因为字符串中包含转义字符 # 尝试修复常见的JSON字符串问题 fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\') try: arguments = json.loads(fixed_str) except json.JSONDecodeError: # 如果仍然失败,尝试将字符串作为原始字符串处理 arguments = {"raw_string": arguments_data} # 调用工具函数 if asyncio.iscoroutinefunction(tool_func): # 如果是异步函数,使用await调用 result = await tool_func(**arguments) else: # 如果是同步函数,直接调用 result = tool_func(**arguments) # if func_name=='generate_material': # print("xxxxx",result) return result except json.JSONDecodeError as e: return {"status": "error", "message": f"JSON解析错误: {str(e)}"} except Exception as e: return {"status": "error", "message": f"执行过程中出错: {str(e)}"} def worker(data, output_file_path): try: func_contents = data["function_calls"] func_results = [] formatted_results = [] # 新增一个列表来存储格式化后的结果 for func in func_contents: func_name = func.get("name") arguments_data = func.get("arguments") # 使用富文本打印函数名 print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") # 使用富文本打印参数 print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") if func.get("name") == 'retrieval_from_knowledge_base': delay_time = random.uniform(1, 5) time.sleep(delay_time) result = asyncio.run(process_retrieval_from_knowledge_base(data)) func_results.append({"function": func['name'], "result": result}) # 格式化结果 formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" formatted_results.append(formatted_result) elif func.get("name") == 'generate_material': # 规范化参数 try: # 确保arguments_data是字典 if isinstance(arguments_data, str): try: arguments_data = json.loads(arguments_data) except json.JSONDecodeError as e: print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") continue # 规范化参数 normalized_args = normalize_material_args(arguments_data) print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") # 优先使用mattergen函数 try: output = asyncio.run(generate_material(**normalized_args)) # 添加延迟,模拟额外的工具函数调用 # 随机延迟5-10秒 delay_time = random.uniform(5, 10) print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") time.sleep(delay_time) # 模拟其他工具函数调用的日志输出 print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") time.sleep(0.5) print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") time.sleep(0.5) print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") time.sleep(0.5) print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") except Exception as e: print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") # 将结果添加到func_results func_results.append({"function": func_name, "result": output}) # 格式化结果 formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" formatted_results.append(formatted_result) except Exception as e: print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") import traceback print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") else: delay_time = random.uniform(1, 5) time.sleep(delay_time) result = asyncio.run(execute_tool_from_dict(func)) func_results.append({"function": func['name'], "result": result}) # 格式化结果 func_name = func.get("name") formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" formatted_results.append(formatted_result) # 将所有格式化后的结果连接起来 final_result = "\n\n\n".join(formatted_results) data['observation'] = final_result # 使用富文本打印开始和结束标记 print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") print(data['observation']) print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") with file_lock: with jsonlines.open(output_file_path, mode='a') as writer: writer.write(data) # observation . data return f"Processed successfully" except Exception as e: print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") return f"Error processing: {str(e)}" def main(datas, output_file_path, max_workers=1): import random from tqdm import tqdm import os from mysql.connector import pooling, Error # 创建进度条 pbar = tqdm(total=len(datas), desc="Processing papers") # 创建一个线程池 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # 提交任务到执行器 future_to_path = {} for path in datas: future = executor.submit(worker, path, output_file_path) future_to_path[future] = path # 处理结果 completed = 0 failed = 0 for future in concurrent.futures.as_completed(future_to_path): path = future_to_path[future] try: result = future.result() if "successfully" in result: completed += 1 else: failed += 1 # 更新进度条 pbar.update(1) # 每100个文件更新一次统计信息 if (completed + failed) % 100 == 0: pbar.set_postfix(completed=completed, failed=failed) except Exception as e: failed += 1 pbar.update(1) print(f"\nWorker for {path} generated an exception: {e}") pbar.close() print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") if __name__ == '__main__': import datetime import jsonlines datas = [] with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader: for obj in reader: datas.append(obj) print(len(datas)) # print() output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" main(datas, output_file, max_workers=16) # 示例1:使用正确的JSON格式 # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' # argument = json.loads(argument) # print(json.dumps(argument, indent=2)) # asyncio.run(mattergen(**argument))