From 6b92e54a414ef24997cabe71cf1efc69f444f1ee Mon Sep 17 00:00:00 2001 From: lzy <949777411@qq.com> Date: Wed, 16 Apr 2025 11:15:01 +0800 Subject: [PATCH] =?UTF-8?q?mcp,=E7=94=9F=E6=88=90=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 0 __pycache__/api_key.cpython-310.pyc | Bin 0 -> 251 bytes __pycache__/execute_tool_copy.cpython-310.pyc | Bin __pycache__/mattergen_wrapper.cpython-310.pyc | Bin __pycache__/mattergen_wrapper.cpython-312.pyc | Bin 1263 -> 0 bytes .../normalize_material_args.cpython-310.pyc | Bin agent_test.py | 0 api_key.py | 0 .../filter_generate_material_data.py | 89 ++ .../generate_data10000.py | 195 +-- .../generate_tool_observation.py | 161 +-- generate_data/grpo_tools.py | 91 ++ generate_data/grpo_utils.py | 294 +++++ generate_data/utils.py | 800 +++++++++++ mars_toolkit.log | 1172 ----------------- mars_toolkit/__init__.py | 0 .../__pycache__/__init__.cpython-310.pyc | Bin 1115 -> 1119 bytes mars_toolkit/compute/__init__.py | 0 .../__pycache__/__init__.cpython-310.pyc | Bin 598 -> 602 bytes .../__pycache__/material_gen.cpython-310.pyc | Bin 10463 -> 10467 bytes .../__pycache__/property_pred.cpython-310.pyc | Bin 2152 -> 2156 bytes .../__pycache__/structure_opt.cpython-310.pyc | Bin 5229 -> 5233 bytes mars_toolkit/compute/material_gen.py | 0 mars_toolkit/compute/property_pred.py | 0 mars_toolkit/compute/structure_opt.py | 0 mars_toolkit/core/__init__.py | 0 .../core/__pycache__/__init__.cpython-310.pyc | Bin 360 -> 364 bytes .../__pycache__/cif_utils.cpython-310.pyc | Bin 3394 -> 3398 bytes .../core/__pycache__/config.cpython-310.pyc | Bin 2121 -> 2141 bytes .../error_handlers.cpython-310.pyc | Bin .../__pycache__/llm_tools.cpython-310.pyc | Bin 5225 -> 5229 bytes .../mattergen_wrapper.cpython-310.pyc | Bin 1013 -> 1017 bytes .../core/__pycache__/utils.cpython-310.pyc | Bin mars_toolkit/core/cif_utils.py | 0 mars_toolkit/core/config.py | 8 +- mars_toolkit/core/llm_tools.py | 0 mars_toolkit/core/mattergen_wrapper.py | 0 mars_toolkit/misc/__init__.py | 0 .../misc/__pycache__/__init__.cpython-310.pyc | Bin 320 -> 324 bytes .../__pycache__/general_tools.cpython-310.pyc | Bin .../__pycache__/misc_tools.cpython-310.pyc | Bin 1207 -> 1211 bytes mars_toolkit/misc/misc_tools.py | 0 mars_toolkit/query/__init__.py | 0 .../__pycache__/__init__.cpython-310.pyc | Bin 785 -> 789 bytes .../__pycache__/dify_search.cpython-310.pyc | Bin 1721 -> 1725 bytes .../__pycache__/mp_query.cpython-310.pyc | Bin 12944 -> 12948 bytes .../__pycache__/oqmd_query.cpython-310.pyc | Bin 3006 -> 3010 bytes .../__pycache__/web_search.cpython-310.pyc | Bin 2243 -> 2247 bytes mars_toolkit/query/dify_search.py | 0 mars_toolkit/query/mp_query.py | 0 mars_toolkit/query/oqmd_query.py | 0 mars_toolkit/query/web_search.py | 0 mars_toolkit/services/__init__.py | 0 .../__pycache__/__init__.cpython-310.pyc | Bin .../mattergen_service.cpython-310.pyc | Bin mars_toolkit/services/mattergen_service.py | 0 mars_toolkit/visualization/__init__.py | 0 .../__pycache__/__init__.cpython-310.pyc | Bin .../__pycache__/band_vis.cpython-310.pyc | Bin .../__pycache__/crystal_vis.cpython-310.pyc | Bin mattergen_api.py | 0 mattergen_client_example.py | 134 -- .../material_synthesis.cpython-310.pyc | Bin 0 -> 5363 bytes prompts/material_synthesis.py | 167 +++ server.py | 306 +++++ test_mars_toolkit.py | 4 +- 66 files changed, 1938 insertions(+), 1483 deletions(-) mode change 100644 => 100755 .gitignore create mode 100755 __pycache__/api_key.cpython-310.pyc mode change 100644 => 100755 __pycache__/execute_tool_copy.cpython-310.pyc mode change 100644 => 100755 __pycache__/mattergen_wrapper.cpython-310.pyc delete mode 100644 __pycache__/mattergen_wrapper.cpython-312.pyc mode change 100644 => 100755 __pycache__/normalize_material_args.cpython-310.pyc mode change 100644 => 100755 agent_test.py mode change 100644 => 100755 api_key.py create mode 100755 generate_data/filter_generate_material_data.py rename execute_tool_other_tools.py => generate_data/generate_data10000.py (71%) mode change 100644 => 100755 rename execute_tool_copy.py => generate_data/generate_tool_observation.py (74%) mode change 100644 => 100755 create mode 100755 generate_data/grpo_tools.py create mode 100755 generate_data/grpo_utils.py create mode 100755 generate_data/utils.py delete mode 100644 mars_toolkit.log mode change 100644 => 100755 mars_toolkit/__init__.py mode change 100644 => 100755 mars_toolkit/__pycache__/__init__.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/compute/__init__.py mode change 100644 => 100755 mars_toolkit/compute/__pycache__/__init__.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/compute/__pycache__/structure_opt.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/compute/material_gen.py mode change 100644 => 100755 mars_toolkit/compute/property_pred.py mode change 100644 => 100755 mars_toolkit/compute/structure_opt.py mode change 100644 => 100755 mars_toolkit/core/__init__.py mode change 100644 => 100755 mars_toolkit/core/__pycache__/__init__.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/core/__pycache__/cif_utils.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/core/__pycache__/config.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/core/__pycache__/error_handlers.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/core/__pycache__/utils.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/core/cif_utils.py mode change 100644 => 100755 mars_toolkit/core/config.py mode change 100644 => 100755 mars_toolkit/core/llm_tools.py mode change 100644 => 100755 mars_toolkit/core/mattergen_wrapper.py mode change 100644 => 100755 mars_toolkit/misc/__init__.py mode change 100644 => 100755 mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/misc/__pycache__/general_tools.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/misc/misc_tools.py mode change 100644 => 100755 mars_toolkit/query/__init__.py mode change 100644 => 100755 mars_toolkit/query/__pycache__/__init__.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/query/__pycache__/dify_search.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/query/__pycache__/web_search.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/query/dify_search.py mode change 100644 => 100755 mars_toolkit/query/mp_query.py mode change 100644 => 100755 mars_toolkit/query/oqmd_query.py mode change 100644 => 100755 mars_toolkit/query/web_search.py mode change 100644 => 100755 mars_toolkit/services/__init__.py mode change 100644 => 100755 mars_toolkit/services/__pycache__/__init__.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/services/mattergen_service.py mode change 100644 => 100755 mars_toolkit/visualization/__init__.py mode change 100644 => 100755 mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc mode change 100644 => 100755 mars_toolkit/visualization/__pycache__/crystal_vis.cpython-310.pyc mode change 100644 => 100755 mattergen_api.py delete mode 100644 mattergen_client_example.py create mode 100755 prompts/__pycache__/material_synthesis.cpython-310.pyc create mode 100755 prompts/material_synthesis.py create mode 100755 server.py mode change 100644 => 100755 test_mars_toolkit.py diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/__pycache__/api_key.cpython-310.pyc b/__pycache__/api_key.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e268273c537a02ba231126de3b2e4d8f066b6221 GIT binary patch literal 251 zcmd1j<>g`k0z0l(=^uggV-N=!FakLaKwQiNBvKfn7*ZIc7*m*n88n%zjEl2%^CL5i zBZ@N8E%IF=jncynjFK(Q%+ibvoh?jEOe`H!EX>?2ol`7R+{|3fOcE_psw6T>N(zdt z^!3X!3-l5TGAlFnlJj%*%MATAnQrm<2e|qYEyb=;u^b>gOgF73=0E7w7|xj?Yf5)GMgG#bJ}1pHiBWY6tR6 MF$<7jVc=i{08Oz#M*si- literal 0 HcmV?d00001 diff --git a/__pycache__/execute_tool_copy.cpython-310.pyc b/__pycache__/execute_tool_copy.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/__pycache__/mattergen_wrapper.cpython-310.pyc b/__pycache__/mattergen_wrapper.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/__pycache__/mattergen_wrapper.cpython-312.pyc b/__pycache__/mattergen_wrapper.cpython-312.pyc deleted file mode 100644 index e82e1d701806ac96af7f748d2d908ebaffd5e674..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1263 zcmZ`&O-$TI6n&^G(ee>S;X5ap0 zng;N7?aGt(uQI?t!s)Nl5!m+;zy^3=gNHmRL^djR$(HyohcZ!Y1p$^;id_Bchx>~8 zsxQ3PFb7XyY7cmAu{`C0bck#nqZ2EQ9L8U^9>RM69s*tHFNGCSwW~lt|TNBBwVD@9P8 zR!&y=D)TmYl0ABTbQP___wu)}iYCmp?~Fx1NG<-k)+%+A1XGK|UMIv>=DV4d(D-KH zVIFR^aGF|K%+Ik!+$_VijiV!SYAodf2FrO6wF}C8p80WPC2rGd7q=XT$!VT4DNdOx+6{vx%I6!HB0Q$7aA{2S0vfpl zGtF_9gAnsu=5c+82U$6YQr;U?LHy2>^MgQS5NW-?b6}YeL>&lDa8LU{h6Dd zj1YP%$w=AP09F15y$fT#GMpOTQ17W5)qB;=h3%_bSGS!l=l->aZ-4ej<>AouE%On` zceHiw&fxmsU3B;JJs9tmb!2pvT32awmC3F$!@p+(P`QED)#p7EwBctUBjd56?^YYX zqx 1: + file_path = sys.argv[1] + + console.print(f"[bold blue]正在处理文件: {file_path}[/bold blue]") + filter_generate_material(file_path) diff --git a/execute_tool_other_tools.py b/generate_data/generate_data10000.py old mode 100644 new mode 100755 similarity index 71% rename from execute_tool_other_tools.py rename to generate_data/generate_data10000.py index 770fe44..36e1857 --- a/execute_tool_other_tools.py +++ b/generate_data/generate_data10000.py @@ -17,7 +17,7 @@ import random # 初始化colorama init(autoreset=True) -from typing import Dict, Union, Any, Optional +from typing import Dict, Union, Any, Optional, List def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: """ @@ -110,6 +110,8 @@ def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: import requests + +# 创建数据库连接池(仅用于初始加载数据) connection_pool = pooling.MySQLConnectionPool( pool_name="mypool", pool_size=32, @@ -120,7 +122,50 @@ connection_pool = pooling.MySQLConnectionPool( database='metadata_mat_papers' ) +# 内存缓存,用于存储从数据库加载的数据 +# 结构: {doi: record, mp_id: record} +memory_cache = {} + +def load_data_to_memory(): + """ + 从数据库加载所有数据到内存中 + """ + print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}") + + conn = connection_pool.get_connection() + try: + cursor = conn.cursor(dictionary=True) + try: + # 查询所有记录 + cursor.execute("SELECT * FROM mp_synthesis_scheme_info") + records = cursor.fetchall() + + # 将记录添加到内存缓存中 + for record in records: + doi = record.get('doi') + mp_id = record.get('mp_id') + + # 使用doi作为键(如果存在) + if doi: + memory_cache[doi] = record + + # 使用mp_id作为键(如果存在) + if mp_id: + memory_cache[mp_id] = record + + print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}") + print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}") + + finally: + cursor.close() + finally: + conn.close() + +# 在程序启动时加载数据到内存中 +load_data_to_memory() + async def process_retrieval_from_knowledge_base(data): + doi = data.get('doi') mp_id = data.get('mp_id') @@ -128,31 +173,12 @@ async def process_retrieval_from_knowledge_base(data): if doi is None and mp_id is None: return "" # 如果没有提供查询参数,返回空字符串 - # 构建SQL查询条件 - query = "SELECT * FROM mp_synthesis_scheme_info WHERE " - params = [] - - if doi is not None and mp_id is not None: - query += "doi = %s OR mp_id = %s" - params = [doi, mp_id] - elif doi is not None: - query += "doi = %s" - params = [doi] - else: # mp_id is not None - query += "mp_id = %s" - params = [mp_id] - - # 从数据库中查询匹配的记录 - conn = connection_pool.get_connection() - try: - cursor = conn.cursor(dictionary=True) - try: - cursor.execute(query, params) - result = cursor.fetchone() # 获取第一个匹配的记录 - finally: - cursor.close() - finally: - conn.close() + # 从内存缓存中查询匹配的记录 + result = None + if doi is not None and doi in memory_cache: + result = memory_cache[doi] + elif mp_id is not None and mp_id in memory_cache: + result = memory_cache[mp_id] # 检查是否找到匹配的记录 if not result: @@ -265,13 +291,13 @@ def worker(data, output_file_path): arguments_data = func.get("arguments") # 使用富文本打印函数名 - print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + #print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") # 使用富文本打印参数 - print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") + #print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") if func.get("name") == 'retrieval_from_knowledge_base': - pass + # delay_time = random.uniform(5, 10) # time.sleep(delay_time) result = asyncio.run(process_retrieval_from_knowledge_base(data)) @@ -281,60 +307,39 @@ def worker(data, output_file_path): formatted_results.append(formatted_result) elif func.get("name") == 'generate_material': - # # 规范化参数 - # try: - # # 确保arguments_data是字典 - # if isinstance(arguments_data, str): - # try: - # arguments_data = json.loads(arguments_data) - # except json.JSONDecodeError as e: - # print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") - # continue + try: + # 确保arguments_data是字典 + if isinstance(arguments_data, str): + try: + arguments_data = json.loads(arguments_data) + except json.JSONDecodeError as e: + #print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") + continue - # # 规范化参数 - # normalized_args = normalize_material_args(arguments_data) - # # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") - # # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") - # # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + # 规范化参数 + normalized_args = normalize_material_args(arguments_data) - # # 优先使用mattergen函数 - # try: - # # output = asyncio.run(generate_material(**normalized_args)) - # output = generate_material(**normalized_args) + # 优先使用mattergen函数 + try: - # # 添加延迟,模拟额外的工具函数调用 + output = generate_material(**normalized_args) - # # 随机延迟5-10秒 - # # delay_time = random.uniform(5, 10) - # # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") - # # time.sleep(delay_time) - # # # 模拟其他工具函数调用的日志输出 - # # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") - # # time.sleep(0.5) - # # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") - # # time.sleep(0.5) - # # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") - # # time.sleep(0.5) - # # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") + except Exception as e: + #print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") + continue + # 将结果添加到func_results + func_results.append({"function": func_name, "result": output}) - # except Exception as e: - # print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") - - # # 将结果添加到func_results - # func_results.append({"function": func_name, "result": output}) - - # # 格式化结果 - # formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" - # formatted_results.append(formatted_result) - # except Exception as e: - # print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") - # import traceback - # print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") - pass + # 格式化结果 + formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" + formatted_results.append(formatted_result) + except Exception as e: + #print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") + import traceback + #print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") + continue else: - # delay_time = random.uniform(5, 10) - # time.sleep(delay_time) result = asyncio.run(execute_tool_from_dict(func)) func_results.append({"function": func['name'], "result": result}) @@ -347,17 +352,17 @@ def worker(data, output_file_path): final_result = "\n\n\n".join(formatted_results) data['observation'] = final_result - # 使用富文本打印开始和结束标记 - print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") - print(data['observation']) - print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") + #使用富文本打印开始和结束标记 + # print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") + # print(data['observation']) + # print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") with file_lock: with jsonlines.open(output_file_path, mode='a') as writer: writer.write(data) # observation . data return f"Processed successfully" except Exception as e: - print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") + #print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") return f"Error processing: {str(e)}" @@ -365,7 +370,6 @@ def main(datas, output_file_path, max_workers=1): import random from tqdm import tqdm import os - from mysql.connector import pooling, Error # 创建进度条 pbar = tqdm(total=len(datas), desc="Processing papers") @@ -403,21 +407,26 @@ def main(datas, output_file_path, max_workers=1): print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") + if __name__ == '__main__': import datetime import jsonlines - datas = [] + datas_with_solution = [] + datas_without_solution = [] + with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader: for obj in reader: - datas.append(obj) + if obj['solution']!='': + datas_with_solution.append(obj) + else: + datas_without_solution.append(obj) - print(len(datas)) - # print() - output_file = f"./filter_ok_questions_solutions_agent_other_tools_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" + datas_with_solution = datas_with_solution[:5000] + datas_without_solution = datas_without_solution[:5000] + + datas = datas_with_solution + datas_without_solution + import random + random.shuffle(datas) + + output_file = f"./filter_ok_questions_solutions_agent_data10000_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" main(datas, output_file, max_workers=32) - - # 示例1:使用正确的JSON格式 - # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' - # argument = json.loads(argument) - # print(json.dumps(argument, indent=2)) - # asyncio.run(mattergen(**argument)) diff --git a/execute_tool_copy.py b/generate_data/generate_tool_observation.py old mode 100644 new mode 100755 similarity index 74% rename from execute_tool_copy.py rename to generate_data/generate_tool_observation.py index 533fb8a..02eb9e1 --- a/execute_tool_copy.py +++ b/generate_data/generate_tool_observation.py @@ -17,7 +17,7 @@ import random # 初始化colorama init(autoreset=True) -from typing import Dict, Union, Any, Optional +from typing import Dict, Union, Any, Optional, List def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: """ @@ -110,6 +110,8 @@ def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: import requests + +# 创建数据库连接池(仅用于初始加载数据) connection_pool = pooling.MySQLConnectionPool( pool_name="mypool", pool_size=32, @@ -120,7 +122,50 @@ connection_pool = pooling.MySQLConnectionPool( database='metadata_mat_papers' ) +# 内存缓存,用于存储从数据库加载的数据 +# 结构: {doi: record, mp_id: record} +memory_cache = {} + +def load_data_to_memory(): + """ + 从数据库加载所有数据到内存中 + """ + print(f"{Fore.CYAN}{Style.BRIGHT}正在从数据库加载数据到内存中...{Style.RESET_ALL}") + + conn = connection_pool.get_connection() + try: + cursor = conn.cursor(dictionary=True) + try: + # 查询所有记录 + cursor.execute("SELECT * FROM mp_synthesis_scheme_info") + records = cursor.fetchall() + + # 将记录添加到内存缓存中 + for record in records: + doi = record.get('doi') + mp_id = record.get('mp_id') + + # 使用doi作为键(如果存在) + if doi: + memory_cache[doi] = record + + # 使用mp_id作为键(如果存在) + if mp_id: + memory_cache[mp_id] = record + + print(f"{Fore.GREEN}{Style.BRIGHT}成功加载 {len(records)} 条记录到内存中{Style.RESET_ALL}") + print(f"{Fore.GREEN}{Style.BRIGHT}内存缓存中的键数量: {len(memory_cache)}{Style.RESET_ALL}") + + finally: + cursor.close() + finally: + conn.close() + +# 在程序启动时加载数据到内存中 +load_data_to_memory() + async def process_retrieval_from_knowledge_base(data): + doi = data.get('doi') mp_id = data.get('mp_id') @@ -128,31 +173,12 @@ async def process_retrieval_from_knowledge_base(data): if doi is None and mp_id is None: return "" # 如果没有提供查询参数,返回空字符串 - # 构建SQL查询条件 - query = "SELECT * FROM mp_synthesis_scheme_info WHERE " - params = [] - - if doi is not None and mp_id is not None: - query += "doi = %s OR mp_id = %s" - params = [doi, mp_id] - elif doi is not None: - query += "doi = %s" - params = [doi] - else: # mp_id is not None - query += "mp_id = %s" - params = [mp_id] - - # 从数据库中查询匹配的记录 - conn = connection_pool.get_connection() - try: - cursor = conn.cursor(dictionary=True) - try: - cursor.execute(query, params) - result = cursor.fetchone() # 获取第一个匹配的记录 - finally: - cursor.close() - finally: - conn.close() + # 从内存缓存中查询匹配的记录 + result = None + if doi is not None and doi in memory_cache: + result = memory_cache[doi] + elif mp_id is not None and mp_id in memory_cache: + result = memory_cache[mp_id] # 检查是否找到匹配的记录 if not result: @@ -265,62 +291,43 @@ def worker(data, output_file_path): arguments_data = func.get("arguments") # 使用富文本打印函数名 - print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + #print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") # 使用富文本打印参数 - print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") + #print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") if func.get("name") == 'retrieval_from_knowledge_base': - pass + # delay_time = random.uniform(5, 10) # time.sleep(delay_time) - # result = asyncio.run(process_retrieval_from_knowledge_base(data)) - # func_results.append({"function": func['name'], "result": result}) - # # 格式化结果 - # formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" - # formatted_results.append(formatted_result) + result = asyncio.run(process_retrieval_from_knowledge_base(data)) + func_results.append({"function": func['name'], "result": result}) + # 格式化结果 + formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + formatted_results.append(formatted_result) elif func.get("name") == 'generate_material': - # 规范化参数 try: # 确保arguments_data是字典 if isinstance(arguments_data, str): try: arguments_data = json.loads(arguments_data) except json.JSONDecodeError as e: - print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") + #print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") continue # 规范化参数 normalized_args = normalize_material_args(arguments_data) - # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") - # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") - # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") # 优先使用mattergen函数 try: - # output = asyncio.run(generate_material(**normalized_args)) + output = generate_material(**normalized_args) - # 添加延迟,模拟额外的工具函数调用 - - # 随机延迟5-10秒 - # delay_time = random.uniform(5, 10) - # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") - # time.sleep(delay_time) - - # # 模拟其他工具函数调用的日志输出 - # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") - # time.sleep(0.5) - # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") - # time.sleep(0.5) - # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") - # time.sleep(0.5) - # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") except Exception as e: - print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") - + #print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") + continue # 将结果添加到func_results func_results.append({"function": func_name, "result": output}) @@ -328,36 +335,34 @@ def worker(data, output_file_path): formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" formatted_results.append(formatted_result) except Exception as e: - print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") + #print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") import traceback - print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") - + #print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") + continue else: - # delay_time = random.uniform(5, 10) - # time.sleep(delay_time) - pass - # result = asyncio.run(execute_tool_from_dict(func)) - # func_results.append({"function": func['name'], "result": result}) - # # 格式化结果 - # func_name = func.get("name") - # formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" - # formatted_results.append(formatted_result) + + result = asyncio.run(execute_tool_from_dict(func)) + func_results.append({"function": func['name'], "result": result}) + # 格式化结果 + func_name = func.get("name") + formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + formatted_results.append(formatted_result) # 将所有格式化后的结果连接起来 final_result = "\n\n\n".join(formatted_results) data['observation'] = final_result - # 使用富文本打印开始和结束标记 - print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") - print(data['observation']) - print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") + #使用富文本打印开始和结束标记 + # print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") + # print(data['observation']) + # print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") with file_lock: with jsonlines.open(output_file_path, mode='a') as writer: writer.write(data) # observation . data return f"Processed successfully" except Exception as e: - print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") + #print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") return f"Error processing: {str(e)}" @@ -365,7 +370,6 @@ def main(datas, output_file_path, max_workers=1): import random from tqdm import tqdm import os - from mysql.connector import pooling, Error # 创建进度条 pbar = tqdm(total=len(datas), desc="Processing papers") @@ -409,12 +413,13 @@ if __name__ == '__main__': datas = [] with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader: for obj in reader: - datas.append(obj) + #if obj['solution']!='': + datas.append(obj) print(len(datas)) # print() - output_file = f"./filter_ok_questions_solutions_agent_mattergen_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" - main(datas, output_file, max_workers=1) + output_file = f"./filter_ok_questions_solutions_agent_data_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" + main(datas, output_file, max_workers=32) # 示例1:使用正确的JSON格式 # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' diff --git a/generate_data/grpo_tools.py b/generate_data/grpo_tools.py new file mode 100755 index 0000000..e06cc70 --- /dev/null +++ b/generate_data/grpo_tools.py @@ -0,0 +1,91 @@ +import jsonlines +import argparse +import generate_data.utils as utils +import glob +import json +from ase import io +import tempfile +import re +from pymatgen.io.vasp import Poscar +from pymatgen.io.cif import CifParser +import threading +import concurrent.futures +import copy +from grpo_utils import generate_design_question, generate_props_question, generate_obs_response +# Create a lock for file writing +file_lock = threading.Lock() + + +def worker(data, output_file_path): + try: + messages = copy.deepcopy(data['messages']) + obs = data['observation'] + messages[-1]['content'] = messages[-1]['content'].split("")[-1].split("")[0] + messages.append({"role": "user", "content": obs}) + data['messages'].append({"role": "user", "content": obs}) + # print(messages) + # print(obs) + + reasoning_content, response = generate_obs_response(messages) + data['messages'].append({"role": "assistant", "content": f"\n{reasoning_content}\n\n{response}\n"}) + # Use the lock to safely write to the file + with file_lock: + with jsonlines.open(output_file_path, mode='a') as writer: + writer.write(messages) + + return f"Processed successfully" + except Exception as e: + return f"Error processing: {str(e)}" + + +def main(input_file_path, output_file_path, max_workers=1): + import random + from tqdm import tqdm + import os + + datas = None + with jsonlines.open(input_file_path, mode='r') as reader: + datas = [line for line in reader] + + # 创建进度条 + pbar = tqdm(total=len(datas), desc="Processing CIF files") + + # 创建一个线程池 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交任务到执行器 + future_to_data = {} + for data in datas: + future = executor.submit(worker, data, output_file_path) + future_to_data[future] = data + + # 处理结果 + completed = 0 + failed = 0 + for future in concurrent.futures.as_completed(future_to_data): + data = future_to_data[future] + try: + result = future.result() + if "successfully" in result: + completed += 1 + else: + failed += 1 + # 更新进度条 + pbar.update(1) + # 每100个文件更新一次统计信息 + if (completed + failed) % 100 == 0: + pbar.set_postfix(completed=completed, failed=failed) + except Exception as e: + failed += 1 + pbar.update(1) + print(f"\nWorker for {data} generated an exception: {e}") + + pbar.close() + print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") + + +if __name__ == '__main__': + import datetime + origin_file = "/home/ubuntu/50T/lzy/mars-mcp/filter_ok_questions_solutions_agent_tools_20250408214808 copy.jsonl" + output_file = f"agent_questions_solutions_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" + main(origin_file, output_file) + diff --git a/generate_data/grpo_utils.py b/generate_data/grpo_utils.py new file mode 100755 index 0000000..9f16107 --- /dev/null +++ b/generate_data/grpo_utils.py @@ -0,0 +1,294 @@ +import jsonlines +import argparse +import generate_data.utils as utils +import glob +import json +from ase import io +import tempfile +import re +from pymatgen.io.vasp import Poscar +from pymatgen.io.cif import CifParser +import threading +import concurrent.futures + +# Create a lock for file writing +file_lock = threading.Lock() + + + +def generate_design_question(crystal_desc, cif_info, crystal_props, max_retries=3, initial_backoff=1.0): + instruction = """ +{crystal_desc} + +### 对应的晶体结构数据(CIF)如下: +{cif_info} + +### 该晶体结构的物理化学性质为: +{crystal_props} + +根据如上信息,我现在需要给材料科学的博士考试出题,问题要求博士们回答出上文中的完整CIF文件,如果是你你会如何出题? +也就是说,要求我们提出的问题的答案是上文中提及的完整CIF文件。当然,你的问题必须给定充足的该晶体结构的相关信息。 +但是相关信息应该抽象和隐晦,避免过于直白,除明确的化学表达式外,尽量避免过多的精确信息,让博士考生们可以通过推理得到某些信息以增加问题的难度。 +问题的语言一半是中文,一半是英文,以便更好地与模型进行交互。 + +请先生成10个问题示例,再挑选2个最好的问题示例并遵循如下格式输出: +```json +{ + "selected_questions": [ + { + "question_id": 1, + "question_text": "问题1的完整内容...", + }, + { + "question_id": 2, + "question_text": "问题2的完整内容...", + } + ] +} +""" + instruction = instruction.replace("{crystal_desc}", crystal_desc).replace("{cif_info}", cif_info).replace("{crystal_props}", crystal_props) + messages=[ + {"role": "system", "content": ""}, + {"role": "user", "content": instruction} + ] + import time + start_time = time.time() + _response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff) + # reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff) + # print(f"Time: {time.time() - start_time}") + if _response == 'apierror' or _response == 'unexpectederror': + return _response + # 尝试从响应中提取JSON部分 + json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL) + if json_match: + json_str = json_match.group(1) + try: + questions_data = json.loads(json_str) + return questions_data + except json.JSONDecodeError: + # 如果JSON解析失败,尝试清理字符串后再次解析 + cleaned_json = re.sub(r'[\n\r\t]', '', json_str) + try: + questions_data = json.loads(cleaned_json) + return questions_data + except: + return {"error": "Failed to parse JSON response", "raw_response": _response} + else: + # 如果没有找到JSON格式,返回原始响应 + return {"error": "No JSON format found in response", "raw_response": _response} + + + +def generate_props_question(crystal_desc, cif_info, crystal_props, max_retries=3, initial_backoff=1.0): + instruction = """ +{crystal_desc} + +### 对应的晶体结构数据(CIF)如下: +{cif_info} + +### 该晶体结构的物理化学性质为: +{crystal_props} + +根据如上信息,我现在需要给材料科学的博士考试出题,问题要求博士们根据CIF文件回答出上文中的物理化学性质,如果是你你会如何出题? +也就是说,要求我们提出的问题的答案是上文中提及的物理化学性质。当然,你的问题必须尽量包含一个标签代表给定的CIF文件。 +让博士考生们根据给定的CIF文件通过深入思考和推理去分析该种晶体材料在上文所提及的全部物理化学性质,并用JSON格式回答全部的物理化学性质。 +问题的语言一半是中文,一半是英文,以便更好地与模型进行交互。 + +示例的问题: +1. \n,根据上文提供的CIF文件,请你xxx +2. 根据下文提供的CIF文件,请你xxx\n <> + +请先生成10个问题示例,再挑选2个最好的问题示例并遵循如下格式输出: +```json +{ + "selected_questions": [ + { + "question_id": 1, + "question_text": "问题1的完整内容...", + }, + { + "question_id": 2, + "question_text": "问题2的完整内容...", + } + ] +} +``` +""" + instruction = instruction.replace("{crystal_desc}", crystal_desc).replace("{cif_info}", cif_info).replace("{crystal_props}", crystal_props) + messages=[ + {"role": "system", "content": ""}, + {"role": "user", "content": instruction} + ] + import time + start_time = time.time() + _response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff) + # reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff) + # print(f"Time: {time.time() - start_time}") + if _response == 'apierror' or _response == 'unexpectederror': + return _response + # 尝试从响应中提取JSON部分 + json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL) + if json_match: + json_str = json_match.group(1) + try: + questions_data = json.loads(json_str) + return questions_data + except json.JSONDecodeError: + # 如果JSON解析失败,尝试清理字符串后再次解析 + cleaned_json = re.sub(r'[\n\r\t]', '', json_str) + try: + questions_data = json.loads(cleaned_json) + return questions_data + except: + return {"error": "Failed to parse JSON response", "raw_response": _response} + else: + # 如果没有找到JSON格式,返回原始响应 + return {"error": "No JSON format found in response", "raw_response": _response} + + +def generate_papers_other_question(paper_info, max_retries=3, initial_backoff=1.0): + instruction = """ +{paper_info} + +根据如上信息,我现在需要给材料科学的博士学生出题,问题要求考察博士对该材料的反应方程式、结构、性能和应用是否完全掌握,如果是你你会怎么出题? +你的问题里面应该包含该材料相关的合适的信息,且是自包含的(在只有问题的情况下问题中的关键信息不遗漏),但问题需要有难度和深度,需要博士生们深入思考和推理后才能作为准确的回答。 +由于问题面向博士,因此,提出的问题需要一定的科研价值导向。涉及到反应方程式、关于结构、性能和应用等方面的具体试剂量等信息时,要求他们尽可能给出精确的数值(前提是这些数值在上文中存在)。 + + +请先生成12个问题示例,12个问题的语言一半是中文,一半是英文,再挑选4个最好的问题示例并遵循如下格式输出: +```json +{ + "selected_questions": [ + { + "question_id": 1, + "question_text": "问题1的完整内容...", + "question_type": "问题1的类型", # reaction_string; structure; performence; application + }, + { + "question_id": 2, + "question_text": "问题2的完整内容...", + "question_type": "问题1的类型", # reaction_string; structure; performence; application + }, ... + ] +} +""" + instruction = instruction.replace("{paper_info}", paper_info) + messages=[ + {"role": "system", "content": ""}, + {"role": "user", "content": instruction} + ] + import time + start_time = time.time() + _response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff) + # reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff) + # print(f"Time: {time.time() - start_time}") + if _response == 'apierror' or _response == 'unexpectederror': + return _response + # 尝试从响应中提取JSON部分 + json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL) + if json_match: + json_str = json_match.group(1) + try: + questions_data = json.loads(json_str) + return questions_data + except json.JSONDecodeError: + # 如果JSON解析失败,尝试清理字符串后再次解析 + cleaned_json = re.sub(r'[\n\r\t]', '', json_str) + try: + questions_data = json.loads(cleaned_json) + return questions_data + except: + return {"error": "Failed to parse JSON response", "raw_response": _response} + else: + # 如果没有找到JSON格式,返回原始响应 + return {"error": "No JSON format found in response", "raw_response": _response} + + + +def generate_papers_synthesis_question(paper_info, max_retries=3, initial_backoff=1.0): + instruction = """ +{paper_info} + +根据如上信息,我现在需要给材料科学的博士学生出题,问题要求考察博士是否完全掌握该材料的合成方案,是否完全掌握给定材料的结构和性能到合成方案的精准映射关系,如果是你你会怎么出题? +你的问题里面应该包含该材料充分的结构和性能信息,问题需要有难度和深度,需要博士生们深入思考和推理后才能给出准确的合成方案并整理成JSON格式的格式化合成方案。 +由于问题面向博士,因此,提出的问题需要一定的科研价值导向,并且要求博士在回答该材料的合成方案时给出精确的数值(包括试剂、前驱体、容器、温度等合成条件)。 +问题中作为条件信息的部分需要尽可能的在问题中明确而不是隐晦(你要考虑到博士们拿到问题的时候并不知道上文中的信息,所以类似“基于给定的材料结构和性能信息”这种问法应该尽量避免)。 + +请先生成6个问题示例,6个问题的语言一半是中文,一半是英文,再挑选2个最好的问题示例并遵循如下格式输出: +```json +{ + "selected_questions": [ + { + "question_id": 1, + "question_text": "问题1的完整内容...", + }, + { + "question_id": 2, + "question_text": "问题2的完整内容...", + }, + ] +} +""" + instruction = instruction.replace("{paper_info}", paper_info) + messages=[ + {"role": "system", "content": ""}, + {"role": "user", "content": instruction} + ] + import time + start_time = time.time() + _response = utils.get_response_from_llm(messages, model_name="deepseek-v3", max_retries=max_retries, initial_backoff=initial_backoff) + # reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff) + # print(f"Time: {time.time() - start_time}") + if _response == 'apierror' or _response == 'unexpectederror': + return _response + # 尝试从响应中提取JSON部分 + json_match = re.search(r'```json\s*(.*?)\s*```', _response, re.DOTALL) + if json_match: + json_str = json_match.group(1) + try: + questions_data = json.loads(json_str) + return questions_data + except json.JSONDecodeError: + # 如果JSON解析失败,尝试清理字符串后再次解析 + cleaned_json = re.sub(r'[\n\r\t]', '', json_str) + try: + questions_data = json.loads(cleaned_json) + return questions_data + except: + return {"error": "Failed to parse JSON response", "raw_response": _response} + else: + # 如果没有找到JSON格式,返回原始响应 + return {"error": "No JSON format found in response", "raw_response": _response} + + +def generate_function_call(messages, tools, max_retries=3, initial_backoff=1.0): + + import time + start_time = time.time() + instruction = """ +# 问题 +{question} + +# 指令 +在准确的回答上述问题之前,你只有现在这一次机会允许你调用工具以获取更多信息。 +请尽可能深入思考上述问题,并尽可能的调用多个提供给你的工具查询该问题的相关信息,而不是直接回答该问题。 +因此,你需要在回答中一次给出多个经过思考后的工具调用,以便更好地回答上述问题。 +思考和回答时使用和问题相同的语言。 +""" + messages[0]["content"] = instruction.replace("{question}", messages[0]["content"]) + + _response, functions = utils.get_response_from_qwq(messages, model_name="qwq-32b", tools=tools, max_retries=max_retries, initial_backoff=initial_backoff) + # reasoning_content, _response = utils.get_response_from_deepseek_r1(messages, max_retries=max_retries, initial_backoff=initial_backoff) + # print(f"Time: {time.time() - start_time}") + # print(_response) + # if _response == 'apierror' or _response == 'unexpectederror': + # return _response + return _response, functions + + + +def generate_obs_response(messages, max_retries=3, initial_backoff=1.0): + import time + start_time = time.time() + _reasoning_content, response = utils.get_response_from_deepseek_r1(messages, prefix=False, max_retries=max_retries, initial_backoff=initial_backoff) + return _reasoning_content, response diff --git a/generate_data/utils.py b/generate_data/utils.py new file mode 100755 index 0000000..544d958 --- /dev/null +++ b/generate_data/utils.py @@ -0,0 +1,800 @@ +""" +This script generates questions and answers from a given set of CIFs. +It uses the OpenAI API and MySQL for storing and retrieving data. +@author: Yutang Li +""" +import multiprocessing +import sqlite3 +import tiktoken +import re +from fractions import Fraction +import numpy as np +import glob +import tqdm +import copy +import json +import time +import random +from openai import OpenAI, APIError, RateLimitError +from mysql.connector import pooling, Error +from collections import Counter + + + +def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, max_retries: int = 3, initial_backoff: float = 1.0): + """ + Get response from DeepSeek API with retry mechanism. + + Args: + messages: List of message dictionaries + prefix: Whether to use the prefix URL + max_retries: Maximum number of retry attempts + initial_backoff: Initial backoff time in seconds + + Returns: + Tuple of (reasoning_content, content) or error messages + """ + retries = 0 + while retries <= max_retries: + try: + base_url = "https://api.deepseek.com/beta" if prefix else "https://vip.apiyi.com/v1" + api_key = "sk-59279cc16ec740089146ef9aef9c1671" if prefix else "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d" + + client = OpenAI(api_key=api_key, base_url=base_url) + # messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] + + response = client.chat.completions.create( + model="deepseek-r1", + messages=messages, + temperature=0.6 + ) + + # reasoning_content = "null" if prefix else "\n" + response.choices[0].message.model_extra['reasoning_content'] + "\n\n" + reasoning_content = response.choices[0].message.content.split("\n")[0].split("\n")[-1] + content = response.choices[0].message.content.split("\n")[-1] + return reasoning_content, content + + except RateLimitError as rate_error: + retries += 1 + if retries > max_retries: + print(f"Max retries exceeded for RateLimitError: {rate_error}") + return 'apierror', 'apierror' + + # Exponential backoff with jitter + backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) + print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})") + time.sleep(backoff_time) + + except APIError as api_error: + retries += 1 + if retries > max_retries: + print(f"Max retries exceeded for APIError: {api_error}") + return 'apierror', 'apierror' + + # Check if the error is retryable + error_str = str(api_error) + if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower(): + # Exponential backoff with jitter + backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) + print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}") + time.sleep(backoff_time) + else: + # Non-retryable API error + print(f"Non-retryable API error: {api_error}") + return 'apierror', 'apierror' + + except Exception as e: + print(f"generate_design_question Unexpected error: {e}") + return 'unexpectederror', 'unexpectederror' + + +def get_response_from_llm(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0): + """ + Get response from LLM API with retry mechanism. + + Args: + messages: List of message dictionaries + model_name: Name of the model to use + tools: Optional list of tools to use + max_retries: Maximum number of retry attempts + initial_backoff: Initial backoff time in seconds + + Returns: + Content of the response or error message + """ + retries = 0 + while retries <= max_retries: + try: + client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1") + # messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] + if tools is None: + response = client.chat.completions.create( + model=model_name, + messages=messages, + ) + else: + response = client.chat.completions.create( + model=model_name, + messages=messages, + tools=tools, + tool_choice='auto', + parallel_tool_calls=True + ) + content = response.choices[0].message.content + return content + + except RateLimitError as rate_error: + retries += 1 + if retries > max_retries: + print(f"Max retries exceeded for RateLimitError: {rate_error}") + return 'apierror' + + # Exponential backoff with jitter + backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) + print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})") + time.sleep(backoff_time) + + except APIError as api_error: + retries += 1 + if retries > max_retries: + print(f"Max retries exceeded for APIError: {api_error}") + return 'apierror' + + # Check if the error is retryable + error_str = str(api_error) + if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower(): + # Exponential backoff with jitter + backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) + print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}") + time.sleep(backoff_time) + else: + # Non-retryable API error + print(f"Non-retryable API error: {api_error}") + return 'apierror' + + except Exception as e: + print(f"generate_design_question Unexpected error: {e}") + return 'unexpectederror' + +def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0): + """ + Get response from LLM API with retry mechanism. + + Args: + messages: List of message dictionaries + model_name: Name of the model to use + tools: Optional list of tools to use + max_retries: Maximum number of retry attempts + initial_backoff: Initial backoff time in seconds + + Returns: + Content of the response or error message + """ + retries = 0 + while retries <= max_retries: + try: + # client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1") + # client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") + import random + if random.random() > 0.5: + client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") + else: + client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") + # messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] + if tools is None: + response = client.chat.completions.create( + model=model_name, + messages=messages, + stream=True + ) + else: + response = client.chat.completions.create( + model=model_name, + messages=messages, + tools=tools, + tool_choice='auto', + parallel_tool_calls=True, + stream=True + ) + + reasoning_content = "" # 定义完整思考过程 + answer_content = "" # 定义完整回复 + tool_info = [] # 存储工具调用信息 + is_answering = False # 判断是否结束思考过程并开始回复 + # print("="*20+"思考过程"+"="*20) + for chunk in response: + # if not chunk.choices: + # # 处理用量统计信息 + # print("\n"+"="*20+"Usage"+"="*20) + # print(chunk.usage) + # else: + delta = chunk.choices[0].delta + # 处理AI的思考过程(链式推理) + if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None: + reasoning_content += delta.reasoning_content + # print(delta.reasoning_content,end="",flush=True) # 实时输出思考过程 + + # 处理最终回复内容 + else: + if not is_answering: # 首次进入回复阶段时打印标题 + is_answering = True + # print("\n"+"="*20+"回复内容"+"="*20) + if delta.content is not None: + answer_content += delta.content + # print(delta.content,end="",flush=True) # 流式输出回复内容 + + # 处理工具调用信息(支持并行工具调用) + if delta.tool_calls is not None: + for tool_call in delta.tool_calls: + index = tool_call.index # 工具调用索引,用于并行调用 + + # 动态扩展工具信息存储列表 + while len(tool_info) <= index: + tool_info.append({}) + + # 收集工具调用ID(用于后续函数调用) + # if tool_call.id: + # tool_info[index]['id'] = tool_info[index].get('id', '') + tool_call.id + + # 收集函数名称(用于后续路由到具体函数) + if tool_call.function and tool_call.function.name: + tool_info[index]['name'] = tool_info[index].get('name', '') + tool_call.function.name + + # 收集函数参数(JSON字符串格式,需要后续解析) + if tool_call.function and tool_call.function.arguments: + tool_info[index]['arguments'] = tool_info[index].get('arguments', '') + tool_call.function.arguments + + tools_response = "" + for tool in tool_info: + tools_response += ("\n" + json.dumps(tool, ensure_ascii=False) + "\n\n") + response = "\n" + reasoning_content + "\n\n" + "\n" + answer_content + tools_response + "\n\n" + return response, tool_info + # return reasoning_content, answer_content, tool_info + + except RateLimitError as rate_error: + retries += 1 + if retries > max_retries: + print(f"Max retries exceeded for RateLimitError: {rate_error}") + return 'apierror', [] + + # Exponential backoff with jitter + backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) + print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})") + time.sleep(backoff_time) + + except APIError as api_error: + retries += 1 + if retries > max_retries: + print(f"Max retries exceeded for APIError: {api_error}") + return 'apierror', [] + + # Check if the error is retryable + error_str = str(api_error) + if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower(): + # Exponential backoff with jitter + backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) + print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}") + time.sleep(backoff_time) + else: + # Non-retryable API error + print(f"Non-retryable API error: {api_error}") + return 'apierror', [] + + except Exception as e: + print(f"generate_design_question Unexpected error: {e}") + return 'unexpectederror', [] + + + +def read_json_file(file_path): + """Read the json file and return its content.""" + try: + with open(file_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"Error reading file {file_path}: {e}") + return None + + + +################################## utils + +def clean_all_repetitions_with_details(text, min_length=10, threshold=10): + """ + 综合清理文本中的各种重复内容,并返回详细信息 + + 参数: + - text: 要清理的文本 + - min_length: 最小重复片段长度 + - threshold: 重复内容的阈值 + + 返回: + - cleaned_text: 清理后的文本 + - is_repetitive: 是否检测到重复 + - repetition_details: 重复内容的详细信息 + """ + original_text = text + is_repetitive = False + repetition_details = [] + + # 1. 首先处理有换行符的重复 + if '\n' in text: + lines = text.split('\n') + unique_lines = [] + line_counts = {} + + for i, line in enumerate(lines): + normalized = line.strip().lower() + if not normalized: + unique_lines.append(line) + continue + + line_counts[normalized] = line_counts.get(normalized, 0) + 1 + + if line_counts[normalized] <= threshold: + unique_lines.append(line) + + # 如果这是第一次超过阈值,记录重复详情 + if line_counts[normalized] == threshold + 1: + # 找到原始形式(保留大小写) + original_form = None + for l in lines[:i]: + if l.strip().lower() == normalized: + original_form = l + break + + if original_form is None: + original_form = line + + repetition_details.append({ + 'type': 'line_repetition', + 'repeated_string': original_form, + 'repeat_count': line_counts[normalized] + }) + + if any(count > threshold for count in line_counts.values()): + text = '\n'.join(unique_lines) + is_repetitive = True + + # 2. 处理同一行内的连续重复模式 + for length in range(min_length, 101): + pattern = r'(.{' + str(length) + r'})(\1)+' + + while True: + match = re.search(pattern, text) + if not match: + break + + repeated_part = match.group(1) + full_match = match.group(0) + + # 计算重复次数 + repeat_count = len(full_match) // len(repeated_part) + + # 记录重复详情 + repetition_details.append({ + 'type': 'inline_repetition', + 'repeated_string': repeated_part, + 'repeat_count': repeat_count, + 'total_length': len(full_match), + 'position': match.start() + }) + + text = text.replace(full_match, repeated_part) + is_repetitive = True + + # 3. 处理句子级别的重复 + sentences = re.split(r'(?<=[.!?。?!])\s+', text) + if len(sentences) > 1: + sentence_counter = Counter(sentences) + + for sentence, count in sentence_counter.items(): + if count > threshold: + repetition_details.append({ + 'type': 'sentence_repetition', + 'repeated_string': sentence, + 'repeat_count': count + }) + + if any(count > threshold for count in sentence_counter.values()): + unique_sentences = [] + seen_sentences = {} + + for sentence in sentences: + seen_sentences[sentence] = seen_sentences.get(sentence, 0) + 1 + if seen_sentences[sentence] <= threshold: + unique_sentences.append(sentence) + + # 重新组合文本 + text = ' '.join(unique_sentences) + is_repetitive = True + + # 4. 处理更短的重复(如果前面的方法没有检测到重复) + if not is_repetitive and min_length > 5: + for length in range(5, min_length): + pattern = r'(.{' + str(length) + r'})(\1){2,}' # 至少重复3次才处理 + + while True: + match = re.search(pattern, text) + if not match: + break + + repeated_part = match.group(1) + full_match = match.group(0) + + # 计算重复次数 + repeat_count = len(full_match) // len(repeated_part) + + # 记录重复详情 + repetition_details.append({ + 'type': 'short_repetition', + 'repeated_string': repeated_part, + 'repeat_count': repeat_count, + 'total_length': len(full_match), + 'position': match.start() + }) + + text = text.replace(full_match, repeated_part) + is_repetitive = True + + # 按重复类型和长度排序 + repetition_details.sort(key=lambda x: (-len(x['repeated_string']), x['type'])) + + return text, is_repetitive or text != original_text, repetition_details + +def create_table(table_name, connection_pool): + """Create the required MySQL table if it does not exist.""" + db = connection_pool.get_connection() + cursor = db.cursor() + create_table_query = f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + id INT AUTO_INCREMENT PRIMARY KEY, + mp_id TEXT, + question_model TEXT, + question TEXT, + answer_model TEXT, + answer TEXT, + answer_len INT + ) + """ + cursor.execute(create_table_query) + db.commit() + cursor.close() + db.close() + +def record_exists(mp_id, table_name, connection_pool): + """Check if a mp_id already exists in the table.""" + db = connection_pool.get_connection() + cursor = db.cursor() + query = f"SELECT * FROM {table_name} WHERE mp_id = %s" + cursor.execute(query, (mp_id,)) + result = cursor.fetchone() + cursor.fetchall() # Ensure all results are processed + cursor.close() + db.close() + return result is not None + +def insert_record(entry, table_name, connection_pool): + """Insert a record into the MySQL table.""" + db = None + cursor = None + try: + db = connection_pool.get_connection() + cursor = db.cursor() + + insert_query = f""" + INSERT INTO {table_name} + (mp_id, question_model, question, answer_model, answer, answer_len) + VALUES (%s, %s, %s, %s, %s, %s) + """ + values = ( + entry["mp_id"], entry["question_model"], + entry["question"], entry["answer_model"], entry["answer"], entry["answer_len"], + ) + cursor.execute(insert_query, values) + db.commit() + + except Error as e: + print(f"Error: {e}") + db.rollback() + finally: + # Ensure cursor is closed + if cursor: + cursor.close() + # Ensure connection is returned to the pool + if db: + db.close() + + +# Initialize SQLite database connection +def initialize_db(): + conn = sqlite3.connect('multi_turns_data.db', check_same_thread=False) + cursor = conn.cursor() + cursor.execute(''' + CREATE TABLE IF NOT EXISTS conversations ( + mp_id TEXT PRIMARY KEY, + sample TEXT, + token_num INTEGER + ) + ''') + conn.commit() + return conn + +# Save sample to SQLite database +def save_to_db(conn, mp_id, sample, total_token): + cursor = conn.cursor() + cursor.execute(''' + INSERT OR REPLACE INTO conversations (mp_id, sample, token_num) + VALUES (?, ?, ?) + ''', (mp_id, str(sample), total_token)) + conn.commit() + + +def read_cif_txt_file(file_path): + """Read the markdown file and return its content.""" + try: + with open(file_path, 'r', encoding='utf-8') as f: + return f.read() + except Exception as e: + print(f"Error reading file {file_path}: {e}") + return None + +def round_values(data, precision=3): + """ + 递归地将字典中的所有值保留三位小数 + """ + if isinstance(data, dict): # 如果是字典 + return {key: round_values(value) for key, value in data.items()} + elif isinstance(data, list): # 如果是列表,递归处理每个元素 + return [round_values(item) for item in data] + elif isinstance(data, (int, float)): # 如果是数字,保留三位小数 + return round(data, precision) + else: # 对其他类型,直接返回 + return data + + +def decimal_to_fraction(decimal_value, max_denominator=1000): + """ + 将小数转换为分数表示 + + 参数: + decimal_value: 要转换的小数 + max_denominator: 分母的最大值,用于控制精度 + + 返回: + 分数表示的字符串 + """ + frac = Fraction(decimal_value).limit_denominator(max_denominator) + return f"{frac.numerator}/{frac.denominator}" + +def poscar_to_fractional_representation(poscar_content, max_denominator=1000): + """ + 将POSCAR文件中的数值转换为分数表示 + + 参数: + poscar_content: POSCAR文件内容 + max_denominator: 分母的最大值,用于控制精度 + + 返回: + 转换后的POSCAR内容,数值以分数表示 + """ + lines = poscar_content.strip().split('\n') + result_lines = [] + + # 保留系统名称 + result_lines.append(lines[0]) + + # 保留缩放因子 + scaling_factor = float(lines[1]) + result_lines.append(lines[1]) + + # 处理晶格向量 + for i in range(2, 5): + vector = [float(x) for x in lines[i].split()] + # 将每个分量转换为分数 + fractional_vector = [decimal_to_fraction(x, max_denominator) for x in vector] + result_lines.append(" " + " ".join(fractional_vector)) + + # 保留元素类型和数量 + if len(lines) > 5: + result_lines.append(lines[5]) + if len(lines) > 6: + result_lines.append(lines[6]) + + # 保留坐标类型 + if len(lines) > 7: + result_lines.append(lines[7]) + + # 处理原子坐标 + for i in range(8, len(lines)): + parts = lines[i].split() + if len(parts) >= 3: + # 将坐标转换为分数 + coords = [float(parts[j]) for j in range(3)] + fractional_coords = [decimal_to_fraction(x, max_denominator) for x in coords] + + # 构建新行 + new_line = " " + " ".join(fractional_coords) + if len(parts) > 3: + new_line += " " + " ".join(parts[3:]) + result_lines.append(new_line) + else: + # 保留非坐标行 + result_lines.append(lines[i]) + + return "\n".join(result_lines) + + +def remove_symmetry_equiv_xyz(cif_content): + """ + 删除CIF文件中的对称性操作部分 + + 参数: + cif_content: CIF文件内容字符串 + + 返回: + 清理后的CIF内容字符串 + """ + lines = cif_content.split('\n') + output_lines = [] + + i = 0 + while i < len(lines): + line = lines[i].strip() + + # 检测循环开始 + if line == 'loop_': + # 查看下一行,检查是否是对称性循环 + next_lines = [] + j = i + 1 + while j < len(lines) and lines[j].strip().startswith('_'): + next_lines.append(lines[j].strip()) + j += 1 + + # 检查是否包含对称性操作标签 + if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines): + # 跳过整个循环块 + while i < len(lines): + if i + 1 >= len(lines): + break + + next_line = lines[i + 1].strip() + # 检查是否到达下一个循环或数据块 + if next_line == 'loop_' or next_line.startswith('data_'): + break + + # 检查是否到达原子位置部分 + if next_line.startswith('_atom_site_'): + break + + i += 1 + else: + # 不是对称性循环,保留loop_行 + output_lines.append(lines[i]) + else: + # 非循环开始行,直接保留 + output_lines.append(lines[i]) + + i += 1 + + return '\n'.join(output_lines) + + + +def remove_null_values(d): + """ + Recursively remove key-value pairs with null (None) values from a dictionary. + + Args: + d (dict): The dictionary to clean. + + Returns: + dict: A new dictionary without null values. + """ + if not isinstance(d, dict): + raise ValueError("Input must be a dictionary") + _d = copy.deepcopy(d) + + def recursive_remove(d): + cleaned_dict = {} + for key, value in d.items(): + if isinstance(value, dict): + # Recursively clean nested dictionaries + nested_cleaned = recursive_remove(value) + if nested_cleaned: # Only add non-empty dictionaries + cleaned_dict[key] = nested_cleaned + elif value is not None and key != 'version': + cleaned_dict[key] = value + + return cleaned_dict + + clean_dict = recursive_remove(d) + if _d['cbm'] is None and _d['vbm'] is None and _d['band_gap'] is not None: + # clean_dict['band_gap'] = None + clean_dict.pop('band_gap') + return clean_dict + + +def get_extra_cif_info(path: str, fields_name: list): + """Extract specific fields from the CIF description.""" + basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density'] + energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct'] + metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites'] + # metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"] + + selected_fields = [] + if fields_name[0] == 'all_fields': + selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields + # selected_fields = energy_electronic_fields + metal_magentic_fields + else: + for field in fields_name: + selected_fields.extend(locals().get(field, [])) + + with open(path, 'r') as f: + docs = json.load(f) + + new_docs = {} + for field_name in selected_fields: + new_docs[field_name] = docs.get(field_name, '') + + # new_docs['structure'] = {"lattice": docs['structure']['lattice']} + return new_docs + +def extract_json(text): + """Extract JSON content from a block of text using regex.""" + json_pattern = re.compile(r'\\{(?:[^{}]|(?R))*\\}') + matches = json_pattern.search(text) + if matches: + json_str = matches.group(0) + try: + return json.loads(json_str) + except json.JSONDecodeError: + return None + return None + +def extract_and_parse_json(response): + """Extract and parse JSON from a response.""" + json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response) + json_str = json_match.group(1) if json_match else response.strip() + json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str) + json_str = json_str.replace('\\"', '"').replace("\\'", "'") + try: + return json.loads(json_str) + except json.JSONDecodeError as e: + print(f"JSON parse error: {e}") + return 'errformat' + + +# 计算输入消息的tokens +def count_message_tokens(messages, model_name): + encoding = tiktoken.encoding_for_model(model_name) + num_tokens = 0 + + num_tokens += len(encoding.encode(messages)) + + return num_tokens + +def make_multi_turns_sharegpt_sample(humans: list[str], gpts: list[str], system: str="{SYSTEM}"): + sample = {} + conversations = [] + + if system is not None and system != "": + sample["system"] = system + + assert len(humans) !=0, "human cannot be None" + assert len(gpts) == len(humans), "human and gpt must have the same length" + + for human, gpt in zip(humans, gpts): + if human is not None and human != "": + assert gpt is not None, "gpt cannot be None" + assert gpt != "", "gpt cannot be empty" + # 下列顺序不可改 + conversations.append({"from": "human", "value": human}) + conversations.append({"from": "gpt", "value": gpt}) + + sample["conversations"] = conversations + return sample + + + +##################################### utils diff --git a/mars_toolkit.log b/mars_toolkit.log deleted file mode 100644 index 4a0e17b..0000000 --- a/mars_toolkit.log +++ /dev/null @@ -1,1172 +0,0 @@ -2025-04-02 11:35:20 - root - INFO - Project root: /home/ubuntu/50T/lzy/mars-mcp/.venv/lib/python3.10/site-packages/fairchem -2025-04-02 11:35:21 - root - INFO - amp: true -cmd: - checkpoint_dir: /home/ubuntu/50T/lzy/mars-mcp/checkpoints/2025-04-02-11-35-28 - commit: core:603304e,experimental:NA - identifier: '' - logs_dir: /home/ubuntu/50T/lzy/mars-mcp/logs/wandb/2025-04-02-11-35-28 - print_every: 100 - results_dir: /home/ubuntu/50T/lzy/mars-mcp/results/2025-04-02-11-35-28 - seed: null - timestamp_id: 2025-04-02-11-35-28 - version: 1.9.0 -dataset: - a2g_args: - r_energy: true - r_forces: true - r_stress: true - format: ase_db - transforms: - decompose_tensor: - decomposition: - stress_anisotropic: - irrep_dim: 2 - stress_isotropic: - irrep_dim: 0 - rank: 2 - tensor: stress - element_references: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/element_references.pt - normalizer: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/normalizers.pt -evaluation_metrics: - metrics: - energy: - - mae - - mae_density - forces: - - mae - - forcesx_mae - - forcesy_mae - - forcesz_mae - - cosine_similarity - stress: - - mae - - mae_density - stress_anisotropic: - - mae - stress_isotropic: - - mae - primary_metric: energy_mae -gp_gpus: null -gpus: 0 -logger: wandb -loss_functions: -- energy: - coefficient: 20 - fn: mae_density -- forces: - coefficient: 10 - fn: l2mae -- stress_isotropic: - coefficient: 1 - fn: mae -- stress_anisotropic: - coefficient: 1 - fn: mae - reduction: mean_all -model: - backbone: - alpha_drop: 0.1 - attn_activation: silu - attn_alpha_channels: 64 - attn_hidden_channels: 64 - attn_value_channels: 16 - avg_degree: 61.94676351484548 - avg_num_nodes: 31.16592360068011 - distance_function: gaussian - drop_path_rate: 0.1 - edge_channels: 128 - enforce_max_neighbors_strictly: false - ffn_activation: silu - ffn_hidden_channels: 128 - grid_resolution: 18 - lmax_list: - - 6 - max_neighbors: 20 - max_num_elements: 96 - max_radius: 12.0 - mmax_list: - - 4 - model: equiformer_v2_backbone - norm_type: layer_norm_sh - num_distance_basis: 512 - num_heads: 8 - num_layers: 10 - num_sphere_samples: 128 - otf_graph: true - proj_drop: 0.0 - share_atom_edge_embedding: false - sphere_channels: 128 - use_atom_edge_embedding: true - use_attn_renorm: true - use_gate_act: false - use_grid_mlp: true - use_m_share_rad: false - use_pbc: true - use_pbc_single: true - use_s2_act_attn: false - use_sep_s2_act: true - weight_init: uniform - heads: - energy: - module: equiformer_v2_energy_head - forces: - module: equiformer_v2_force_head - stress: - decompose: true - module: rank2_symmetric_head - output_name: stress - use_source_target_embedding: true - name: hydra - otf_graph: true - pass_through_head_outputs: true -optim: - batch_size: 8 - clip_grad_norm: 100 - ema_decay: 0.999 - eval_batch_size: 12 - eval_every: 3000 - load_balancing: atoms - lr_initial: 0.0002 - max_epochs: 16 - num_workers: 8 - optimizer: AdamW - optimizer_params: - weight_decay: 0.001 - scheduler: LambdaLR - scheduler_params: - epochs: 741904 - lambda_type: cosine - lr: 0.0002 - lr_min_factor: 0.01 - warmup_epochs: 463 - warmup_factor: 0.2 -outputs: - energy: - level: system - property: energy - forces: - eval_on_free_atoms: true - level: atom - property: forces - train_on_free_atoms: true - stress: - decomposition: - stress_anisotropic: - eval_on_free_atoms: true - irrep_dim: 2 - level: system - parent: stress - train_on_free_atoms: true - stress_isotropic: - eval_on_free_atoms: true - irrep_dim: 0 - level: system - parent: stress - train_on_free_atoms: true - level: system - property: stress -relax_dataset: {} -slurm: - account: ocp - cpus_per_task: 9 - folder: /fsx-ocp-med/lbluque/logs/omat-alex-mp/S2EFS/train/4460394 - gpus_per_node: 8 - job_id: '4460394' - job_name: eqV2_86M_ft_alexmptraj_e20_f10_s1_cos16 - mem: 480GB - nodes: 4 - ntasks_per_node: 8 - partition: learn - qos: ocp_high - time: 4320 -task: {} -test_dataset: {} -trainer: ocp -val_dataset: {} - -2025-04-02 11:35:21 - root - INFO - Loading model: hydra -2025-04-02 11:35:25 - root - WARNING - equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead) -2025-04-02 11:35:25 - root - WARNING - equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head) -2025-04-02 11:35:25 - root - INFO - Loaded HydraModel with 86589068 parameters. -2025-04-02 11:35:25 - root - INFO - Loading checkpoint in inference-only mode, not loading keys associated with trainer state! -2025-04-02 11:35:26 - root - WARNING - No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run -2025-04-02 11:35:26 - mars_toolkit.compute.structure_opt - INFO - FairChem model initialized successfully -2025-04-02 11:36:59 - root - INFO - Project root: /home/ubuntu/50T/lzy/mars-mcp/.venv/lib/python3.10/site-packages/fairchem -2025-04-02 11:37:01 - root - INFO - amp: true -cmd: - checkpoint_dir: /home/ubuntu/50T/lzy/mars-mcp/checkpoints/2025-04-02-11-37-36 - commit: core:603304e,experimental:NA - identifier: '' - logs_dir: /home/ubuntu/50T/lzy/mars-mcp/logs/wandb/2025-04-02-11-37-36 - print_every: 100 - results_dir: /home/ubuntu/50T/lzy/mars-mcp/results/2025-04-02-11-37-36 - seed: null - timestamp_id: 2025-04-02-11-37-36 - version: 1.9.0 -dataset: - a2g_args: - r_energy: true - r_forces: true - r_stress: true - format: ase_db - transforms: - decompose_tensor: - decomposition: - stress_anisotropic: - irrep_dim: 2 - stress_isotropic: - irrep_dim: 0 - rank: 2 - tensor: stress - element_references: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/element_references.pt - normalizer: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/normalizers.pt -evaluation_metrics: - metrics: - energy: - - mae - - mae_density - forces: - - mae - - forcesx_mae - - forcesy_mae - - forcesz_mae - - cosine_similarity - stress: - - mae - - mae_density - stress_anisotropic: - - mae - stress_isotropic: - - mae - primary_metric: energy_mae -gp_gpus: null -gpus: 0 -logger: wandb -loss_functions: -- energy: - coefficient: 20 - fn: mae_density -- forces: - coefficient: 10 - fn: l2mae -- stress_isotropic: - coefficient: 1 - fn: mae -- stress_anisotropic: - coefficient: 1 - fn: mae - reduction: mean_all -model: - backbone: - alpha_drop: 0.1 - attn_activation: silu - attn_alpha_channels: 64 - attn_hidden_channels: 64 - attn_value_channels: 16 - avg_degree: 61.94676351484548 - avg_num_nodes: 31.16592360068011 - distance_function: gaussian - drop_path_rate: 0.1 - edge_channels: 128 - enforce_max_neighbors_strictly: false - ffn_activation: silu - ffn_hidden_channels: 128 - grid_resolution: 18 - lmax_list: - - 6 - max_neighbors: 20 - max_num_elements: 96 - max_radius: 12.0 - mmax_list: - - 4 - model: equiformer_v2_backbone - norm_type: layer_norm_sh - num_distance_basis: 512 - num_heads: 8 - num_layers: 10 - num_sphere_samples: 128 - otf_graph: true - proj_drop: 0.0 - share_atom_edge_embedding: false - sphere_channels: 128 - use_atom_edge_embedding: true - use_attn_renorm: true - use_gate_act: false - use_grid_mlp: true - use_m_share_rad: false - use_pbc: true - use_pbc_single: true - use_s2_act_attn: false - use_sep_s2_act: true - weight_init: uniform - heads: - energy: - module: equiformer_v2_energy_head - forces: - module: equiformer_v2_force_head - stress: - decompose: true - module: rank2_symmetric_head - output_name: stress - use_source_target_embedding: true - name: hydra - otf_graph: true - pass_through_head_outputs: true -optim: - batch_size: 8 - clip_grad_norm: 100 - ema_decay: 0.999 - eval_batch_size: 12 - eval_every: 3000 - load_balancing: atoms - lr_initial: 0.0002 - max_epochs: 16 - num_workers: 8 - optimizer: AdamW - optimizer_params: - weight_decay: 0.001 - scheduler: LambdaLR - scheduler_params: - epochs: 741904 - lambda_type: cosine - lr: 0.0002 - lr_min_factor: 0.01 - warmup_epochs: 463 - warmup_factor: 0.2 -outputs: - energy: - level: system - property: energy - forces: - eval_on_free_atoms: true - level: atom - property: forces - train_on_free_atoms: true - stress: - decomposition: - stress_anisotropic: - eval_on_free_atoms: true - irrep_dim: 2 - level: system - parent: stress - train_on_free_atoms: true - stress_isotropic: - eval_on_free_atoms: true - irrep_dim: 0 - level: system - parent: stress - train_on_free_atoms: true - level: system - property: stress -relax_dataset: {} -slurm: - account: ocp - cpus_per_task: 9 - folder: /fsx-ocp-med/lbluque/logs/omat-alex-mp/S2EFS/train/4460394 - gpus_per_node: 8 - job_id: '4460394' - job_name: eqV2_86M_ft_alexmptraj_e20_f10_s1_cos16 - mem: 480GB - nodes: 4 - ntasks_per_node: 8 - partition: learn - qos: ocp_high - time: 4320 -task: {} -test_dataset: {} -trainer: ocp -val_dataset: {} - -2025-04-02 11:37:01 - root - INFO - Loading model: hydra -2025-04-02 11:37:04 - root - WARNING - equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead) -2025-04-02 11:37:04 - root - WARNING - equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head) -2025-04-02 11:37:05 - root - INFO - Loaded HydraModel with 86589068 parameters. -2025-04-02 11:37:05 - root - INFO - Loading checkpoint in inference-only mode, not loading keys associated with trainer state! -2025-04-02 11:37:05 - root - WARNING - No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run -2025-04-02 11:37:05 - mars_toolkit.compute.structure_opt - INFO - FairChem model initialized successfully -2025-04-02 12:32:21 - root - INFO - Project root: /home/ubuntu/50T/lzy/mars-mcp/.venv/lib/python3.10/site-packages/fairchem -2025-04-02 12:32:22 - root - INFO - amp: true -cmd: - checkpoint_dir: /home/ubuntu/50T/lzy/mars-mcp/checkpoints/2025-04-02-12-33-04 - commit: core:603304e,experimental:NA - identifier: '' - logs_dir: /home/ubuntu/50T/lzy/mars-mcp/logs/wandb/2025-04-02-12-33-04 - print_every: 100 - results_dir: /home/ubuntu/50T/lzy/mars-mcp/results/2025-04-02-12-33-04 - seed: null - timestamp_id: 2025-04-02-12-33-04 - version: 1.9.0 -dataset: - a2g_args: - r_energy: true - r_forces: true - r_stress: true - format: ase_db - transforms: - decompose_tensor: - decomposition: - stress_anisotropic: - irrep_dim: 2 - stress_isotropic: - irrep_dim: 0 - rank: 2 - tensor: stress - element_references: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/element_references.pt - normalizer: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/normalizers.pt -evaluation_metrics: - metrics: - energy: - - mae - - mae_density - forces: - - mae - - forcesx_mae - - forcesy_mae - - forcesz_mae - - cosine_similarity - stress: - - mae - - mae_density - stress_anisotropic: - - mae - stress_isotropic: - - mae - primary_metric: energy_mae -gp_gpus: null -gpus: 0 -logger: wandb -loss_functions: -- energy: - coefficient: 20 - fn: mae_density -- forces: - coefficient: 10 - fn: l2mae -- stress_isotropic: - coefficient: 1 - fn: mae -- stress_anisotropic: - coefficient: 1 - fn: mae - reduction: mean_all -model: - backbone: - alpha_drop: 0.1 - attn_activation: silu - attn_alpha_channels: 64 - attn_hidden_channels: 64 - attn_value_channels: 16 - avg_degree: 61.94676351484548 - avg_num_nodes: 31.16592360068011 - distance_function: gaussian - drop_path_rate: 0.1 - edge_channels: 128 - enforce_max_neighbors_strictly: false - ffn_activation: silu - ffn_hidden_channels: 128 - grid_resolution: 18 - lmax_list: - - 6 - max_neighbors: 20 - max_num_elements: 96 - max_radius: 12.0 - mmax_list: - - 4 - model: equiformer_v2_backbone - norm_type: layer_norm_sh - num_distance_basis: 512 - num_heads: 8 - num_layers: 10 - num_sphere_samples: 128 - otf_graph: true - proj_drop: 0.0 - share_atom_edge_embedding: false - sphere_channels: 128 - use_atom_edge_embedding: true - use_attn_renorm: true - use_gate_act: false - use_grid_mlp: true - use_m_share_rad: false - use_pbc: true - use_pbc_single: true - use_s2_act_attn: false - use_sep_s2_act: true - weight_init: uniform - heads: - energy: - module: equiformer_v2_energy_head - forces: - module: equiformer_v2_force_head - stress: - decompose: true - module: rank2_symmetric_head - output_name: stress - use_source_target_embedding: true - name: hydra - otf_graph: true - pass_through_head_outputs: true -optim: - batch_size: 8 - clip_grad_norm: 100 - ema_decay: 0.999 - eval_batch_size: 12 - eval_every: 3000 - load_balancing: atoms - lr_initial: 0.0002 - max_epochs: 16 - num_workers: 8 - optimizer: AdamW - optimizer_params: - weight_decay: 0.001 - scheduler: LambdaLR - scheduler_params: - epochs: 741904 - lambda_type: cosine - lr: 0.0002 - lr_min_factor: 0.01 - warmup_epochs: 463 - warmup_factor: 0.2 -outputs: - energy: - level: system - property: energy - forces: - eval_on_free_atoms: true - level: atom - property: forces - train_on_free_atoms: true - stress: - decomposition: - stress_anisotropic: - eval_on_free_atoms: true - irrep_dim: 2 - level: system - parent: stress - train_on_free_atoms: true - stress_isotropic: - eval_on_free_atoms: true - irrep_dim: 0 - level: system - parent: stress - train_on_free_atoms: true - level: system - property: stress -relax_dataset: {} -slurm: - account: ocp - cpus_per_task: 9 - folder: /fsx-ocp-med/lbluque/logs/omat-alex-mp/S2EFS/train/4460394 - gpus_per_node: 8 - job_id: '4460394' - job_name: eqV2_86M_ft_alexmptraj_e20_f10_s1_cos16 - mem: 480GB - nodes: 4 - ntasks_per_node: 8 - partition: learn - qos: ocp_high - time: 4320 -task: {} -test_dataset: {} -trainer: ocp -val_dataset: {} - -2025-04-02 12:32:22 - root - INFO - Loading model: hydra -2025-04-02 12:32:25 - root - WARNING - equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead) -2025-04-02 12:32:25 - root - WARNING - equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head) -2025-04-02 12:32:26 - root - INFO - Loaded HydraModel with 86589068 parameters. -2025-04-02 12:32:26 - root - INFO - Loading checkpoint in inference-only mode, not loading keys associated with trainer state! -2025-04-02 12:32:26 - root - WARNING - No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run -2025-04-02 12:32:26 - mars_toolkit.compute.structure_opt - INFO - FairChem model initialized successfully -2025-04-02 12:41:48 - root - INFO - Project root: /home/ubuntu/50T/lzy/mars-mcp/.venv/lib/python3.10/site-packages/fairchem -2025-04-02 12:41:49 - root - INFO - amp: true -cmd: - checkpoint_dir: /home/ubuntu/50T/lzy/mars-mcp/checkpoints/2025-04-02-12-41-36 - commit: core:603304e,experimental:NA - identifier: '' - logs_dir: /home/ubuntu/50T/lzy/mars-mcp/logs/wandb/2025-04-02-12-41-36 - print_every: 100 - results_dir: /home/ubuntu/50T/lzy/mars-mcp/results/2025-04-02-12-41-36 - seed: null - timestamp_id: 2025-04-02-12-41-36 - version: 1.9.0 -dataset: - a2g_args: - r_energy: true - r_forces: true - r_stress: true - format: ase_db - transforms: - decompose_tensor: - decomposition: - stress_anisotropic: - irrep_dim: 2 - stress_isotropic: - irrep_dim: 0 - rank: 2 - tensor: stress - element_references: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/element_references.pt - normalizer: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/normalizers.pt -evaluation_metrics: - metrics: - energy: - - mae - - mae_density - forces: - - mae - - forcesx_mae - - forcesy_mae - - forcesz_mae - - cosine_similarity - stress: - - mae - - mae_density - stress_anisotropic: - - mae - stress_isotropic: - - mae - primary_metric: energy_mae -gp_gpus: null -gpus: 0 -logger: wandb -loss_functions: -- energy: - coefficient: 20 - fn: mae_density -- forces: - coefficient: 10 - fn: l2mae -- stress_isotropic: - coefficient: 1 - fn: mae -- stress_anisotropic: - coefficient: 1 - fn: mae - reduction: mean_all -model: - backbone: - alpha_drop: 0.1 - attn_activation: silu - attn_alpha_channels: 64 - attn_hidden_channels: 64 - attn_value_channels: 16 - avg_degree: 61.94676351484548 - avg_num_nodes: 31.16592360068011 - distance_function: gaussian - drop_path_rate: 0.1 - edge_channels: 128 - enforce_max_neighbors_strictly: false - ffn_activation: silu - ffn_hidden_channels: 128 - grid_resolution: 18 - lmax_list: - - 6 - max_neighbors: 20 - max_num_elements: 96 - max_radius: 12.0 - mmax_list: - - 4 - model: equiformer_v2_backbone - norm_type: layer_norm_sh - num_distance_basis: 512 - num_heads: 8 - num_layers: 10 - num_sphere_samples: 128 - otf_graph: true - proj_drop: 0.0 - share_atom_edge_embedding: false - sphere_channels: 128 - use_atom_edge_embedding: true - use_attn_renorm: true - use_gate_act: false - use_grid_mlp: true - use_m_share_rad: false - use_pbc: true - use_pbc_single: true - use_s2_act_attn: false - use_sep_s2_act: true - weight_init: uniform - heads: - energy: - module: equiformer_v2_energy_head - forces: - module: equiformer_v2_force_head - stress: - decompose: true - module: rank2_symmetric_head - output_name: stress - use_source_target_embedding: true - name: hydra - otf_graph: true - pass_through_head_outputs: true -optim: - batch_size: 8 - clip_grad_norm: 100 - ema_decay: 0.999 - eval_batch_size: 12 - eval_every: 3000 - load_balancing: atoms - lr_initial: 0.0002 - max_epochs: 16 - num_workers: 8 - optimizer: AdamW - optimizer_params: - weight_decay: 0.001 - scheduler: LambdaLR - scheduler_params: - epochs: 741904 - lambda_type: cosine - lr: 0.0002 - lr_min_factor: 0.01 - warmup_epochs: 463 - warmup_factor: 0.2 -outputs: - energy: - level: system - property: energy - forces: - eval_on_free_atoms: true - level: atom - property: forces - train_on_free_atoms: true - stress: - decomposition: - stress_anisotropic: - eval_on_free_atoms: true - irrep_dim: 2 - level: system - parent: stress - train_on_free_atoms: true - stress_isotropic: - eval_on_free_atoms: true - irrep_dim: 0 - level: system - parent: stress - train_on_free_atoms: true - level: system - property: stress -relax_dataset: {} -slurm: - account: ocp - cpus_per_task: 9 - folder: /fsx-ocp-med/lbluque/logs/omat-alex-mp/S2EFS/train/4460394 - gpus_per_node: 8 - job_id: '4460394' - job_name: eqV2_86M_ft_alexmptraj_e20_f10_s1_cos16 - mem: 480GB - nodes: 4 - ntasks_per_node: 8 - partition: learn - qos: ocp_high - time: 4320 -task: {} -test_dataset: {} -trainer: ocp -val_dataset: {} - -2025-04-02 12:41:49 - root - INFO - Loading model: hydra -2025-04-02 12:41:52 - root - WARNING - equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead) -2025-04-02 12:41:52 - root - WARNING - equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head) -2025-04-02 12:41:52 - root - INFO - Loaded HydraModel with 86589068 parameters. -2025-04-02 12:41:52 - root - INFO - Loading checkpoint in inference-only mode, not loading keys associated with trainer state! -2025-04-02 12:41:53 - root - WARNING - No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run -2025-04-02 12:41:53 - mars_toolkit.compute.structure_opt - INFO - FairChem model initialized successfully -2025-04-02 12:44:38 - root - INFO - Project root: /home/ubuntu/50T/lzy/mars-mcp/.venv/lib/python3.10/site-packages/fairchem -2025-04-02 12:44:39 - root - INFO - amp: true -cmd: - checkpoint_dir: /home/ubuntu/50T/lzy/mars-mcp/checkpoints/2025-04-02-12-43-44 - commit: core:603304e,experimental:NA - identifier: '' - logs_dir: /home/ubuntu/50T/lzy/mars-mcp/logs/wandb/2025-04-02-12-43-44 - print_every: 100 - results_dir: /home/ubuntu/50T/lzy/mars-mcp/results/2025-04-02-12-43-44 - seed: null - timestamp_id: 2025-04-02-12-43-44 - version: 1.9.0 -dataset: - a2g_args: - r_energy: true - r_forces: true - r_stress: true - format: ase_db - transforms: - decompose_tensor: - decomposition: - stress_anisotropic: - irrep_dim: 2 - stress_isotropic: - irrep_dim: 0 - rank: 2 - tensor: stress - element_references: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/element_references.pt - normalizer: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/normalizers.pt -evaluation_metrics: - metrics: - energy: - - mae - - mae_density - forces: - - mae - - forcesx_mae - - forcesy_mae - - forcesz_mae - - cosine_similarity - stress: - - mae - - mae_density - stress_anisotropic: - - mae - stress_isotropic: - - mae - primary_metric: energy_mae -gp_gpus: null -gpus: 0 -logger: wandb -loss_functions: -- energy: - coefficient: 20 - fn: mae_density -- forces: - coefficient: 10 - fn: l2mae -- stress_isotropic: - coefficient: 1 - fn: mae -- stress_anisotropic: - coefficient: 1 - fn: mae - reduction: mean_all -model: - backbone: - alpha_drop: 0.1 - attn_activation: silu - attn_alpha_channels: 64 - attn_hidden_channels: 64 - attn_value_channels: 16 - avg_degree: 61.94676351484548 - avg_num_nodes: 31.16592360068011 - distance_function: gaussian - drop_path_rate: 0.1 - edge_channels: 128 - enforce_max_neighbors_strictly: false - ffn_activation: silu - ffn_hidden_channels: 128 - grid_resolution: 18 - lmax_list: - - 6 - max_neighbors: 20 - max_num_elements: 96 - max_radius: 12.0 - mmax_list: - - 4 - model: equiformer_v2_backbone - norm_type: layer_norm_sh - num_distance_basis: 512 - num_heads: 8 - num_layers: 10 - num_sphere_samples: 128 - otf_graph: true - proj_drop: 0.0 - share_atom_edge_embedding: false - sphere_channels: 128 - use_atom_edge_embedding: true - use_attn_renorm: true - use_gate_act: false - use_grid_mlp: true - use_m_share_rad: false - use_pbc: true - use_pbc_single: true - use_s2_act_attn: false - use_sep_s2_act: true - weight_init: uniform - heads: - energy: - module: equiformer_v2_energy_head - forces: - module: equiformer_v2_force_head - stress: - decompose: true - module: rank2_symmetric_head - output_name: stress - use_source_target_embedding: true - name: hydra - otf_graph: true - pass_through_head_outputs: true -optim: - batch_size: 8 - clip_grad_norm: 100 - ema_decay: 0.999 - eval_batch_size: 12 - eval_every: 3000 - load_balancing: atoms - lr_initial: 0.0002 - max_epochs: 16 - num_workers: 8 - optimizer: AdamW - optimizer_params: - weight_decay: 0.001 - scheduler: LambdaLR - scheduler_params: - epochs: 741904 - lambda_type: cosine - lr: 0.0002 - lr_min_factor: 0.01 - warmup_epochs: 463 - warmup_factor: 0.2 -outputs: - energy: - level: system - property: energy - forces: - eval_on_free_atoms: true - level: atom - property: forces - train_on_free_atoms: true - stress: - decomposition: - stress_anisotropic: - eval_on_free_atoms: true - irrep_dim: 2 - level: system - parent: stress - train_on_free_atoms: true - stress_isotropic: - eval_on_free_atoms: true - irrep_dim: 0 - level: system - parent: stress - train_on_free_atoms: true - level: system - property: stress -relax_dataset: {} -slurm: - account: ocp - cpus_per_task: 9 - folder: /fsx-ocp-med/lbluque/logs/omat-alex-mp/S2EFS/train/4460394 - gpus_per_node: 8 - job_id: '4460394' - job_name: eqV2_86M_ft_alexmptraj_e20_f10_s1_cos16 - mem: 480GB - nodes: 4 - ntasks_per_node: 8 - partition: learn - qos: ocp_high - time: 4320 -task: {} -test_dataset: {} -trainer: ocp -val_dataset: {} - -2025-04-02 12:44:39 - root - INFO - Loading model: hydra -2025-04-02 12:44:43 - root - WARNING - equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead) -2025-04-02 12:44:43 - root - WARNING - equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head) -2025-04-02 12:44:43 - root - INFO - Loaded HydraModel with 86589068 parameters. -2025-04-02 12:44:43 - root - INFO - Loading checkpoint in inference-only mode, not loading keys associated with trainer state! -2025-04-02 12:44:43 - root - WARNING - No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run -2025-04-02 12:44:43 - mars_toolkit.compute.structure_opt - INFO - FairChem model initialized successfully -2025-04-02 12:45:04 - mars_toolkit.compute.structure_opt - ERROR - Failed to optimize structure: 'str' object has no attribute 'site_properties' -2025-04-02 12:45:04 - mars_toolkit.core.error_handlers - ERROR - Unexpected error: 'str' object has no attribute 'site_properties' -2025-04-02 12:47:19 - root - INFO - Project root: /home/ubuntu/50T/lzy/mars-mcp/.venv/lib/python3.10/site-packages/fairchem -2025-04-02 12:47:20 - root - INFO - amp: true -cmd: - checkpoint_dir: /home/ubuntu/50T/lzy/mars-mcp/checkpoints/2025-04-02-12-48-00 - commit: core:603304e,experimental:NA - identifier: '' - logs_dir: /home/ubuntu/50T/lzy/mars-mcp/logs/wandb/2025-04-02-12-48-00 - print_every: 100 - results_dir: /home/ubuntu/50T/lzy/mars-mcp/results/2025-04-02-12-48-00 - seed: null - timestamp_id: 2025-04-02-12-48-00 - version: 1.9.0 -dataset: - a2g_args: - r_energy: true - r_forces: true - r_stress: true - format: ase_db - transforms: - decompose_tensor: - decomposition: - stress_anisotropic: - irrep_dim: 2 - stress_isotropic: - irrep_dim: 0 - rank: 2 - tensor: stress - element_references: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/element_references.pt - normalizer: - file: /fsx-ocp-med/shared/alex-10M/alex-mp-norms-refs/normalizers.pt -evaluation_metrics: - metrics: - energy: - - mae - - mae_density - forces: - - mae - - forcesx_mae - - forcesy_mae - - forcesz_mae - - cosine_similarity - stress: - - mae - - mae_density - stress_anisotropic: - - mae - stress_isotropic: - - mae - primary_metric: energy_mae -gp_gpus: null -gpus: 0 -logger: wandb -loss_functions: -- energy: - coefficient: 20 - fn: mae_density -- forces: - coefficient: 10 - fn: l2mae -- stress_isotropic: - coefficient: 1 - fn: mae -- stress_anisotropic: - coefficient: 1 - fn: mae - reduction: mean_all -model: - backbone: - alpha_drop: 0.1 - attn_activation: silu - attn_alpha_channels: 64 - attn_hidden_channels: 64 - attn_value_channels: 16 - avg_degree: 61.94676351484548 - avg_num_nodes: 31.16592360068011 - distance_function: gaussian - drop_path_rate: 0.1 - edge_channels: 128 - enforce_max_neighbors_strictly: false - ffn_activation: silu - ffn_hidden_channels: 128 - grid_resolution: 18 - lmax_list: - - 6 - max_neighbors: 20 - max_num_elements: 96 - max_radius: 12.0 - mmax_list: - - 4 - model: equiformer_v2_backbone - norm_type: layer_norm_sh - num_distance_basis: 512 - num_heads: 8 - num_layers: 10 - num_sphere_samples: 128 - otf_graph: true - proj_drop: 0.0 - share_atom_edge_embedding: false - sphere_channels: 128 - use_atom_edge_embedding: true - use_attn_renorm: true - use_gate_act: false - use_grid_mlp: true - use_m_share_rad: false - use_pbc: true - use_pbc_single: true - use_s2_act_attn: false - use_sep_s2_act: true - weight_init: uniform - heads: - energy: - module: equiformer_v2_energy_head - forces: - module: equiformer_v2_force_head - stress: - decompose: true - module: rank2_symmetric_head - output_name: stress - use_source_target_embedding: true - name: hydra - otf_graph: true - pass_through_head_outputs: true -optim: - batch_size: 8 - clip_grad_norm: 100 - ema_decay: 0.999 - eval_batch_size: 12 - eval_every: 3000 - load_balancing: atoms - lr_initial: 0.0002 - max_epochs: 16 - num_workers: 8 - optimizer: AdamW - optimizer_params: - weight_decay: 0.001 - scheduler: LambdaLR - scheduler_params: - epochs: 741904 - lambda_type: cosine - lr: 0.0002 - lr_min_factor: 0.01 - warmup_epochs: 463 - warmup_factor: 0.2 -outputs: - energy: - level: system - property: energy - forces: - eval_on_free_atoms: true - level: atom - property: forces - train_on_free_atoms: true - stress: - decomposition: - stress_anisotropic: - eval_on_free_atoms: true - irrep_dim: 2 - level: system - parent: stress - train_on_free_atoms: true - stress_isotropic: - eval_on_free_atoms: true - irrep_dim: 0 - level: system - parent: stress - train_on_free_atoms: true - level: system - property: stress -relax_dataset: {} -slurm: - account: ocp - cpus_per_task: 9 - folder: /fsx-ocp-med/lbluque/logs/omat-alex-mp/S2EFS/train/4460394 - gpus_per_node: 8 - job_id: '4460394' - job_name: eqV2_86M_ft_alexmptraj_e20_f10_s1_cos16 - mem: 480GB - nodes: 4 - ntasks_per_node: 8 - partition: learn - qos: ocp_high - time: 4320 -task: {} -test_dataset: {} -trainer: ocp -val_dataset: {} - -2025-04-02 12:47:20 - root - INFO - Loading model: hydra -2025-04-02 12:47:24 - root - WARNING - equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead) -2025-04-02 12:47:24 - root - WARNING - equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head) -2025-04-02 12:47:24 - root - INFO - Loaded HydraModel with 86589068 parameters. -2025-04-02 12:47:24 - root - INFO - Loading checkpoint in inference-only mode, not loading keys associated with trainer state! -2025-04-02 12:47:24 - root - WARNING - No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run -2025-04-02 12:47:24 - mars_toolkit.compute.structure_opt - INFO - FairChem model initialized successfully diff --git a/mars_toolkit/__init__.py b/mars_toolkit/__init__.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/__pycache__/__init__.cpython-310.pyc old mode 100644 new mode 100755 index 6a1511fa5e6a3d9a6e5aa37bcc8a70e589e677f2..10a881077d313868a9940da5f4008018fe8d726c GIT binary patch delta 46 zcmcc3ai4=bpO=@50SG4i_`8w&9kZyFenx(7s(xuwX^?8vc50RS914KM%z diff --git a/mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc b/mars_toolkit/compute/__pycache__/property_pred.cpython-310.pyc old mode 100644 new mode 100755 index 83252fa17bb1dda9e70b70aaaa3f43669b9ebf3b..55e5c50675e68d03e13bfce58268f6864157445d GIT binary patch delta 47 zcmaDM@J4_;pO=@50SG4i_`8w2l1@lk_2pO=@50SG4i_`8wYibd32KO;XkRlhW;G_Ryo-_#&PKQFC#a~O-b003&Z B4&(p; delta 43 xcmeyU@m7O7pO=@50SMmSd9#t*ibc>_KO;XkRlhW;G_Ryo-_#&va}kbxevopO=@50SG4i_`8vtiCNS^KO;XkRlhW;G_Ryo-_#&PKQFC#vmCPu7XVIQ B4gdfE delta 43 xcmX>mbx4XkpO=@50SIJnzS+pl#4Ko|pOK%Ns$ZH^npaY)Z)y;-S&3PN3jhKc3*rC( diff --git a/mars_toolkit/core/__pycache__/config.cpython-310.pyc b/mars_toolkit/core/__pycache__/config.cpython-310.pyc old mode 100644 new mode 100755 index d11872dea0d40dc3215de592d786006b4a4d5436..c47f1a04baf81162594715be10097af57b090737 GIT binary patch delta 158 zcmX>pa94mgpO=@50SMTH{-^)h$Q#2X8m*s^pPQ;*npB!sQmSui5Tc)#Ry_F*lfvX# kOopOPIArxjwXw>`Fjt9MW0TpuoLPiX)E2AM8J1dR0L0NW00000 delta 149 zcmcaBa8iIbpO=@50SKB6Kc>Ig$Q#2X7_Og@pPQ;*npB!sQmSui5Hh)rDNxW3i>QJ^ rZemGEYEgP>UVL(PK}nT5h9V8-DnWB}(aqbLMHmGwF@ztm)G`AAMVK+4 diff --git a/mars_toolkit/core/__pycache__/error_handlers.cpython-310.pyc b/mars_toolkit/core/__pycache__/error_handlers.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc b/mars_toolkit/core/__pycache__/llm_tools.cpython-310.pyc old mode 100644 new mode 100755 index 10c44036e66b46b92d9f5e7ebc0f431be338a80c..7a9f7f5584d3591c1ee5a75d3093b31d5b9ff308 GIT binary patch delta 54 zcmaE<@m7O7pO=@50SNkk{N2d?k4@A;KO;XkRlhW;G_Ryo-_#&PKQFC#vlRO=X2xHe JA98hY0s#GZ6D@lu03pO=@50SGQHezTGLADf_!enx(7s(xuwXbpOK%Ns$ZH^npaY)Z)y;tpO;oVxrMn909ca` AHUIzs delta 42 vcmey#{*|3OpO=@50R${RZRE~i7Ie|i$j?pHFHI`VD=F1CH3*s9!CVLc_qz;Z diff --git a/mars_toolkit/core/__pycache__/utils.cpython-310.pyc b/mars_toolkit/core/__pycache__/utils.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/mars_toolkit/core/cif_utils.py b/mars_toolkit/core/cif_utils.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/core/config.py b/mars_toolkit/core/config.py old mode 100644 new mode 100755 index 886a480..41f1684 --- a/mars_toolkit/core/config.py +++ b/mars_toolkit/core/config.py @@ -22,12 +22,12 @@ class Config: HTTPS_PROXY = 'http://192.168.168.1:20171' # FairChem - FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt' + FAIRCHEM_MODEL_PATH = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/fairchem_ckpt/eqV2_86M_omat_mp_salex.pt' FMAX = 0.05 # MatterGen - MATTERGENMODEL_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/pretrained_models/mattergen_ckpt' - MATTERGEN_ROOT='/home/ubuntu/50T/lzy/mars-mcp/mattergen' + MATTERGENMODEL_ROOT = '/home/ubuntu/50T/nfs/lzy/mars-mcp/pretrained_models/mattergen_ckpt' + MATTERGEN_ROOT='/home/ubuntu/50T/nfs/lzy/mars-mcp/mattergen' MATTERGENMODEL_RESULT_PATH = 'results/' # Dify @@ -38,7 +38,7 @@ class Config: SEARXNG_HOST="http://192.168.168.1:40032/" # Visualization - VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization' + VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/nfs/lzy/mars-mcp/outputs/cif_visualization' @classmethod def as_dict(cls) -> Dict[str, Any]: diff --git a/mars_toolkit/core/llm_tools.py b/mars_toolkit/core/llm_tools.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/core/mattergen_wrapper.py b/mars_toolkit/core/mattergen_wrapper.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/misc/__init__.py b/mars_toolkit/misc/__init__.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/misc/__pycache__/__init__.cpython-310.pyc old mode 100644 new mode 100755 index 8314c2d2073f36f27bcbc87a48f09499fe72ecf2..1bc8892f72cae8133e0a0f0ae283a73b3ef95bf6 GIT binary patch delta 44 ycmX@WbcBgJpO=@50SG4i_&bsNk*K|XMt*LperZx^UP-CGsX>T-URv?Qe@*~6vJY1P delta 41 vcmX@YbbyIFpO=@50SGQ{c{7pwk)XAHMt*LperZx^UP-CGsX@qOCPpU!2}%sk diff --git a/mars_toolkit/misc/__pycache__/general_tools.cpython-310.pyc b/mars_toolkit/misc/__pycache__/general_tools.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc b/mars_toolkit/misc/__pycache__/misc_tools.cpython-310.pyc old mode 100644 new mode 100755 index 70641404595070ed4e1022499b186e78492ac1ec..36bc18bed0ed851df2f53ff23b9d5f8c780c66c7 GIT binary patch delta 46 zcmdnaxto(apO=@50SG4i_`8vNH?ydtenx(7s(xuwXV!Z delta 42 wcmdnZxt)_cpO=@50SF#1d9#swH?yFvenx(7s(xuwX-01@;KK>z>% diff --git a/mars_toolkit/misc/misc_tools.py b/mars_toolkit/misc/misc_tools.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/query/__init__.py b/mars_toolkit/query/__init__.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/query/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/query/__pycache__/__init__.cpython-310.pyc old mode 100644 new mode 100755 index a5003b27b953c9757efc59b6ca28f978ec8c6e94..04f283c99595d946b804bfab137c93e57943b4bd GIT binary patch delta 46 zcmbQpHkFM#pO=@50SG4i_`8vN4U?#Yenx(7s(xuwXK diff --git a/mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc b/mars_toolkit/query/__pycache__/mp_query.cpython-310.pyc old mode 100644 new mode 100755 index acb5ae2155565a9c916ec80a1b09036315ae8f23..9069b44d25b7f233348d20f6b7865423c4fa2263 GIT binary patch delta 47 zcmbP`Iwh4mpO=@50SG4i_`8w2hDFptKO;XkRlhW;G_Ryo-_#&PKQFC#^9+_Bx&UZt B5FY>l delta 43 xcmbP|Iw6%ipO=@50SNlvLQ4tD?m diff --git a/mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc b/mars_toolkit/query/__pycache__/oqmd_query.cpython-310.pyc old mode 100644 new mode 100755 index 8473082bc4aec6888a56f278e0a4a86b520407d9..1ac1161812c32b54dfc3964fb5ea946ed557f78c GIT binary patch delta 47 zcmdlden^}kzE7MxpO=@50SG?dc(ajPj!V!^KO;XkRlhW;G_Ryo-_#&vvmI9pBLE$=4M_k1 diff --git a/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc b/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc old mode 100644 new mode 100755 index 0da5f8346c5d49868955a7a5b10d5bda4b62d9fe..43b40a916b915af7a239e016cc4290daac91fcce GIT binary patch delta 47 zcmX>scwCS>pO=@50SG4i_`8w2o>kOIKO;XkRlhW;G_Ryo-_#&PKQFC#^K4cPW&m6v B4($K{ delta 43 xcmX>ucvz4-pO=@50SF{=KW*f$XBD*5&&bbB)h|sd%_}L@H#G>^JfBsA82|zc3}pZS diff --git a/mars_toolkit/query/dify_search.py b/mars_toolkit/query/dify_search.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/query/mp_query.py b/mars_toolkit/query/mp_query.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/query/oqmd_query.py b/mars_toolkit/query/oqmd_query.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/query/web_search.py b/mars_toolkit/query/web_search.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/services/__init__.py b/mars_toolkit/services/__init__.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/services/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/services/__pycache__/__init__.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc b/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/mars_toolkit/services/mattergen_service.py b/mars_toolkit/services/mattergen_service.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/visualization/__init__.py b/mars_toolkit/visualization/__init__.py old mode 100644 new mode 100755 diff --git a/mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/visualization/__pycache__/__init__.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc b/mars_toolkit/visualization/__pycache__/band_vis.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/mars_toolkit/visualization/__pycache__/crystal_vis.cpython-310.pyc b/mars_toolkit/visualization/__pycache__/crystal_vis.cpython-310.pyc old mode 100644 new mode 100755 diff --git a/mattergen_api.py b/mattergen_api.py old mode 100644 new mode 100755 diff --git a/mattergen_client_example.py b/mattergen_client_example.py deleted file mode 100644 index 928237c..0000000 --- a/mattergen_client_example.py +++ /dev/null @@ -1,134 +0,0 @@ -import requests -import json -import argparse -import sys - -def generate_material( - url="http://localhost:8051/generate_material", - properties=None, - batch_size=2, - num_batches=1, - diffusion_guidance_factor=2.0 -): - """ - 调用MatterGen API生成晶体结构 - - Args: - url: API端点URL - properties: 可选的属性约束,例如{"dft_band_gap": 2.0} - batch_size: 每批生成的结构数量 - num_batches: 批次数量 - diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 - - Returns: - 生成的结构内容或错误信息 - """ - # 构建请求负载 - payload = { - "properties": properties , - "batch_size": batch_size, - "num_batches": num_batches, - "diffusion_guidance_factor": diffusion_guidance_factor - } - - print(f"发送请求到 {url}") - print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}") - - try: - # 添加headers参数,包含accept头 - headers = { - "Content-Type": "application/json", - "accept": "application/json" - } - - # 打印完整请求信息(调试用) - print(f"完整请求URL: {url}") - print(f"请求头: {headers}") - print(f"请求体: {json.dumps(payload)}") - - # 禁用代理设置 - proxies = { - "http": None, - "https": None - } - - # 发送POST请求,添加headers参数,禁用代理,增加超时时间 - response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300) - - # 打印响应信息(调试用) - print(f"响应状态码: {response.status_code}") - print(f"响应头: {dict(response.headers)}") - print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长 - - # 检查响应状态 - if response.status_code == 200: - result = response.json() - - if result["success"]: - print("\n生成成功!") - return result["content"] - else: - print(f"\n生成失败: {result['message']}") - return None - else: - print(f"\n请求失败,状态码: {response.status_code}") - print(f"响应内容: {response.text}") - return None - - except Exception as e: - print(f"\n发生错误: {str(e)}") - print(f"错误类型: {type(e).__name__}") - import traceback - print(f"错误堆栈: {traceback.format_exc()}") - return None - -def main(): - """命令行入口函数""" - parser = argparse.ArgumentParser(description="MatterGen API客户端示例") - - # 添加命令行参数 - parser.add_argument("--url", default="http://localhost:8051/generate_material", - help="MatterGen API端点URL") - parser.add_argument("--property-name", default='dft_mag_density',help="属性名称,例如dft_band_gap") - parser.add_argument("--property-value",default=0.15,help="属性值,例如2.0") - parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量") - parser.add_argument("--num-batches", type=int, default=1, help="批次数量") - parser.add_argument("--guidance-factor", type=float, default=2.0, - help="控制生成结构与目标属性的符合程度") - - args = parser.parse_args() - - # 构建属性字典 - properties = None - if args.property_name and args.property_value: - try: - # 尝试将属性值转换为数字 - try: - value = float(args.property_value) - # 如果是整数,转换为整数 - if value.is_integer(): - value = int(value) - except ValueError: - # 如果无法转换为数字,保持为字符串 - value = args.property_value - - properties = {args.property_name: value} - except Exception as e: - print(f"解析属性值时出错: {str(e)}") - return - - # 调用API - result = generate_material( - url=args.url, - properties=properties, - batch_size=args.batch_size, - num_batches=args.num_batches, - diffusion_guidance_factor=args.guidance_factor - ) - - if result: - print("\n生成的结构:") - print(result) - -if __name__ == "__main__": - main() diff --git a/prompts/__pycache__/material_synthesis.cpython-310.pyc b/prompts/__pycache__/material_synthesis.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..5e287f8614f6c33ac468d5bce0d3fe873d35a8cd GIT binary patch literal 5363 zcmc&&+fy6Y8Q+^OtO2{OC*$;?qj7qPGBG%glXhsDY3nqd=`_T(GwqaV%P88lEUXLJ zT{(rq4EG3wfW-I;7;qhf(;$(MU@RH@NAzu9qg_c)d1+tL#P09gBZRFf9QO$|sy|*deu1BU3WdTce1wbgEqqjH5u##?$WxyX zk)o~^ms1xb?x?55!*ex4jil5mE|piks`sSuIH$PP8pWf^s;h0+^;rWy$?7Nkj=A2%v=sH6Dl(=G3zB`B_xFwgT`7pEODq%N+XSKOy1?TF%;b2$gV$eeqgdkG`+e7B^? zUHlZM?CO$G);jgN8Bez>R-@EAt=;85SGTLOP!u4!1ld@Nnc5mjBo<7-G`%5jPH#s$; zv?Z-3AERXVEEGRlw8qa`gR5509Z<5CE?R?YpWYeGZ;V(s`h0&spqw;XUk}8T)=;4R zK$Gvl;f6y8x+^rdq~7FXe3f|wQ|FgUZ;gXgDR-%uzMfyp9%ft}ceGQ**bersSLA0JeD>UJJC%l{?c^=6l;4=D?85|AywPia&_}{q zzl&+nfTp(^@p$BP*vPLfVRA@~san8LTcaqna3G>vx8Ah|Z!yUf-oF8jF(MVp^5F(w zVI^0{&05pTR&vO`I#HfoIU}?(HHBq`kLJKBsXjUMNl$Mfch0^#fC}~~4rgHx#AW}w z{A_~{!rnoJ9lpTiHIXb#41on^KukLF5iDcpQnhFxtQ0d#_R9G}?vj;Utthp;;^W@J z;)pf1AOT?HdRA>*blqS>m*aB2&P0isTOahLb5Xmxn63NoF z0egCuR5E-MS|MW2Sy8d4v-Y)RW?c*{@@r%4Y{bIg3Je%Z*bAmX}@z~5<$IsfCaSYKJVRq?keq#m>&lW+T{p%R9 zN8h%Sm%tdA<)N!CU$b*!8RA{W!?inT@`^q8HSPi+uHtqt0jm#q)Z1FiG9 zF{mD*u#(I6bWb6d%x4#@bG@{1aItvzEqIu0y7dm-P@|?~sDCs30{%Oovt5W^)X?G) z)$|18F++_R8X^W~@^X;=H$NgO6#7o!r~6QZ5V-zQ`)8q>NBHBMpu~3xQ{0U7w(tgp z9-&h>=<5`kd^af&3k`=%@wD1uy3Pb5302zz5>D~i&l*p~qiSQ~^+e1_G(LCewZ=$i zMNZ6`(#Ees_9@ZP$J52e72ts=hQ(pWjaYm8a$8|mt)KmvP zVd;U8syDn4i3fqAFQT$JKk*di0(BT#GmK&Iv;(9WCL|g$J-_xIfUDF)h{NVuw^tb= zQsmoau2*C^(Tz0@wArJ>g~5gLCXP`R?@qz<$e1Z)JaQnYhA+keIfmn>-~+A?#NLXf zKTjNj;wUsYh)Lm|IM0k17E)|AV}k__vY5>PQ4MvzgJrA(x79$%Agz!a09o>dido*ss*9tGb3JP1_QaLT@%k;?&foA{lp#G?Qno$o6Vk3{=r zkjISUQGl4DoNTgOhJ;x%_I9XctdCtkv$?! zQ^WzOpYP@|Ce3rmHft1@;Zod&+rYn*aZ10UHA0&Q*``PF4!AI1P$Wb*InfHTggZWa zWziZMt>hlob-EQ;mlxp!uN*yQPp6O^p!tww3amHzs^?khtGFUMlr=KMAb#6E)cG&` z^g}3md4p@?6&_h9hujl6r-CD-M5hIoSt}$cH7PN~6XVV1-UnmpfBHbnn$1>n!%83^ z;CM);KsTl`P7YUJoVQ})wt%>q?XWuz%LjI(0gjm^UIs4=af$r)yz;jvUTsEhi4-Wm z_95K(tJBMu=?Nefh^po;Mb(2^*m(gky_)()BCM&3W6=q+p=$ogwI9qm{i60GOw@ji z!mJ5sp+ppyB^}-SsmIM6%B*Mq$e5k+D%*FiJdOj@(j}yoIfrgwwN5MuJ(4KE_9cNz z{C^?2HfzFXXx;B@(9{r?s#nSAE5ml?Y;8c@tW*qxLeA^p_3mxPhjumB5>4g(0 z{Ap3EHSiqM#o-z!^z(gu7q4&$PTL#ix_SIyUtRnePJ15xa3byb&N|o|7opb=aZX^p zc*PL@$i-?Uu1jnaJD%e?`qJt@aWVc)PI2Lt!|l*HjumduIo|2T+lW(#gGpFjBXI`T z@q4zGtutJ0Zo|_hbV-T_?<#owNzpY^Zq|N)6C_Bm_Vjy&S-O^2oUzHrkc+Oe8sUti zalf#IXW9?Zi;nNpAnDfcHRV6h9rf3m7T17-dK@?yi3b$j+)d{9k{XOF){ruyTmo`@I@GWlIw8em&CZ1PZL7>-TIA$8LvhB3!+ZlT93)$?9K;w5npEW*q0 z<-+V@xv=M+_7e>Gg=R-T{(94m z2!-@i6KEsKJ2brAz~zt)G&sSF3H-a%Mr^|?UPC1QW<))sMl`Z)?Kvu*q2dS?k5fSo v=4@2O^_D|=Ux>z)L_~d&>_o>qCby5Pg>VH4Zg>yEy)OvDJ+CN#u{ZQTVcNN> literal 0 HcmV?d00001 diff --git a/prompts/material_synthesis.py b/prompts/material_synthesis.py new file mode 100755 index 0000000..be92d3a --- /dev/null +++ b/prompts/material_synthesis.py @@ -0,0 +1,167 @@ +from typing import Dict, List, Optional +import mcp.types as types +from mcp.server.lowlevel import Server + + +def create_messages( + properties: Dict[str, str] = None, + batch_size: int = 2, +) -> list[types.PromptMessage]: + """ + 创建用于材料生成和合成的提示词消息。 + + Args: + properties: 材料性质及其值的字典,例如 {"dft_band_gap": "2.0"} + batch_size: 生成材料的数量,默认为2 + + Returns: + 提示词消息列表 + """ + messages = [] + + # 系统消息,定义助手的角色和任务 + system_message = """你是一位专业的材料科学家,擅长材料生成和合成方案设计。 +你的任务是: +1. 根据用户提供的材料性质要求,使用mars_toolkit中的generate_materials工具生成符合要求的材料 +2. 系统地分析生成的材料的四要素:成分、结构、工艺和性能 +3. 为生成的材料设计科学合理的合成方案 +4. 使用mermaid语法绘制材料的合成流程图 + +请确保你的回答包含以下内容: +- 对用户需求的理解和分析 +- 使用generate_material工具生成的材料结构 +- 生成材料的四要素详细分析: + * 成分(Composition):详细的化学成分、元素比例、化学计量比 + * 结构(Structure):晶体结构、空间群、晶格参数、原子位置、配位环境 + * 工艺(Processing):可行的合成路线、工艺参数、关键控制因素 + * 性能(Properties):预期的物理、化学、机械性能及其与结构的关系 +- 详细的合成方案,包括: + * 原料选择及纯度要求 + * 精确的反应条件(温度、压力、时间、气氛) + * 分步骤的合成流程及每步的理论依据 + * 可能的挑战及解决方案 + * 表征方法建议 +- 使用mermaid语法绘制的合成流程图,清晰展示从原料到最终产品的全过程 +""" + + messages.append( + types.PromptMessage( + role="system", + content=types.TextContent(type="text", text=system_message), + ) + ) + + # 构建主提示词 + if properties and len(properties) > 0: + properties_text = "\n".join([f"- {key}: {value}" for key, value in properties.items()]) + prompt = f"""请根据以下材料性质要求,生成{batch_size}个合适的材料并设计其合成方案: + +{properties_text} + +请按照以下步骤进行: + +1. 使用mars_toolkit中的generate_material工具生成材料,参数设置为batch_size={batch_size} +2. 对生成的每种材料进行系统的四要素分析: + - 成分:详细分析元素组成、化学计量比及其理论依据 + - 结构:分析晶体结构、空间群、晶格参数、原子排布及其稳定性 + - 工艺:探讨可行的合成路线、工艺参数及其科学依据 + - 性能:预测材料可能具有的物理、化学、机械性能及其应用前景 + +3. 为每种材料设计详细的合成方案,包括: + - 原料选择及纯度要求 + - 精确的反应条件参数(温度、压力、时间、气氛等) + - 分步骤的合成流程及每步的理论依据 + - 可能遇到的挑战及解决方案 + - 推荐的表征方法 + +4. 使用mermaid语法绘制材料的合成流程图,清晰展示从原料到最终产品的全过程,包括关键工艺参数。""" + else: + prompt = f"""请生成{batch_size}种具有创新性的新型材料并设计其合成方案。 + +请按照以下步骤进行: + +1. 使用mars_toolkit中的generate_material工具生成材料,参数设置为batch_size={batch_size} +2. 对生成的每种材料进行系统的四要素分析: + - 成分:详细分析元素组成、化学计量比及其理论依据 + - 结构:分析晶体结构、空间群、晶格参数、原子排布及其稳定性 + - 工艺:探讨可行的合成路线、工艺参数及其科学依据 + - 性能:预测材料可能具有的物理、化学、机械性能及其应用前景 + +3. 为每种材料设计详细的合成方案,包括: + - 原料选择及纯度要求 + - 精确的反应条件参数(温度、压力、时间、气氛等) + - 分步骤的合成流程及每步的理论依据 + - 可能遇到的挑战及解决方案 + - 推荐的表征方法 + +4. 使用mermaid语法绘制材料的合成流程图,清晰展示从原料到最终产品的全过程,包括关键工艺参数。""" + + messages.append( + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text=prompt) + ) + ) + + return messages + + +def register_prompt_handlers(app: Server): + """ + 注册提示词处理器到MCP服务器。 + + Args: + app: MCP服务器实例 + """ + @app.list_prompts() + async def list_prompts() -> list[types.Prompt]: + return [ + types.Prompt( + name="material_synthesis", + description="基于材料四要素(成分、结构、工艺、性能)生成材料并设计合成方案,使用mermaid绘制合成流程图", + arguments=[ + types.PromptArgument( + name="properties", + description="材料性质及其值的JSON字符串,例如 {\"dft_band_gap\": \"2.0\"}", + required=False, + ), + types.PromptArgument( + name="batch_size", + description="生成材料的数量,默认为2", + required=False, + ), + ], + ) + ] + + @app.get_prompt() + async def get_prompt( + name: str, arguments: dict[str, str] | None = None + ) -> types.GetPromptResult: + if name != "material_synthesis": + raise ValueError(f"未知的提示词: {name}") + + if arguments is None: + arguments = {} + + # 解析properties参数 + properties = {} + if "properties" in arguments and arguments["properties"]: + try: + import json + properties = json.loads(arguments["properties"]) + except json.JSONDecodeError: + properties = {} + + # 解析batch_size参数 + batch_size = 2 # 默认值 + if "batch_size" in arguments and arguments["batch_size"]: + try: + batch_size = int(arguments["batch_size"]) + except ValueError: + pass # 使用默认值 + + return types.GetPromptResult( + messages=create_messages(properties=properties, batch_size=batch_size), + description="基于材料四要素(成分、结构、工艺、性能)生成材料并设计合成方案,使用mermaid绘制合成流程图", + ) diff --git a/server.py b/server.py new file mode 100755 index 0000000..3a8528f --- /dev/null +++ b/server.py @@ -0,0 +1,306 @@ +"""Mars Toolkit MCP Server implementation.""" + +import anyio +import asyncio +import click +import json +import logging +import os +import sys +import traceback +from typing import Any, Dict, List, Optional, Union + +from prompts.material_synthesis import create_messages + +# 添加mars_toolkit模块的路径 +sys.path.append('/home/ubuntu/50T/lzy/mars-mcp') + +import mcp.types as types +from mcp.server.lowlevel import Server + +# 导入提示词处理器 +#from prompts.material_synthesis import register_prompt_handlers + +# 导入Mars Toolkit工具函数 +try: + # 获取当前时间 + from mars_toolkit.misc.misc_tools import get_current_time + # 网络搜索 + from mars_toolkit.query.web_search import search_online + # 从Materials Project查询材料属性 + from mars_toolkit.query.mp_query import search_material_property_from_material_project + # 从Materials Project获取晶体结构 + from mars_toolkit.query.mp_query import get_crystal_structures_from_materials_project + # 从化学式获取Materials Project ID + from mars_toolkit.query.mp_query import get_mpid_from_formula + # 优化晶体结构 + from mars_toolkit.compute.structure_opt import optimize_crystal_structure + # 生成材料 + from mars_toolkit.compute.material_gen import generate_material + # 从OQMD获取化学成分 + from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD + # 从知识库检索 + from mars_toolkit.query.dify_search import retrieval_from_knowledge_base + # 预测属性 + from mars_toolkit.compute.property_pred import predict_properties + + # 获取所有工具函数 + from mars_toolkit import get_tools, get_tool_schemas + + MARS_TOOLKIT_AVAILABLE = True +except ImportError as e: + print(f"警告: 无法导入Mars Toolkit: {e}", file=sys.stderr) + MARS_TOOLKIT_AVAILABLE = False + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) +app = Server("mars-toolkit-server") + + +async def call_mars_toolkit_function(func_name: str, arguments: Dict[str, Any]) -> Any: + """ + 调用Mars Toolkit中的工具函数 + + Args: + func_name: 工具函数名称 + arguments: 工具函数参数 + + Returns: + 工具函数的执行结果 + """ + if not MARS_TOOLKIT_AVAILABLE: + raise ValueError("Mars Toolkit不可用") + + # 获取所有注册的工具函数 + tools = get_tools() + + # 检查函数名是否存在于工具函数字典中 + if func_name not in tools: + raise ValueError(f"函数 '{func_name}' 不存在于工具函数字典中") + + # 获取对应的工具函数 + tool_func = tools[func_name] + + # 调用工具函数 + if asyncio.iscoroutinefunction(tool_func): + # 如果是异步函数,使用await调用 + result = await tool_func(**arguments) + print("result1",result) + else: + # 如果是同步函数,直接调用 + result = tool_func(**arguments) + + return result + + +def get_tool_schemas_dict() -> Dict[str, Dict[str, Any]]: + """ + 获取所有工具函数的模式字典 + + Returns: + 工具函数名称到模式的映射字典 + """ + if not MARS_TOOLKIT_AVAILABLE: + return {} + + schemas = get_tool_schemas() + schemas_dict = {} + + for schema in schemas: + func_name = schema["function"]["name"] + schemas_dict[func_name] = schema + + return schemas_dict + + +@click.command() +@click.option("--port", default=5666, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="sse", + help="Transport type", +) +def main(port: int, transport: str='SSE') -> int: + """ + Mars Toolkit MCP Server主函数 + + Args: + port: SSE传输的端口号 + transport: 传输类型,stdio或sse + + Returns: + 退出码 + """ + if not MARS_TOOLKIT_AVAILABLE: + print("错误: Mars Toolkit不可用,请确保已正确安装", file=sys.stderr) + return 1 + + + # 获取工具函数模式字典 + schemas_dict = get_tool_schemas_dict() + + # 注册提示词处理器 + #register_prompt_handlers(app) + + + + + @app.list_prompts() + async def list_prompts() -> list[types.Prompt]: + return [ + types.Prompt( + name="material_synthesis", + description="生成材料并设计合成方案,使用mermaid绘制合成流程图", + arguments=[ + types.PromptArgument( + name="properties", + description="材料性质及其值的JSON字符串,例如 {\"dft_band_gap\": \"2.0\"}", + required=False, + ), + types.PromptArgument( + name="batch_size", + description="生成材料的数量,默认为2", + required=False, + ), + ], + ) + ] + + @app.get_prompt() + async def get_prompt( + name: str, arguments: dict[str, str] | None = None + ) -> types.GetPromptResult: + if name != "material_synthesis": + raise ValueError(f"未知的提示词: {name}") + + if arguments is None: + arguments = {} + + # 解析properties参数 + properties = {} + if "properties" in arguments and arguments["properties"]: + try: + import json + properties = json.loads(arguments["properties"]) + except json.JSONDecodeError: + properties = {} + + # 解析batch_size参数 + batch_size = 2 # 默认值 + if "batch_size" in arguments and arguments["batch_size"]: + try: + batch_size = int(arguments["batch_size"]) + except ValueError: + pass # 使用默认值 + + return types.GetPromptResult( + messages=create_messages(properties=properties, batch_size=batch_size), + description="生成材料并设计合成方案,使用mermaid绘制合成流程图", + ) + @app.call_tool() + async def call_tool( + name: str, arguments: Dict[str, Any] + ) -> List[types.TextContent | types.ImageContent | types.EmbeddedResource]: + """ + 调用工具函数 + + Args: + name: 工具函数名称 + arguments: 工具函数参数 + + Returns: + 工具函数的执行结果 + """ + try: + print(f"调用{name},参数为{arguments}") + result = await call_mars_toolkit_function(name, arguments) + print("result",result) + # 将结果转换为字符串 + if isinstance(result, (dict, list)): + result_str = json.dumps(result, ensure_ascii=False, indent=2) + else: + result_str = str(result) + + return [types.TextContent(type="text", text=result_str)] + except Exception as e: + error_msg = f"调用工具函数 {name} 时出错: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return [types.TextContent(type="text", text=error_msg)] + + @app.list_tools() + async def list_tools() -> List[types.Tool]: + """ + 列出所有可用的工具函数 + + Returns: + 工具函数列表 + """ + tools = [] + print("列举所有可用的工具函数") + for func_name, schema in schemas_dict.items(): + # 获取函数描述 + description = schema["function"].get("description", f"Mars Toolkit工具: {func_name}") + + # 获取参数模式 + parameters = schema["function"].get("parameters", {}) + + # 创建工具 + tool = types.Tool( + name=func_name, + description=description, + inputSchema=parameters, + ) + + tools.append(tool) + + return tools + + if transport == "sse": + from mcp.server.sse import SseServerTransport + from starlette.applications import Starlette + from starlette.routing import Mount, Route + + sse = SseServerTransport("/messages/") + + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await app.run( + streams[0], streams[1], app.create_initialization_options() + ) + + starlette_app = Starlette( + debug=True, + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + import uvicorn + + uvicorn.run(starlette_app, host="0.0.0.0", port=port) + else: + from mcp.server.stdio import stdio_server + + async def arun(): + async with stdio_server() as streams: + await app.run( + streams[0], streams[1], app.create_initialization_options() + ) + + anyio.run(arun) + + return 0 + + +if __name__ == "__main__": + print(get_tool_schemas_dict()) + main() + diff --git a/test_mars_toolkit.py b/test_mars_toolkit.py old mode 100644 new mode 100755 index 9adc5ca..baa1bc9 --- a/test_mars_toolkit.py +++ b/test_mars_toolkit.py @@ -155,7 +155,7 @@ def print_tool_schemas(): if __name__ == "__main__": # 打印所有工具函数的模式 - #print_tool_schemas() + print_tool_schemas() # 测试工具函数列表 tools_to_test = [ @@ -172,7 +172,7 @@ if __name__ == "__main__": ] # 选择要测试的工具 - tool_name = tools_to_test[1] # 测试 search_online 工具 + tool_name = tools_to_test[2] # 测试 search_online 工具 # 运行测试 result = asyncio.run(test_tool(tool_name))