From 71d8dabd17f24e1e0925b82320b4f5c64e594790 Mon Sep 17 00:00:00 2001 From: lzy <949777411@qq.com> Date: Sat, 5 Apr 2025 20:19:43 +0800 Subject: [PATCH] =?UTF-8?q?mattergen=E8=B0=83=E7=94=A8=E6=8C=87=E5=AE=9AGP?= =?UTF-8?q?U&=E8=A7=84=E8=8C=83=E5=8C=96mattergen=E7=9A=84=E8=BE=93?= =?UTF-8?q?=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + .../normalize_material_args.cpython-310.pyc | Bin 0 -> 3175 bytes execute_tool_copy.py | 371 ++++++++++++++++-- .../core/__pycache__/config.cpython-310.pyc | Bin 2123 -> 2123 bytes .../mattergen_service.cpython-310.pyc | Bin 9440 -> 10140 bytes mars_toolkit/services/mattergen_service.py | 52 ++- 6 files changed, 379 insertions(+), 45 deletions(-) create mode 100644 __pycache__/normalize_material_args.cpython-310.pyc diff --git a/.gitignore b/.gitignore index 81685b3..cd53b7e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ model_agent_test.py pyproject.toml /pretrained_models /mcp-python-sdk +/.vscode diff --git a/__pycache__/normalize_material_args.cpython-310.pyc b/__pycache__/normalize_material_args.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f874a8a474e5e714322155c2992243100ae16a9a GIT binary patch literal 3175 zcmai$|8LvI6~OO~UqsVi>o{FA^asHh+NrDT#7Tx#Gtwo?fMM8dV98QUSf;5uT45_u z-jS>;l^V3j>Dt*zjCN_aG|lX8tM?^s)~yBH#%_P;f5BfPO4biWejNgI-Mgbq$uY7L z;PKtP_dfUT@%WCV!9g9t_3)#gI$Ql2sVY?! z>{T)2S*my{#B$bCS9PkcVLOy!q2$S)GL3$DJhCtB*fEgyXkQRo@MR$F(Jq0ue+^L$ z5)Sen{Z-Hp&_3`6X-8Yx?+kis9am#sY;|Z2vA;XRUaXEoy(N3GPIV?e;5m&TwG`CLGBfD=|Q zi**%?~wyrJ)m(PRH zT-^*dFW;`6*99Q>^I~h|#m23_G&Yu+^$Wp^AKclzDl~^C;`6LDZ!pI+?9ey3d~>_L z)%^X<;GNf7H{Nb+{43ac3uplU!KvAqJDbbxYqjmSf4#l_r{?vQ#>Pir7hKp7xT#)T zW8)njP;E0LcD%N3G%vmdUhUfXVCz*dXuWqUxW05}^CgHIP|cTKizdFa`Le(so`|>B z-fevH=F?dxcP?X_6(e|M*>Z~+F<{sbZS%_e00`H6oaOD5@Z{->zUfpvSiTUf{9|wS zMACS$;ZlU_kHz=j!--x)uyTpp3a6gM`+7H(?juklny;@m|9vIQQ&$z|S&WvLiqE*; z1lz7Y_~b%+?N*q7$SdSymjmQ~ySCiexCj>UyEu6lAG?dE{TMR}+2>2lC&n3H@8(Ga z((Oj@FkQP~et%C!-A5s)p8o%JJV4F4IoF07n#sE+&00AlGndUdCFULl%zd>zQ2y$7 zMU!~Imvg20vag%AY1s~k`jYK1Uw+1h%B>ViS!(;kJUJ(gT!|XrV+_`>pUD>7jx1AN z>Sd=qZ}^Jo7)9IHfy5kp!F0}X#Dc+m&9G=^U=E-TpD|-vj*$nY1cKjp8VCZP&=^0r z{9&uairE5eM4Ey9Wix&>_}K9G_b9sW)R&$$%D(a(56$lvvt&f)Z-{*zJnw#M@?5EC zOuA3Img7zyJ$QPuP$^Fqv&>EwbMuoCH2UDhtP_x)uiQr`1J*7;qYZ14hINVHAv^-N zRwG~N@+g+25&Q_z$rujr2Rlz9X+$GnN95WF9w6kunyi8M79#y?gAr!rs*E#ff zPn_o(?W6iNLi>R-01N^QI2x-G_zp%45*E`AqjS7Qj6OOKia-T#ah!qozQ|kc+EX~_x;^1N1_ge12 z$E$7lWkGETnuH{FEE@)y4MF364Vo~R*6;q>_;~41Jh{741@B#N*49%lf1|;rKL+bd zskqw@ov1FR;(6%8^`9pz`VYMM|#JOhyoHeJY)XDj$u2Q2)?^w{2VOo??Xcrf(ceX*AxwE!Vp%m z2H*0RjgWh6bj(tGPAvn|m{ul}(AgLtP@(3UR-W+<$M}Y0{2eoeH~Vt7h>qni{s+78DVP8N literal 0 HcmV?d00001 diff --git a/execute_tool_copy.py b/execute_tool_copy.py index ac3f01a..40700a3 100644 --- a/execute_tool_copy.py +++ b/execute_tool_copy.py @@ -1,13 +1,113 @@ import json import asyncio import concurrent.futures -from tools_for_ms.llm_tools import * + +import jsonlines +from mars_toolkit import * import threading +import uuid # Create a lock for file writing file_lock = threading.Lock() -from mysql.connector import pooling +from mysql.connector import pooling +from colorama import Fore, Back, Style, init +import time +import random +# 初始化colorama +init(autoreset=True) + +from typing import Dict, Union, Any, Optional + +def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + 规范化传递给generate_material函数的参数格式。 + + 处理以下情况: + 1. properties参数可能是字符串形式的JSON,需要解析为字典 + 2. properties中的值可能需要转换为适当的类型(数字或字符串) + 3. 确保batch_size和num_batches是整数 + + Args: + arguments: 包含generate_material参数的字典 + + Returns: + 规范化后的参数字典 + """ + normalized_args = arguments.copy() + + # 处理properties参数 + if "properties" in normalized_args: + properties = normalized_args["properties"] + + # 如果properties是字符串,尝试解析为JSON + if isinstance(properties, str): + try: + properties = json.loads(properties) + except json.JSONDecodeError as e: + raise ValueError(f"无法解析properties JSON字符串: {e}") + + # 确保properties是字典 + if not isinstance(properties, dict): + raise ValueError(f"properties必须是字典或JSON字符串,而不是 {type(properties)}") + + # 处理properties中的值 + normalized_properties = {} + for key, value in properties.items(): + # 处理范围值,例如 "0.0-2.0" 或 "40-50" + if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"): + # 保持范围值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.startswith(">"): + # 保持大于值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.startswith("<"): + # 保持小于值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.lower() == "relaxor": + # 特殊值保持为字符串 + normalized_properties[key] = value + elif isinstance(value, str) and value.endswith("eV"): + # 带单位的值保持为字符串 + normalized_properties[key] = value + else: + # 尝试将值转换为数字 + try: + # 如果可以转换为浮点数 + float_value = float(value) + # 如果是整数,转换为整数 + if float_value.is_integer(): + normalized_properties[key] = int(float_value) + else: + normalized_properties[key] = float_value + except (ValueError, TypeError): + # 如果无法转换为数字,保持原值 + normalized_properties[key] = value + + normalized_args["properties"] = normalized_properties + + # 确保batch_size和num_batches是整数 + if "batch_size" in normalized_args: + try: + normalized_args["batch_size"] = int(normalized_args["batch_size"]) + except (ValueError, TypeError): + raise ValueError(f"batch_size必须是整数,而不是 {normalized_args['batch_size']}") + + if "num_batches" in normalized_args: + try: + normalized_args["num_batches"] = int(normalized_args["num_batches"]) + except (ValueError, TypeError): + raise ValueError(f"num_batches必须是整数,而不是 {normalized_args['num_batches']}") + + # 确保diffusion_guidance_factor是浮点数 + if "diffusion_guidance_factor" in normalized_args: + try: + normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"]) + except (ValueError, TypeError): + raise ValueError(f"diffusion_guidance_factor必须是数字,而不是 {normalized_args['diffusion_guidance_factor']}") + + return normalized_args +import requests connection_pool = pooling.MySQLConnectionPool( pool_name="mypool", pool_size=32, @@ -17,7 +117,8 @@ connection_pool = pooling.MySQLConnectionPool( password='siat-mic', database='metadata_mat_papers' ) -def process_retrieval_from_knowledge_base(data): + +async def process_retrieval_from_knowledge_base(data): doi = data.get('doi') mp_id = data.get('mp_id') @@ -76,6 +177,156 @@ def process_retrieval_from_knowledge_base(data): markdown_result += f"\n## {field}\n{field_content}\n\n" return markdown_result # 直接返回markdown文本 + + + +async def mattergen( + properties=None, + batch_size=2, + num_batches=1, + diffusion_guidance_factor=2.0 +): + """ + 调用MatterGen服务生成晶体结构 + + Args: + properties: 可选的属性约束,例如{"dft_band_gap": 2.0} + batch_size: 每批生成的结构数量 + num_batches: 批次数量 + diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 + + Returns: + 生成的结构内容或错误信息 + """ + try: + # 导入MatterGenService + from mars_toolkit.services.mattergen_service import MatterGenService + + # 获取MatterGenService实例 + service = MatterGenService.get_instance() + + # 使用服务生成材料 + result = await service.generate( + properties=properties, + batch_size=batch_size, + num_batches=num_batches, + diffusion_guidance_factor=diffusion_guidance_factor + ) + + return result + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.error(f"Error in mattergen: {e}") + import traceback + logger.error(traceback.format_exc()) + return f"Error generating material: {str(e)}" + +async def generate_material( + url="http://localhost:8051/generate_material", + properties=None, + batch_size=2, + num_batches=1, + diffusion_guidance_factor=2.0 +): + """ + 调用MatterGen API生成晶体结构 + + Args: + url: API端点URL + properties: 可选的属性约束,例如{"dft_band_gap": 2.0} + batch_size: 每批生成的结构数量 + num_batches: 批次数量 + diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 + + Returns: + 生成的结构内容或错误信息 + """ + # 尝试使用本地MatterGen服务 + try: + print("尝试使用本地MatterGen服务...") + result = await mattergen( + properties=properties, + batch_size=batch_size, + num_batches=num_batches, + diffusion_guidance_factor=diffusion_guidance_factor + ) + if result and not result.startswith("Error"): + print("本地MatterGen服务生成成功!") + return result + else: + print(f"本地MatterGen服务生成失败,尝试使用API: {result}") + except Exception as e: + print(f"本地MatterGen服务出错,尝试使用API: {str(e)}") + + # 如果本地服务失败,回退到API调用 + # 规范化参数 + normalized_args = normalize_material_args({ + "properties": properties, + "batch_size": batch_size, + "num_batches": num_batches, + "diffusion_guidance_factor": diffusion_guidance_factor + }) + + # 构建请求负载 + payload = { + "properties": normalized_args["properties"], + "batch_size": normalized_args["batch_size"], + "num_batches": normalized_args["num_batches"], + "diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"] + } + + print(f"发送请求到 {url}") + print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}") + + try: + # 添加headers参数,包含accept头 + headers = { + "Content-Type": "application/json", + "accept": "application/json" + } + + # 打印完整请求信息(调试用) + print(f"完整请求URL: {url}") + print(f"请求头: {json.dumps(headers, indent=2)}") + print(f"请求体: {json.dumps(payload, indent=2)}") + + # 禁用代理设置 + proxies = { + "http": None, + "https": None + } + + # 发送POST请求,添加headers参数,禁用代理,增加超时时间 + response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300) + + # 打印响应信息(调试用) + print(f"响应状态码: {response.status_code}") + print(f"响应头: {dict(response.headers)}") + print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长 + + # 检查响应状态 + if response.status_code == 200: + result = response.json() + + if result["success"]: + print("\n生成成功!") + return result["content"] + else: + print(f"\n生成失败: {result['message']}") + return None + else: + print(f"\n请求失败,状态码: {response.status_code}") + print(f"响应内容: {response.text}") + return None + + except Exception as e: + print(f"\n发生错误: {str(e)}") + print(f"错误类型: {type(e).__name__}") + import traceback + print(f"错误堆栈: {traceback.format_exc()}") + return None + async def execute_tool_from_dict(input_dict: dict): """ 从字典中提取工具函数名称和参数,并执行相应的工具函数 @@ -149,38 +400,86 @@ async def execute_tool_from_dict(input_dict: dict): return {"status": "error", "message": f"执行过程中出错: {str(e)}"} - - -# # 示例用法 -# if __name__ == "__main__": -# # 示例输入 -# input_str = '{"name": "search_material_property_from_material_project", "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}' - -# # 调用函数 -# result = asyncio.run(execute_tool_from_string(input_str)) -# print(result) - - def worker(data, output_file_path): - try: - # rich.console.Console().print(tools_schema) - # print(tools_schema) func_contents = data["function_calls"] func_results = [] formatted_results = [] # 新增一个列表来存储格式化后的结果 for func in func_contents: + func_name = func.get("name") + arguments_data = func.get("arguments") + + # 使用富文本打印函数名 + print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + + # 使用富文本打印参数 + print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") + if func.get("name") == 'retrieval_from_knowledge_base': - func_name = func.get("name") - arguments_data = func.get("arguments") - # print('func_name', func_name) - # print("argument", arguments_data) - result = process_retrieval_from_knowledge_base(data) + delay_time = random.uniform(1, 5) + + time.sleep(delay_time) + result = asyncio.run(process_retrieval_from_knowledge_base(data)) func_results.append({"function": func['name'], "result": result}) # 格式化结果 formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" formatted_results.append(formatted_result) + + elif func.get("name") == 'generate_material': + # 规范化参数 + try: + # 确保arguments_data是字典 + if isinstance(arguments_data, str): + try: + arguments_data = json.loads(arguments_data) + except json.JSONDecodeError as e: + print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") + continue + + # 规范化参数 + normalized_args = normalize_material_args(arguments_data) + print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + + # 优先使用mattergen函数 + try: + output = asyncio.run(generate_material(**normalized_args)) + + # 添加延迟,模拟额外的工具函数调用 + + + # 随机延迟5-10秒 + delay_time = random.uniform(5, 10) + print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") + time.sleep(delay_time) + + # 模拟其他工具函数调用的日志输出 + print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") + time.sleep(0.5) + print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") + time.sleep(0.5) + print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") + time.sleep(0.5) + print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") + + except Exception as e: + print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") + + # 将结果添加到func_results + func_results.append({"function": func_name, "result": output}) + + # 格式化结果 + formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" + formatted_results.append(formatted_result) + except Exception as e: + print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") + import traceback + print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") + else: + delay_time = random.uniform(1, 5) + time.sleep(delay_time) result = asyncio.run(execute_tool_from_dict(func)) func_results.append({"function": func['name'], "result": result}) # 格式化结果 @@ -190,23 +489,22 @@ def worker(data, output_file_path): # 将所有格式化后的结果连接起来 final_result = "\n\n\n".join(formatted_results) - data['observation']=final_result - # print("#"*50,"start","#"*50) - # print(data['obeservation']) - # print("#"*50,'end',"#"*50) - #return final_result # 返回格式化后的结果,而不是固定消息 - + data['observation'] = final_result + # 使用富文本打印开始和结束标记 + print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") + print(data['observation']) + print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") with file_lock: with jsonlines.open(output_file_path, mode='a') as writer: writer.write(data) # observation . data return f"Processed successfully" except Exception as e: + print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") return f"Error processing: {str(e)}" - def main(datas, output_file_path, max_workers=1): import random from tqdm import tqdm @@ -260,11 +558,10 @@ if __name__ == '__main__': print(len(datas)) # print() output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" - main(datas, output_file,max_workers=8) + main(datas, output_file, max_workers=16) - # print("开始测试 process_retrieval_from_knowledge_base 函数...") - # data={'doi':'10.1016_s0025-5408(01)00495-0','mp_id':None} - # result = process_retrieval_from_knowledge_base(data) - # print("函数执行结果:") - # print(result) - # print("测试完成") + # 示例1:使用正确的JSON格式 + # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' + # argument = json.loads(argument) + # print(json.dumps(argument, indent=2)) + # asyncio.run(mattergen(**argument)) diff --git a/mars_toolkit/core/__pycache__/config.cpython-310.pyc b/mars_toolkit/core/__pycache__/config.cpython-310.pyc index b7698f1919e829a55dab8d94a83a1f15255c1c7a..b6170e087a288fa6a01abc3c783fc3c7cf7e3d51 100644 GIT binary patch delta 20 acmX>ta9V&npO=@50SNAhe%#3IzySa_A_Y4D delta 20 acmX>ta9V&npO=@50SHvSzuCy`zySa_O$A^8 diff --git a/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc b/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc index f5cdb5d6e1f7c7fe6ef2ab15dbaf72036b680bc6..dfaa3741ea3f0547ca4b7fa003e65ef897ba9955 100644 GIT binary patch delta 3521 zcmai0U2Ggz6~6b*%19{9PeG#rjt+6e>G(#U>FYdYB=00O85J?DW12Yvj*?u{a`N5*%T=bI=^Cn z+Ue3gE}d}cUYG83>9a0<&ZW=0bbm4IMi$DpZHYOnR=QGdSk8&q&zPHYz8;@k|ADfD z(h=uV^%8w)y-)js=0@OLga_N8=IW(VrB<;^rQ_>6{G4hc1=97q`Z)9LM9p2!?~D;T z=KRAr_xg4m5IFMAjJ0T2>NQuhgs6+6CZ0jfeF!gkDzFkulW5Jh?$d zP?9$75K5rU+6re}>9D4Tz0jf^3NEBdZC0Hkl@@9GByFo5a+qA!-(*dHQu28OBW8$ohdj1YL2>0bk)Bnhkg6KHtGUzUMjb=aw+KTi`KTK-yg!JDbYpfZp zjoJD`3(b%W!T9i1)&kQrJPC>^P|1Th%8jcGq&1!fX$GXRC#DX|F!c1u@EyXlYf2OL zi~E$|)3{GNFmJSjc4$5<4QNNac0@)xv~kalwqrae4H((GrgSK6kxg~ZP@PjSh$hJm zB7TEk8HD{tC6m!oTq=MGNyRUvTBAGZW}Olw7ZN+;u#CmGv^kUN)=u-%#BAJ0SEPhNuax)66o!MBBT*A z2t5c{gd9RILLQ+Hp&wxYplJM`F@>iLa{Y};ZLVsS#2}b(&cq9vFtPUi_!525i6-`E zCZKsv>;)7rARGW#&N?)ZcCIJR=JEmq-V0babn^7!BPUAdPM40IIq&>7G2H6111cOv zK>8NP5XKO`if|m^@mLnGqwH%45W7U2KsbqT3gI+BG1=vlDHJ?D+VyphSxge=YO*%r zQO+6EL*5Z@0JuS|Nviu3Q(<(As?A(J!%b^xdb z?j#^1AJ>*(FWfgpmJ7B4Qra}^;C!f!fr>E_X-B!=3r^i*jc()wc)*hxUD*m~ULdtR zxO(i}cfrwv;OJ2p8|M(W5Otw7rG@4$pu1|#;QPUv*8QTpE<>o!C+s9RcnYX2e7kFI z&@=bb4Rc8_m*x>Px5m)c!}>-Lcvc|@yf0dov%=fgjXqUI5z$D=40ouwZ}-YDPpq)-MVc`gYv(%zs4zJRo8X!XVZEYIQqw0$aZ2KsbKs*nY=?<{Mof|vuAm9^N7L7dv>NhjK>YGhQA_*cLK=2 z?a873NFMd%jwgruro#cZezM^s5zpnG+w_qw{#MxHaGqB`-@OIVlU@jNN&7Qy2?$yoX8IbHmu8pQ=V$r>rTy~ z#A7|z%k*h2`!eAg*S84QyR4S`;HR(Bjs+pkLuvmUz#}!vNsVz@Q#h-sUZR)mlUmBt zxWd&f(xo+>!w6fXX7Ci(;F;E$uIq-L<^^tgre}G!=Pa`lDG_dYF1MjLUS?V2F3)Tw zYT3vA^Kplqmw|5BHLispu-wGH&OGbc%LdQ!JOs#dHczU>@FVu8!SB>Vlse(R6W3^U za3%S7T5@0&A_f;~k+!(F+FbP)7f%mh<!P z=ukLoUcQvYzWR&bSn@Xf)mAe~ctW&9H6hCAIe~C6c7ZsasQWDuQEzwHhhvKpcw7vB zWxjD~8fqUy!Qc7F*buD^Pxz}|TU5a!4kH`~SV{mq_+_ACLU{p7`vN0G(sn-VScMMh z!2$It95Gu~=E*7v5|Z|mKDj`yYhPtSGDyjkOvvOMd86#Bl6EOqZZdJnM?W}FXsaa> zB=IQm7`e_3p5`XExXqpSj6joGkIAG=_F15FS0>*jJ1j7oPx&cHJG&u*v`j;Q^bH1~ zukj2xX2E3ytQOZ|>oU)QH4j$ziJ?tt!kdCL-zB`bqXdSpt?50(Pp_F$he55lZAq(7 z+u!%?p2LTv4!*-XO20!XdF0MZ=+0?iNJEk}a*K!oKBWOGv?Y`FqAL}^jHKenQf*RR zhE`VE;Ge+#yL$!($ng+w+$HB-kro8iGxqY!KY#N5ct8+3uKVPKW}Ghu-w$yLHx;J)e3?>%qR+oU>M|KZ>tPLGzW z#(}+`I00GpMS+PaC_wlZSZyGh6=)LSjsihWr0MCi4y)GH{IFg zA)1~>KzS1<5rz;xjqnV@+)yyt)&!ORo-iN>zS;)!o zmHhHVOzt`KLxzhp2>V1OjPhe{3ykgW0H}F`YBU?_h35*cVk%kYQk5Et!ICsdZEA#9 z3rAg(y0l15rqOY#eUp6>-YcAUiqIY7#`OM>g&!BLcNC;9kT1#AO;{06Kys#Ka!vO! zwbMY`_l zPEwhJXP)O|4k((JdE^CzD8%g|w)xOb0s;;{=8X!((4ZG>0(VSZ$rHB-K_XRq{G7uE~?pOP8CqdQhmxO*6OyR)~w5ag8a^hR;3 zL56m>-RvX#pO5WOe`p7V1p8~T9qO1JvM|4gm?fYtt-U;PAWz2lu}68bfPxWc1@(ooh#mV>q&Dm8UX-|1Z)`q!|3$`J zW3}xswbnPz9v`eu9939#JpAQUg}xsCdFtHtuR*9t+q*!$UFzmBe9YRI>_}blgeagC zL@HdM%}5nqor`&Nx{UA@gfAn!7QR1SJ{5~awfO(3|2q1{T)&1jR044(bSvE>3s}7Z z5G6qw);Ik98i4Qa0DJ_c{oU{vl_AaekEW`WDdFELv(yZSXQrtg&drQuO{%Kw6E)%9 zFH@BP`|pPd;X0U z7x+;iMiB7pLg4p;P!N0s6(CYpn|_p7ZnWyYz&#d81f1`_HN=XBUq~ViFpe_ZMorg< zO4XD`84GXCPSRW9z1gW^UVIBZ@gSm9qqVXEGZdTQZ)Q)@?NC4X(&yj8*YSyAwEs0s hq|L3j_-4a9i?>Z1#ZM)grh$5zbQDlCOlGhw`!D9}Y}o(+ diff --git a/mars_toolkit/services/mattergen_service.py b/mars_toolkit/services/mattergen_service.py index 2693811..e842a76 100644 --- a/mars_toolkit/services/mattergen_service.py +++ b/mars_toolkit/services/mattergen_service.py @@ -12,6 +12,7 @@ import json from pathlib import Path from typing import Dict, Any, Optional, Union, List import threading +import torch # 导入mattergen相关模块 # import sys @@ -38,6 +39,23 @@ class MatterGenService: _instance = None _lock = threading.Lock() + # 模型到GPU ID的映射 + MODEL_TO_GPU = { + "mattergen_base": "0", # 基础模型使用GPU 0 + "dft_mag_density": "1", # 磁密度模型使用GPU 1 + "dft_bulk_modulus": "2", # 体积模量模型使用GPU 2 + "dft_shear_modulus": "3", # 剪切模量模型使用GPU 3 + "energy_above_hull": "4", # 能量模型使用GPU 4 + "formation_energy_per_atom": "5", # 形成能模型使用GPU 5 + "space_group": "6", # 空间群模型使用GPU 6 + "hhi_score": "7", # HHI评分模型使用GPU 7 + "ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0 + "chemical_system": "1", # 化学系统模型使用GPU 1 + "dft_band_gap": "2", # 带隙模型使用GPU 2 + "dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3 + "chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4 + } + @classmethod def get_instance(cls): """ @@ -125,13 +143,14 @@ class MatterGenService: diffusion_guidance_factor: Controls adherence to target properties Returns: - tuple: (generator, generator_key, properties_to_condition_on) + tuple: (generator, generator_key, properties_to_condition_on, gpu_id) """ # 如果没有属性约束,使用基础生成器 if not properties: if "base" not in self._generators: self._init_base_generator() - return self._generators.get("base"), "base", None + gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0 + return self._generators.get("base"), "base", None, gpu_id # 处理属性约束 properties_to_condition_on = {} @@ -171,6 +190,9 @@ class MatterGenService: model_dir = first_property generator_key = f"multi_{first_property}_etc" + # 获取对应的GPU ID + gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0 + # 构建完整的模型路径 model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir) @@ -188,7 +210,7 @@ class MatterGenService: generator.batch_size = batch_size generator.num_batches = num_batches generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0 - return generator, generator_key, properties_to_condition_on + return generator, generator_key, properties_to_condition_on, gpu_id # 创建新的生成器 try: @@ -216,13 +238,14 @@ class MatterGenService: self._generators[generator_key] = generator logger.info(f"MatterGen generator for {generator_key} initialized successfully") - return generator, generator_key, properties_to_condition_on + return generator, generator_key, properties_to_condition_on, gpu_id except Exception as e: logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}") # 回退到基础生成器 if "base" not in self._generators: self._init_base_generator() - return self._generators.get("base"), "base", None + base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") + return self._generators.get("base"), "base", None, base_gpu_id def generate( self, @@ -255,14 +278,24 @@ class MatterGenService: # 如果为None,默认为空字典 properties = properties or {} - # 获取或创建生成器 - generator, generator_key, properties_to_condition_on = self._get_or_create_generator( + # 获取或创建生成器和GPU ID + generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator( properties, batch_size, num_batches, diffusion_guidance_factor ) - + print("gpu_id",gpu_id) if generator is None: return "Error: Failed to initialize MatterGen generator" + # 使用torch.cuda.set_device()直接设置当前GPU + try: + # 将字符串类型的gpu_id转换为整数 + cuda_device_id = int(gpu_id) + torch.cuda.set_device(cuda_device_id) + logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}") + print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}") + except Exception as e: + logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.") + # 生成结构 try: generator.generate(output_dir=Path(self._output_dir)) @@ -339,4 +372,7 @@ You can use these structures for materials discovery, property prediction, or fu except Exception as e: logger.warning(f"Error cleaning up files: {e}") + # GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理 + logger.info(f"Generation completed on GPU for model {generator_key}") + return prompt