diff --git a/.gitignore b/.gitignore index cd53b7e..1ce8b20 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ pyproject.toml /pretrained_models /mcp-python-sdk /.vscode + +/*filter_ok_questions_solutions_agent* diff --git a/__pycache__/execute_tool_copy.cpython-310.pyc b/__pycache__/execute_tool_copy.cpython-310.pyc new file mode 100644 index 0000000..5020866 Binary files /dev/null and b/__pycache__/execute_tool_copy.cpython-310.pyc differ diff --git a/execute_tool_copy.py b/execute_tool_copy.py index 40700a3..533fb8a 100644 --- a/execute_tool_copy.py +++ b/execute_tool_copy.py @@ -6,6 +6,8 @@ import jsonlines from mars_toolkit import * import threading import uuid + +from mars_toolkit.compute.material_gen import generate_material # Create a lock for file writing file_lock = threading.Lock() from mysql.connector import pooling @@ -180,153 +182,6 @@ async def process_retrieval_from_knowledge_base(data): -async def mattergen( - properties=None, - batch_size=2, - num_batches=1, - diffusion_guidance_factor=2.0 -): - """ - 调用MatterGen服务生成晶体结构 - - Args: - properties: 可选的属性约束,例如{"dft_band_gap": 2.0} - batch_size: 每批生成的结构数量 - num_batches: 批次数量 - diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 - - Returns: - 生成的结构内容或错误信息 - """ - try: - # 导入MatterGenService - from mars_toolkit.services.mattergen_service import MatterGenService - - # 获取MatterGenService实例 - service = MatterGenService.get_instance() - - # 使用服务生成材料 - result = await service.generate( - properties=properties, - batch_size=batch_size, - num_batches=num_batches, - diffusion_guidance_factor=diffusion_guidance_factor - ) - - return result - except Exception as e: - import logging - logger = logging.getLogger(__name__) - logger.error(f"Error in mattergen: {e}") - import traceback - logger.error(traceback.format_exc()) - return f"Error generating material: {str(e)}" - -async def generate_material( - url="http://localhost:8051/generate_material", - properties=None, - batch_size=2, - num_batches=1, - diffusion_guidance_factor=2.0 -): - """ - 调用MatterGen API生成晶体结构 - - Args: - url: API端点URL - properties: 可选的属性约束,例如{"dft_band_gap": 2.0} - batch_size: 每批生成的结构数量 - num_batches: 批次数量 - diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 - - Returns: - 生成的结构内容或错误信息 - """ - # 尝试使用本地MatterGen服务 - try: - print("尝试使用本地MatterGen服务...") - result = await mattergen( - properties=properties, - batch_size=batch_size, - num_batches=num_batches, - diffusion_guidance_factor=diffusion_guidance_factor - ) - if result and not result.startswith("Error"): - print("本地MatterGen服务生成成功!") - return result - else: - print(f"本地MatterGen服务生成失败,尝试使用API: {result}") - except Exception as e: - print(f"本地MatterGen服务出错,尝试使用API: {str(e)}") - - # 如果本地服务失败,回退到API调用 - # 规范化参数 - normalized_args = normalize_material_args({ - "properties": properties, - "batch_size": batch_size, - "num_batches": num_batches, - "diffusion_guidance_factor": diffusion_guidance_factor - }) - - # 构建请求负载 - payload = { - "properties": normalized_args["properties"], - "batch_size": normalized_args["batch_size"], - "num_batches": normalized_args["num_batches"], - "diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"] - } - - print(f"发送请求到 {url}") - print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}") - - try: - # 添加headers参数,包含accept头 - headers = { - "Content-Type": "application/json", - "accept": "application/json" - } - - # 打印完整请求信息(调试用) - print(f"完整请求URL: {url}") - print(f"请求头: {json.dumps(headers, indent=2)}") - print(f"请求体: {json.dumps(payload, indent=2)}") - - # 禁用代理设置 - proxies = { - "http": None, - "https": None - } - - # 发送POST请求,添加headers参数,禁用代理,增加超时时间 - response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300) - - # 打印响应信息(调试用) - print(f"响应状态码: {response.status_code}") - print(f"响应头: {dict(response.headers)}") - print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长 - - # 检查响应状态 - if response.status_code == 200: - result = response.json() - - if result["success"]: - print("\n生成成功!") - return result["content"] - else: - print(f"\n生成失败: {result['message']}") - return None - else: - print(f"\n请求失败,状态码: {response.status_code}") - print(f"响应内容: {response.text}") - return None - - except Exception as e: - print(f"\n发生错误: {str(e)}") - print(f"错误类型: {type(e).__name__}") - import traceback - print(f"错误堆栈: {traceback.format_exc()}") - return None - async def execute_tool_from_dict(input_dict: dict): """ 从字典中提取工具函数名称和参数,并执行相应的工具函数 @@ -416,14 +271,14 @@ def worker(data, output_file_path): print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") if func.get("name") == 'retrieval_from_knowledge_base': - delay_time = random.uniform(1, 5) - - time.sleep(delay_time) - result = asyncio.run(process_retrieval_from_knowledge_base(data)) - func_results.append({"function": func['name'], "result": result}) - # 格式化结果 - formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" - formatted_results.append(formatted_result) + pass + # delay_time = random.uniform(5, 10) + # time.sleep(delay_time) + # result = asyncio.run(process_retrieval_from_knowledge_base(data)) + # func_results.append({"function": func['name'], "result": result}) + # # 格式化结果 + # formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + # formatted_results.append(formatted_result) elif func.get("name") == 'generate_material': # 规范化参数 @@ -438,30 +293,30 @@ def worker(data, output_file_path): # 规范化参数 normalized_args = normalize_material_args(arguments_data) - print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") - print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") - print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") # 优先使用mattergen函数 try: - output = asyncio.run(generate_material(**normalized_args)) + # output = asyncio.run(generate_material(**normalized_args)) + output = generate_material(**normalized_args) # 添加延迟,模拟额外的工具函数调用 - # 随机延迟5-10秒 - delay_time = random.uniform(5, 10) - print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") - time.sleep(delay_time) + # delay_time = random.uniform(5, 10) + # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") + # time.sleep(delay_time) - # 模拟其他工具函数调用的日志输出 - print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") - time.sleep(0.5) - print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") - time.sleep(0.5) - print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") - time.sleep(0.5) - print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") + # # 模拟其他工具函数调用的日志输出 + # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") + # time.sleep(0.5) + # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") + # time.sleep(0.5) + # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") + # time.sleep(0.5) + # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") except Exception as e: print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") @@ -478,14 +333,15 @@ def worker(data, output_file_path): print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") else: - delay_time = random.uniform(1, 5) - time.sleep(delay_time) - result = asyncio.run(execute_tool_from_dict(func)) - func_results.append({"function": func['name'], "result": result}) - # 格式化结果 - func_name = func.get("name") - formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" - formatted_results.append(formatted_result) + # delay_time = random.uniform(5, 10) + # time.sleep(delay_time) + pass + # result = asyncio.run(execute_tool_from_dict(func)) + # func_results.append({"function": func['name'], "result": result}) + # # 格式化结果 + # func_name = func.get("name") + # formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + # formatted_results.append(formatted_result) # 将所有格式化后的结果连接起来 final_result = "\n\n\n".join(formatted_results) @@ -557,8 +413,8 @@ if __name__ == '__main__': print(len(datas)) # print() - output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" - main(datas, output_file, max_workers=16) + output_file = f"./filter_ok_questions_solutions_agent_mattergen_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" + main(datas, output_file, max_workers=1) # 示例1:使用正确的JSON格式 # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' diff --git a/execute_tool_other_tools.py b/execute_tool_other_tools.py new file mode 100644 index 0000000..770fe44 --- /dev/null +++ b/execute_tool_other_tools.py @@ -0,0 +1,423 @@ +import json +import asyncio +import concurrent.futures + +import jsonlines +from mars_toolkit import * +import threading +import uuid + +from mars_toolkit.compute.material_gen import generate_material +# Create a lock for file writing +file_lock = threading.Lock() +from mysql.connector import pooling +from colorama import Fore, Back, Style, init +import time +import random +# 初始化colorama +init(autoreset=True) + +from typing import Dict, Union, Any, Optional + +def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + 规范化传递给generate_material函数的参数格式。 + + 处理以下情况: + 1. properties参数可能是字符串形式的JSON,需要解析为字典 + 2. properties中的值可能需要转换为适当的类型(数字或字符串) + 3. 确保batch_size和num_batches是整数 + + Args: + arguments: 包含generate_material参数的字典 + + Returns: + 规范化后的参数字典 + """ + normalized_args = arguments.copy() + + # 处理properties参数 + if "properties" in normalized_args: + properties = normalized_args["properties"] + + # 如果properties是字符串,尝试解析为JSON + if isinstance(properties, str): + try: + properties = json.loads(properties) + except json.JSONDecodeError as e: + raise ValueError(f"无法解析properties JSON字符串: {e}") + + # 确保properties是字典 + if not isinstance(properties, dict): + raise ValueError(f"properties必须是字典或JSON字符串,而不是 {type(properties)}") + + # 处理properties中的值 + normalized_properties = {} + for key, value in properties.items(): + # 处理范围值,例如 "0.0-2.0" 或 "40-50" + if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"): + # 保持范围值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.startswith(">"): + # 保持大于值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.startswith("<"): + # 保持小于值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.lower() == "relaxor": + # 特殊值保持为字符串 + normalized_properties[key] = value + elif isinstance(value, str) and value.endswith("eV"): + # 带单位的值保持为字符串 + normalized_properties[key] = value + else: + # 尝试将值转换为数字 + try: + # 如果可以转换为浮点数 + float_value = float(value) + # 如果是整数,转换为整数 + if float_value.is_integer(): + normalized_properties[key] = int(float_value) + else: + normalized_properties[key] = float_value + except (ValueError, TypeError): + # 如果无法转换为数字,保持原值 + normalized_properties[key] = value + + normalized_args["properties"] = normalized_properties + + # 确保batch_size和num_batches是整数 + if "batch_size" in normalized_args: + try: + normalized_args["batch_size"] = int(normalized_args["batch_size"]) + except (ValueError, TypeError): + raise ValueError(f"batch_size必须是整数,而不是 {normalized_args['batch_size']}") + + if "num_batches" in normalized_args: + try: + normalized_args["num_batches"] = int(normalized_args["num_batches"]) + except (ValueError, TypeError): + raise ValueError(f"num_batches必须是整数,而不是 {normalized_args['num_batches']}") + + # 确保diffusion_guidance_factor是浮点数 + if "diffusion_guidance_factor" in normalized_args: + try: + normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"]) + except (ValueError, TypeError): + raise ValueError(f"diffusion_guidance_factor必须是数字,而不是 {normalized_args['diffusion_guidance_factor']}") + + return normalized_args + + +import requests +connection_pool = pooling.MySQLConnectionPool( + pool_name="mypool", + pool_size=32, + pool_reset_session=True, + host='localhost', + user='metadata_mat_papers', + password='siat-mic', + database='metadata_mat_papers' + ) + +async def process_retrieval_from_knowledge_base(data): + doi = data.get('doi') + mp_id = data.get('mp_id') + + # 检查是否提供了至少一个查询参数 + if doi is None and mp_id is None: + return "" # 如果没有提供查询参数,返回空字符串 + + # 构建SQL查询条件 + query = "SELECT * FROM mp_synthesis_scheme_info WHERE " + params = [] + + if doi is not None and mp_id is not None: + query += "doi = %s OR mp_id = %s" + params = [doi, mp_id] + elif doi is not None: + query += "doi = %s" + params = [doi] + else: # mp_id is not None + query += "mp_id = %s" + params = [mp_id] + + # 从数据库中查询匹配的记录 + conn = connection_pool.get_connection() + try: + cursor = conn.cursor(dictionary=True) + try: + cursor.execute(query, params) + result = cursor.fetchone() # 获取第一个匹配的记录 + finally: + cursor.close() + finally: + conn.close() + + # 检查是否找到匹配的记录 + if not result: + return "" # 如果没有找到匹配记录,返回空字符串 + + # 构建markdown格式的结果 + markdown_result = "" + + # 添加各个字段(除了doi和mp_id) + fields = [ + "target_material", + "reaction_string", + "chara_structure", + "chara_performance", + "chara_application", + "synthesis_schemes" + ] + + for field in fields: + # 获取字段内容 + field_content = result.get(field, "") + # 只有当字段内容不为空时才添加该字段 + if field_content and field_content.strip(): + markdown_result += f"\n## {field}\n{field_content}\n\n" + + return markdown_result # 直接返回markdown文本 + + + +async def execute_tool_from_dict(input_dict: dict): + """ + 从字典中提取工具函数名称和参数,并执行相应的工具函数 + + Args: + input_dict: 字典,例如: + {"name": "search_material_property_from_material_project", + "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"} + + Returns: + 工具函数的执行结果,如果工具函数不存在则返回错误信息 + """ + try: + # 解析输入字符串为字典 + # input_dict = json.loads(input_str) + + # 提取函数名和参数 + func_name = input_dict.get("name") + arguments_data = input_dict.get("arguments") + #print('func_name', func_name) + #print("argument", arguments_data) + if not func_name: + return {"status": "error", "message": "未提供函数名称"} + + # 获取所有注册的工具函数 + tools = get_tools() + + # 检查函数名是否存在于工具函数字典中 + if func_name not in tools: + return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"} + + # 获取对应的工具函数 + tool_func = tools[func_name] + + # 处理参数 + arguments = {} + if arguments_data: + # 检查arguments是字符串还是字典 + if isinstance(arguments_data, dict): + # 如果已经是字典,直接使用 + arguments = arguments_data + elif isinstance(arguments_data, str): + # 如果是字符串,尝试解析为JSON + try: + # 尝试直接解析为JSON对象 + arguments = json.loads(arguments_data) + except json.JSONDecodeError: + # 如果解析失败,可能是因为字符串中包含转义字符 + # 尝试修复常见的JSON字符串问题 + fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\') + try: + arguments = json.loads(fixed_str) + except json.JSONDecodeError: + # 如果仍然失败,尝试将字符串作为原始字符串处理 + arguments = {"raw_string": arguments_data} + + # 调用工具函数 + if asyncio.iscoroutinefunction(tool_func): + # 如果是异步函数,使用await调用 + result = await tool_func(**arguments) + else: + # 如果是同步函数,直接调用 + result = tool_func(**arguments) + # if func_name=='generate_material': + # print("xxxxx",result) + return result + + except json.JSONDecodeError as e: + return {"status": "error", "message": f"JSON解析错误: {str(e)}"} + except Exception as e: + return {"status": "error", "message": f"执行过程中出错: {str(e)}"} + + +def worker(data, output_file_path): + try: + func_contents = data["function_calls"] + func_results = [] + formatted_results = [] # 新增一个列表来存储格式化后的结果 + for func in func_contents: + func_name = func.get("name") + arguments_data = func.get("arguments") + + # 使用富文本打印函数名 + print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + + # 使用富文本打印参数 + print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") + + if func.get("name") == 'retrieval_from_knowledge_base': + pass + # delay_time = random.uniform(5, 10) + # time.sleep(delay_time) + result = asyncio.run(process_retrieval_from_knowledge_base(data)) + func_results.append({"function": func['name'], "result": result}) + # 格式化结果 + formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + formatted_results.append(formatted_result) + + elif func.get("name") == 'generate_material': + # # 规范化参数 + # try: + # # 确保arguments_data是字典 + # if isinstance(arguments_data, str): + # try: + # arguments_data = json.loads(arguments_data) + # except json.JSONDecodeError as e: + # print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") + # continue + + # # 规范化参数 + # normalized_args = normalize_material_args(arguments_data) + # # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + # # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + # # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + + # # 优先使用mattergen函数 + # try: + # # output = asyncio.run(generate_material(**normalized_args)) + # output = generate_material(**normalized_args) + + # # 添加延迟,模拟额外的工具函数调用 + + # # 随机延迟5-10秒 + # # delay_time = random.uniform(5, 10) + # # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") + # # time.sleep(delay_time) + + # # # 模拟其他工具函数调用的日志输出 + # # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") + # # time.sleep(0.5) + # # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") + # # time.sleep(0.5) + # # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") + # # time.sleep(0.5) + # # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") + + # except Exception as e: + # print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") + + # # 将结果添加到func_results + # func_results.append({"function": func_name, "result": output}) + + # # 格式化结果 + # formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" + # formatted_results.append(formatted_result) + # except Exception as e: + # print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") + # import traceback + # print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") + pass + else: + # delay_time = random.uniform(5, 10) + # time.sleep(delay_time) + + result = asyncio.run(execute_tool_from_dict(func)) + func_results.append({"function": func['name'], "result": result}) + # 格式化结果 + func_name = func.get("name") + formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + formatted_results.append(formatted_result) + + # 将所有格式化后的结果连接起来 + final_result = "\n\n\n".join(formatted_results) + data['observation'] = final_result + + # 使用富文本打印开始和结束标记 + print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") + print(data['observation']) + print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") + with file_lock: + with jsonlines.open(output_file_path, mode='a') as writer: + writer.write(data) # observation . data + return f"Processed successfully" + + except Exception as e: + print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") + return f"Error processing: {str(e)}" + + +def main(datas, output_file_path, max_workers=1): + import random + from tqdm import tqdm + import os + from mysql.connector import pooling, Error + + # 创建进度条 + pbar = tqdm(total=len(datas), desc="Processing papers") + + # 创建一个线程池 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交任务到执行器 + future_to_path = {} + for path in datas: + future = executor.submit(worker, path, output_file_path) + future_to_path[future] = path + + # 处理结果 + completed = 0 + failed = 0 + for future in concurrent.futures.as_completed(future_to_path): + path = future_to_path[future] + try: + result = future.result() + if "successfully" in result: + completed += 1 + else: + failed += 1 + # 更新进度条 + pbar.update(1) + # 每100个文件更新一次统计信息 + if (completed + failed) % 100 == 0: + pbar.set_postfix(completed=completed, failed=failed) + except Exception as e: + failed += 1 + pbar.update(1) + print(f"\nWorker for {path} generated an exception: {e}") + + pbar.close() + print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") + + +if __name__ == '__main__': + import datetime + import jsonlines + datas = [] + with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader: + for obj in reader: + datas.append(obj) + + print(len(datas)) + # print() + output_file = f"./filter_ok_questions_solutions_agent_other_tools_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" + main(datas, output_file, max_workers=32) + + # 示例1:使用正确的JSON格式 + # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' + # argument = json.loads(argument) + # print(json.dumps(argument, indent=2)) + # asyncio.run(mattergen(**argument)) diff --git a/mars_toolkit/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/__pycache__/__init__.cpython-310.pyc index e4e2215..6a1511f 100644 Binary files a/mars_toolkit/__pycache__/__init__.cpython-310.pyc and b/mars_toolkit/__pycache__/__init__.cpython-310.pyc differ diff --git a/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc b/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc index e929dd8..675755e 100644 Binary files a/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc and b/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc differ diff --git a/mars_toolkit/compute/material_gen.py b/mars_toolkit/compute/material_gen.py index b1fa653..7f89ec6 100644 --- a/mars_toolkit/compute/material_gen.py +++ b/mars_toolkit/compute/material_gen.py @@ -9,9 +9,18 @@ import asyncio import zipfile import shutil import re +import multiprocessing +from multiprocessing import Process, Queue from pathlib import Path from typing import Literal, Dict, Any, Tuple, Union, Optional, List +# 设置多进程启动方法为spawn,解决CUDA初始化错误 +try: + multiprocessing.set_start_method('spawn', force=True) +except RuntimeError: + # 如果已经设置过启动方法,会抛出RuntimeError + pass + from ase.optimize import FIRE from ase.filters import FrechetCellFilter from ase.atoms import Atoms @@ -33,6 +42,49 @@ from ..core.mattergen_wrapper import * logger = logging.getLogger(__name__) +def _process_generate_material_worker(args_queue, result_queue): + """ + 在新进程中处理材料生成的工作函数 + + Args: + args_queue: 包含生成参数的队列 + result_queue: 用于返回结果的队列 + """ + try: + # 配置日志 + import logging + logger = logging.getLogger(__name__) + logger.info("子进程开始执行材料生成...") + + # 从队列获取参数 + args = args_queue.get() + logger.info(f"子进程获取到参数: {args}") + + # 导入MatterGenService + from mars_toolkit.services.mattergen_service import MatterGenService + logger.info("子进程成功导入MatterGenService") + + # 获取MatterGenService实例 + service = MatterGenService.get_instance() + logger.info("子进程成功获取MatterGenService实例") + + # 使用服务生成材料 + logger.info("子进程开始调用generate方法...") + result = service.generate(**args) + logger.info("子进程generate方法调用完成") + + # 将结果放入结果队列 + result_queue.put(result) + logger.info("子进程材料生成完成,结果已放入队列") + except Exception as e: + # 如果发生错误,将错误信息放入结果队列 + import traceback + error_msg = f"材料生成过程中出错: {str(e)}\n{traceback.format_exc()}" + import logging + logging.getLogger(__name__).error(error_msg) + result_queue.put(f"Error: {error_msg}") + + def format_cif_content(content): """ Format CIF content by removing unnecessary headers and organizing each CIF file. @@ -233,7 +285,7 @@ def main( @llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints") -async def generate_material( +def generate_material( properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None, batch_size: int = 2, num_batches: int = 1, @@ -260,16 +312,45 @@ async def generate_material( Returns: Descriptive text with generated crystal structures in CIF format """ + # # 创建队列用于进程间通信 + # args_queue = Queue() + # result_queue = Queue() + + # # 将参数放入队列 + # args_queue.put({ + # "properties": properties, + # "batch_size": batch_size, + # "num_batches": num_batches, + # "diffusion_guidance_factor": diffusion_guidance_factor + # }) + + # # 创建并启动新进程 + # logger.info("启动新进程处理材料生成...") + # p = Process(target=_process_generate_material_worker, args=(args_queue, result_queue)) + # p.start() + + # # 等待进程完成并获取结果 + # p.join() + # result = result_queue.get() + + # # 检查结果是否为错误信息 + # if isinstance(result, str) and result.startswith("Error:"): + # # 记录错误日志 + # logger.error(result) + # 导入MatterGenService from mars_toolkit.services.mattergen_service import MatterGenService + logger.info("子进程成功导入MatterGenService") # 获取MatterGenService实例 service = MatterGenService.get_instance() + logger.info("子进程成功获取MatterGenService实例") # 使用服务生成材料 - return service.generate( - properties=properties, - batch_size=batch_size, - num_batches=num_batches, - diffusion_guidance_factor=diffusion_guidance_factor - ) + logger.info("子进程开始调用generate方法...") + result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor) + logger.info("子进程generate方法调用完成") + if "Error generating structures" in result: + return f"Error: Invalid properties {properties}." + else: + return result diff --git a/mars_toolkit/core/__pycache__/config.cpython-310.pyc b/mars_toolkit/core/__pycache__/config.cpython-310.pyc index b6170e0..d11872d 100644 Binary files a/mars_toolkit/core/__pycache__/config.cpython-310.pyc and b/mars_toolkit/core/__pycache__/config.cpython-310.pyc differ diff --git a/mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc b/mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc index a0a2b88..33d3c64 100644 Binary files a/mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc and b/mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc differ diff --git a/mars_toolkit/core/config.py b/mars_toolkit/core/config.py index 4babf2c..886a480 100644 --- a/mars_toolkit/core/config.py +++ b/mars_toolkit/core/config.py @@ -35,7 +35,7 @@ class Config: DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA' # Searxng - SEARXNG_HOST="http://192.168.191.101:40032/" + SEARXNG_HOST="http://192.168.168.1:40032/" # Visualization VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization' diff --git a/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc b/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc index 8f99194..0da5f83 100644 Binary files a/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc and b/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc differ diff --git a/mars_toolkit/query/web_search.py b/mars_toolkit/query/web_search.py index 1aa6ab2..7a3c15e 100644 --- a/mars_toolkit/query/web_search.py +++ b/mars_toolkit/query/web_search.py @@ -5,6 +5,7 @@ This module provides functions for searching information on the web. """ import asyncio +import os from typing import Annotated, Dict, Any, List from langchain_community.utilities import SearxSearchWrapper @@ -28,6 +29,8 @@ async def search_online( Formatted string with search results (titles, snippets, links) """ # 确保 num_results 是整数 + os.environ['HTTP_PROXY'] = '' + os.environ['HTTPS_PROXY'] = '' try: num_results = int(num_results) except (TypeError, ValueError): diff --git a/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc b/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc index dfaa374..cb8e3dc 100644 Binary files a/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc and b/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc differ diff --git a/test_mars_toolkit.py b/test_mars_toolkit.py index ec87396..9adc5ca 100644 --- a/test_mars_toolkit.py +++ b/test_mars_toolkit.py @@ -62,7 +62,8 @@ async def test_tool(tool_name: str) -> str: elif tool_name == "generate_material": from mars_toolkit.compute.material_gen import generate_material # 使用简单的属性约束进行测试 - result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1) + # result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1) + result = generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1) elif tool_name == "fetch_chemical_composition_from_OQMD": from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD @@ -171,7 +172,7 @@ if __name__ == "__main__": ] # 选择要测试的工具 - tool_name = tools_to_test[6] # 测试 search_online 工具 + tool_name = tools_to_test[1] # 测试 search_online 工具 # 运行测试 result = asyncio.run(test_tool(tool_name))