From 72045e5cfebd66c048d1fca9171b0064627f40be Mon Sep 17 00:00:00 2001 From: lzy <949777411@qq.com> Date: Sun, 6 Apr 2025 20:35:13 +0800 Subject: [PATCH] =?UTF-8?q?=E7=94=9F=E6=88=90=E6=95=B0=E6=8D=AE=EF=BC=9Ama?= =?UTF-8?q?ttergen=E6=94=B9=E6=88=90=E4=BA=86=E5=90=8C=E6=AD=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + __pycache__/execute_tool_copy.cpython-310.pyc | Bin 0 -> 12157 bytes execute_tool_copy.py | 218 ++------- execute_tool_other_tools.py | 423 ++++++++++++++++++ .../__pycache__/__init__.cpython-310.pyc | Bin 1115 -> 1115 bytes .../__pycache__/material_gen.cpython-310.pyc | Bin 8998 -> 10463 bytes mars_toolkit/compute/material_gen.py | 95 +++- .../core/__pycache__/config.cpython-310.pyc | Bin 2123 -> 2121 bytes .../mattergen_wrapper.cpython-310.pyc | Bin 1013 -> 1013 bytes mars_toolkit/core/config.py | 2 +- .../__pycache__/web_search.cpython-310.pyc | Bin 2166 -> 2243 bytes mars_toolkit/query/web_search.py | 3 + .../mattergen_service.cpython-310.pyc | Bin 10140 -> 10140 bytes test_mars_toolkit.py | 5 +- 14 files changed, 557 insertions(+), 191 deletions(-) create mode 100644 __pycache__/execute_tool_copy.cpython-310.pyc create mode 100644 execute_tool_other_tools.py diff --git a/.gitignore b/.gitignore index cd53b7e..1ce8b20 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ pyproject.toml /pretrained_models /mcp-python-sdk /.vscode + +/*filter_ok_questions_solutions_agent* diff --git a/__pycache__/execute_tool_copy.cpython-310.pyc b/__pycache__/execute_tool_copy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5020866c2202629d7163c90a55bbc460e4b167f0 GIT binary patch literal 12157 zcmeHNYj7J?mhP8Yt(Gjyj^o5hAc_+wQ6P2_hGEOg%uWFFm<@qV0@-L_Rf@Xp$k8kP z7{{VTIWGthCmx;yNkC?p1ZE(C85RNw!*0z!X8-N4t*zc4J6l+?{fpY#N7XVs_B*#F zSxyWy`+ueC+qdsM_uTh6=R5c2TUt~OzYX8|_o1l=IPMerX#NH8@jX1|hdjq=oW^Ik zxcBDcJo^^nB6~}532!0e+vVlUahaw2<9=N(E0bJ2fNwFQW`pq{Pc%{{lx>N(WW(`r zwl&_Gjl?4?ozJwz+gRE5c)K3ae40GYvmAeBQMM!AkzE{Lob8NvW|zd5AYIWr^rd>| zpwPu>fw}mgq^Vl)s1)yFyq3k6>A~UUT1aaF$FSC_MewwZ`{OILc5Ts7E`D8r4|D_; z2RgNwwnSTcl#gE@=xU^G(3WY-k+xD>p+^S=?KDVsU-@(9xDzXy+ggMOpF_6UE3Mx#gX8SW`>oaD;A+S*C}G^edQE5`fu)%se! z4}Dxe4jvl6lRdE;XK7Zq$9R|DfWTkK=QHWtP)u^AJM)I_O1C9b`(58oYczw8bS`b} z0$CtwTgWwaiwN&Xr!3d^KrWrnx#HH`sH@ypu<()0fY2DSvwF@lUD?nr+sLJ`UK&3I zF8nvFNwNvg8Ppd`Z~N5+4;u_z{kA@E;M(KL&+lDGu`MRB5gRj^)54 zhmkKmB$Rw58T2x%@nOy9$Y{&Mj$BqXd6GB1nBR*zlHszpF(bM5-S($1o3O~UqXIJ3!pS&_1xlOtHqIz@FlxsDcx(sa+hn;e9KI5pQqs_{8IZ z@eF!Y_3<3xv=DKa;0=F`N2tal#CR;K^O)dA}fBET^ zmsrl`-e|$d7j(l)>!w$>a_ZvDk&miBzF2wdh3PlXOh#iRQ&0cCm7qCdZW{m7p6WwyEkd2_9e`8QLj9GB4=k4EI~KX z+Uj%ff`QlHt;Ueq&q~tpHTHCWv~uEDW#Z@8jvgB{YOL24R&)p3|4n=v(@>c>Nwlm` zQ|DcEZYb@w5E0cE%GF;!=Z(``R+L7u(Z&AgcKh3)oy)I$d~9~|vN!%1SB%HJ0vP|9 zu@h65#z7Hv`He3BU6;Sr^&5I7c`$Fdg8rbZ&W#gm$lPaxxi?(%nFX6h)Lw^1HGO|$ z!7eR5IB1)Yp@|_ottE3QJu#R}LFU>wgU#3GHZi5CpVyjLn=7UAg;7^co9Ucs5mQ$* zEyI5VZu>mua+8 zDzCeXn=b5eyyMi9lZB+*4~1s2#NqYK^lwjAwl?3krb5fK>6jnYRTlTRfx`|_scF3RdwQcGG%YH6a7gupS2ikVJY8?)(D%a{5CakrhNGR2)Rb(0pDD=~(WZPLsYlBPM5H#Aov=6jQ-K2KkVXb|xH8jslxae9a+ zjji#}#$Jr+WyoL7=8VOT;Jq~w)`nCPhe*#-9&c=5dGr=F-!0JH4mA9bP?B~?&^qNZqpalRxMBhZkNBRbNgz;lTEK;wt5kjwHU zETtTDxfdHD4dQF0)?(0selyG3E0&ZYXu?Iz*3N)Y!ODp6A5w@JA9w`TGM}gi6Z+As4pLTuOLL#lXTyf*hfo%h~?}~1S-nrwx zd!tA;M|0Lb-NY_2Q~UI+j?FTdkM6#EV8=kT*onf???ijd=zTkgIIXd-qEb&Us&lDi zoT;=)$^0@#W3taVZkSp#m*~nV+b|(# z{IxylDuX(dQa%R}lgi{xyr`jtSjd%VQ5xUD;xoRB!1W!lbz{_(3ur$?5yV?AXZ$JU zz}PhH3~I-Y->>CIa*0}kJebxq5OY5Ec3aqsnq$F8Wun6iHW8N|2%HqQ2@0bHIs{B& z(8y;K`*Zn`jIIsoP?x4Y3R3en1j3lgt9(S1AP^)zB0yB|iuV^}o;H4nr$^;Gg${ve zBoP875E4Q{Cohaig0Tl|W?=_&vt%Yu5JnNeCP2EQmeom)rOH;2j7_52&B=fx*y#R8{cOYIO{RCuZ|e=y>+K`c z&%SdQwLGCSs^KkSZE7iY(wAfw@( zm#bsvr$0PbJ^KpO-PEt1s+>FgNVhgzn5)6UV1tf2#7~Ij`f_m`Yc*u=3ck%7qVL$;>|c;>^X1Qy))O4_#ypHfY^I2t*n~ zyi+%Rm`>?$NKB;~i;-`1Qm_QU_F@N9!_jmu3QdAZfi2e`EiU&`Yla=91m!ZrxIc=G zxQ-MlQ=a?NRxdOM)nN8EI6%3uB;#S!;f83(H%tztZY{cjfrBZXAw}1p$qx;YFbaTn z8{(vl5=?(360T2YmPo)dk|}*J0aBGoCM%&IOc~2iM=ZEd@4V{$^~UU?UT>fvMa`0_ zPH-Ng6fzo`c7CAImG_Zhu16q@3GE?45hd(n?;rYH_CvYTo<|twmytp6nXw>^uVY0GMniJCY`_y)0rc z`y8LK5^7A)B*4*?p!E^0VFo-AUwj6Lt9K8*D9-yn(ihRE%5p?SyfPiEOaQK!#p6P~eW2>B@tT*eLB45Az(=)|;Q(Zw;sl_m zmug+(zETjSmN~&;C}?sDuof+c)CvxV=6VjA6-a3TpX-jnVFWdP{aKNXB}!^50m!GI z?qNjGm}}84&iH4PS>=tk*4>5iWh1^Cxohg+lSZA42~hT2)dE&qy+s}s%?PNFyJxIzBn;`w$rR}jyN&9*Z~`dYs{jiU zy&+SnR^e;!LUZv(0Vz$%FgTDb00co?yFLodG<=C%Ofr^()b;PvlN!JkJflYbAbdKm z-%4loyj`pEneobp&rU!2UUlryboo%N$_h)ZP?fCDwQr`h`=jaq1plIlxt;wDtHQL= z2h(Jv>b!|`{WZ(fCOBNn^JcF<>iV-Nl^oJ-YN=Z9XKiYm+r$cEHo)667;5eEK0JR> zf0UUHtPx&%rTkd6d>q5m?)nNq3Kr9B1}#BjF9oD{8BFzt%BR`cg6S%T4zt{Z$5$#4 zR1LyGVA&>YY>eYdDq;TycUe3 zVN()nHw*b3RQ^J0z3m$P`)?>@ItZZN7b`WzPU=1Md019>nQsSvzn&p|!BopF?!uBLMZ+e#36n+Lk-sD&0Xv-iiac3>4ewlM5AJD|TL&36X2IJ6Fb!~@pi z4|$^pIkH2?0sgaq+#UshGP&0!IKX|5U}P)-ZgvrF_M06z!E?8;ZwcQSIMxA?ncE$< zMQ#Ihm2cQTy5DoU!g;O`1fQfg4p#t>4Y2@tTazF@PXa&kD zfIDwI3L7Vqh86P~f074kX@R8!=Qj1JIj5dhMot&kx?KK5>3#^2GG%OC&jfV4CO`f<1Gq0xgCvZE(#BIf5vy?_0VVhL#}EhM{2H=?RGm@kOS58Q8$DwK`pPNGao+= zpohic=~GBu(zvmZaW`HJ{V-T%8eRDI#3d86B;t(~6kJDy@J<&p07I@n2{%kCop%?d z%~aliUP|ZmK|9ABE=)|v0>&{a5};JiBN%SXX@Wc|u8-WHrW;`1P$KhuvVp<$LAYb# zeQ6TB)Z8YZ)H%=E2@XL(sC3z?Tix3kL^INLXDcj`vsi3{dR&cuD9A8zts-m{VqiD zqH>2rZXF2G+$spZMZ=3ne+khSguv`zz3K$_R`OrgdH|qD@C6V8RF4L6dKmIz3=%FO zrw>5K{PFm|jb3E_=%GzZ0V4qjwM7IHn@94hdO8*;8Ctkxafpp@m$n`XB}57ZqtL?U zBA_jXqEZ;|Wvo700i6Jq_A|g%d7cPq4JP$8#%d-n#nwJpwn!*YA<~ z`Y+L0^zYp*+NyGt42@E&gPo^!-^y95PHn<0cMq?|>BDMm&A4!?r4(@@mOGnNaw;M7=<=b}LbJyLw+`x{3oddfPTeod<5=hHxy9 z&$@m)mnK^lKooc$XG|GgFBs#*L^SLiJBfjJp-{j=^^CEF>@ac1z#Se0cBnZF$NBHw zde^}AU0ZQ?YTE+?9#n3iTV})H#66$R!6Hj%^aMP_``rLJH*oO^{sjrQM{pWHV&Hr| z)@IyKEsDUdBW%bD6p}bKUX)RcYMpYMiIBb{bdv&t)7U@7IlI$P6r_J_%?S$6Y|d z^zRjW#jt>bFf#u5rHDlUnpMF&S6}9#;x83nFW-;90HZM^uR*y^5f8obL{Vk~GYccq zQjyN%;eO}WQ_2@1Sp|;{@Lx*DZ9w3%jPrLs`jz^XEDB#z*v`z$HJ;P{OXCU}nGzAy ze`H#cFl44BKr4(&C^eXr?4*tNPTE)+kTTw4M$AejIqReIG)J;zGU*75kCTbvA11%F zKo=HIjg$hw+p5FiP`L#M$xwOZjdu7+v4jI;M|E&X>c9pKW(uA#m4&egLl5*jK!JRi zx545%czatWxl*gr0!a1F8({eUr=bCls$ zO4HPF9F#-7-v%`l8GRN?2$wC3UDtB3+MrtLcst}oU=BB%hml7k2X2Gn!JEvBX4S_o zE56dK^w6euI*JFs7KaRuuoJHH)1WwGt=sE|Sdw)>%fVo2T1mTeS! zAzJK;)@7d-P3EFFYVr=Na6GoLsTp-Pz0sYXACHMPx&=mZOV>cI-Wa`;)c|MOBI7mc z%|{4aA&>hu*-`UArWZSzPRaAeTxVS|13{o72PfnP+@Y?25N=28Xt#aWKAhrf_XD;K zu$wG6KbK8=Z#HeQmF$L+#j_Z5iC!Ops>|=*|;POp2MpL+c;HK=ibuegr_2_>;m_EL?%p~8Pl!JH58Dkv#s zlj+>&Br$FSor3+UU{^;(9H_&87RH`u$LdHG@UX*m$czYNhy;j8-wY;xA^gXR|5jEk z<&A#=8#+ZKzKKLppLZq5Oq!eew(Z%)?gVY#viW92_UjO9Nl4;^X{7ofY;oj?dsnnJ z63nNGdk^q*6Tt5!;5Qs>m~F17c2#jZO%#*Lp^(X zvOQYQ-97jA>@3P&Y2&ZZ=9sjdxy9LXH6EfApOL}qnq;1l#5OSptf#!%=hjPaZMhi# zfU?F#1a7GL;1+tW;4p%#q55P&a~G*QtsOb=+X!I>RmCvci4D?lG)1s4dscrvf(OlWehTP zA8b;nO}dM%$W27A+zP|R&gkF8dgRPu1PYd<=gRfoA-H!VPujvHH7kTDeC}D;ub^#Fjbkg?zQ*tB|x5j$!G){ukPQF*^VN literal 0 HcmV?d00001 diff --git a/execute_tool_copy.py b/execute_tool_copy.py index 40700a3..533fb8a 100644 --- a/execute_tool_copy.py +++ b/execute_tool_copy.py @@ -6,6 +6,8 @@ import jsonlines from mars_toolkit import * import threading import uuid + +from mars_toolkit.compute.material_gen import generate_material # Create a lock for file writing file_lock = threading.Lock() from mysql.connector import pooling @@ -180,153 +182,6 @@ async def process_retrieval_from_knowledge_base(data): -async def mattergen( - properties=None, - batch_size=2, - num_batches=1, - diffusion_guidance_factor=2.0 -): - """ - 调用MatterGen服务生成晶体结构 - - Args: - properties: 可选的属性约束,例如{"dft_band_gap": 2.0} - batch_size: 每批生成的结构数量 - num_batches: 批次数量 - diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 - - Returns: - 生成的结构内容或错误信息 - """ - try: - # 导入MatterGenService - from mars_toolkit.services.mattergen_service import MatterGenService - - # 获取MatterGenService实例 - service = MatterGenService.get_instance() - - # 使用服务生成材料 - result = await service.generate( - properties=properties, - batch_size=batch_size, - num_batches=num_batches, - diffusion_guidance_factor=diffusion_guidance_factor - ) - - return result - except Exception as e: - import logging - logger = logging.getLogger(__name__) - logger.error(f"Error in mattergen: {e}") - import traceback - logger.error(traceback.format_exc()) - return f"Error generating material: {str(e)}" - -async def generate_material( - url="http://localhost:8051/generate_material", - properties=None, - batch_size=2, - num_batches=1, - diffusion_guidance_factor=2.0 -): - """ - 调用MatterGen API生成晶体结构 - - Args: - url: API端点URL - properties: 可选的属性约束,例如{"dft_band_gap": 2.0} - batch_size: 每批生成的结构数量 - num_batches: 批次数量 - diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 - - Returns: - 生成的结构内容或错误信息 - """ - # 尝试使用本地MatterGen服务 - try: - print("尝试使用本地MatterGen服务...") - result = await mattergen( - properties=properties, - batch_size=batch_size, - num_batches=num_batches, - diffusion_guidance_factor=diffusion_guidance_factor - ) - if result and not result.startswith("Error"): - print("本地MatterGen服务生成成功!") - return result - else: - print(f"本地MatterGen服务生成失败,尝试使用API: {result}") - except Exception as e: - print(f"本地MatterGen服务出错,尝试使用API: {str(e)}") - - # 如果本地服务失败,回退到API调用 - # 规范化参数 - normalized_args = normalize_material_args({ - "properties": properties, - "batch_size": batch_size, - "num_batches": num_batches, - "diffusion_guidance_factor": diffusion_guidance_factor - }) - - # 构建请求负载 - payload = { - "properties": normalized_args["properties"], - "batch_size": normalized_args["batch_size"], - "num_batches": normalized_args["num_batches"], - "diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"] - } - - print(f"发送请求到 {url}") - print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}") - - try: - # 添加headers参数,包含accept头 - headers = { - "Content-Type": "application/json", - "accept": "application/json" - } - - # 打印完整请求信息(调试用) - print(f"完整请求URL: {url}") - print(f"请求头: {json.dumps(headers, indent=2)}") - print(f"请求体: {json.dumps(payload, indent=2)}") - - # 禁用代理设置 - proxies = { - "http": None, - "https": None - } - - # 发送POST请求,添加headers参数,禁用代理,增加超时时间 - response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300) - - # 打印响应信息(调试用) - print(f"响应状态码: {response.status_code}") - print(f"响应头: {dict(response.headers)}") - print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长 - - # 检查响应状态 - if response.status_code == 200: - result = response.json() - - if result["success"]: - print("\n生成成功!") - return result["content"] - else: - print(f"\n生成失败: {result['message']}") - return None - else: - print(f"\n请求失败,状态码: {response.status_code}") - print(f"响应内容: {response.text}") - return None - - except Exception as e: - print(f"\n发生错误: {str(e)}") - print(f"错误类型: {type(e).__name__}") - import traceback - print(f"错误堆栈: {traceback.format_exc()}") - return None - async def execute_tool_from_dict(input_dict: dict): """ 从字典中提取工具函数名称和参数,并执行相应的工具函数 @@ -416,14 +271,14 @@ def worker(data, output_file_path): print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") if func.get("name") == 'retrieval_from_knowledge_base': - delay_time = random.uniform(1, 5) - - time.sleep(delay_time) - result = asyncio.run(process_retrieval_from_knowledge_base(data)) - func_results.append({"function": func['name'], "result": result}) - # 格式化结果 - formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" - formatted_results.append(formatted_result) + pass + # delay_time = random.uniform(5, 10) + # time.sleep(delay_time) + # result = asyncio.run(process_retrieval_from_knowledge_base(data)) + # func_results.append({"function": func['name'], "result": result}) + # # 格式化结果 + # formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + # formatted_results.append(formatted_result) elif func.get("name") == 'generate_material': # 规范化参数 @@ -438,30 +293,30 @@ def worker(data, output_file_path): # 规范化参数 normalized_args = normalize_material_args(arguments_data) - print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") - print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") - print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") # 优先使用mattergen函数 try: - output = asyncio.run(generate_material(**normalized_args)) + # output = asyncio.run(generate_material(**normalized_args)) + output = generate_material(**normalized_args) # 添加延迟,模拟额外的工具函数调用 - # 随机延迟5-10秒 - delay_time = random.uniform(5, 10) - print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") - time.sleep(delay_time) + # delay_time = random.uniform(5, 10) + # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") + # time.sleep(delay_time) - # 模拟其他工具函数调用的日志输出 - print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") - time.sleep(0.5) - print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") - time.sleep(0.5) - print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") - time.sleep(0.5) - print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") + # # 模拟其他工具函数调用的日志输出 + # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") + # time.sleep(0.5) + # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") + # time.sleep(0.5) + # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") + # time.sleep(0.5) + # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") except Exception as e: print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") @@ -478,14 +333,15 @@ def worker(data, output_file_path): print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") else: - delay_time = random.uniform(1, 5) - time.sleep(delay_time) - result = asyncio.run(execute_tool_from_dict(func)) - func_results.append({"function": func['name'], "result": result}) - # 格式化结果 - func_name = func.get("name") - formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" - formatted_results.append(formatted_result) + # delay_time = random.uniform(5, 10) + # time.sleep(delay_time) + pass + # result = asyncio.run(execute_tool_from_dict(func)) + # func_results.append({"function": func['name'], "result": result}) + # # 格式化结果 + # func_name = func.get("name") + # formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + # formatted_results.append(formatted_result) # 将所有格式化后的结果连接起来 final_result = "\n\n\n".join(formatted_results) @@ -557,8 +413,8 @@ if __name__ == '__main__': print(len(datas)) # print() - output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" - main(datas, output_file, max_workers=16) + output_file = f"./filter_ok_questions_solutions_agent_mattergen_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" + main(datas, output_file, max_workers=1) # 示例1:使用正确的JSON格式 # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' diff --git a/execute_tool_other_tools.py b/execute_tool_other_tools.py new file mode 100644 index 0000000..770fe44 --- /dev/null +++ b/execute_tool_other_tools.py @@ -0,0 +1,423 @@ +import json +import asyncio +import concurrent.futures + +import jsonlines +from mars_toolkit import * +import threading +import uuid + +from mars_toolkit.compute.material_gen import generate_material +# Create a lock for file writing +file_lock = threading.Lock() +from mysql.connector import pooling +from colorama import Fore, Back, Style, init +import time +import random +# 初始化colorama +init(autoreset=True) + +from typing import Dict, Union, Any, Optional + +def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + 规范化传递给generate_material函数的参数格式。 + + 处理以下情况: + 1. properties参数可能是字符串形式的JSON,需要解析为字典 + 2. properties中的值可能需要转换为适当的类型(数字或字符串) + 3. 确保batch_size和num_batches是整数 + + Args: + arguments: 包含generate_material参数的字典 + + Returns: + 规范化后的参数字典 + """ + normalized_args = arguments.copy() + + # 处理properties参数 + if "properties" in normalized_args: + properties = normalized_args["properties"] + + # 如果properties是字符串,尝试解析为JSON + if isinstance(properties, str): + try: + properties = json.loads(properties) + except json.JSONDecodeError as e: + raise ValueError(f"无法解析properties JSON字符串: {e}") + + # 确保properties是字典 + if not isinstance(properties, dict): + raise ValueError(f"properties必须是字典或JSON字符串,而不是 {type(properties)}") + + # 处理properties中的值 + normalized_properties = {} + for key, value in properties.items(): + # 处理范围值,例如 "0.0-2.0" 或 "40-50" + if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"): + # 保持范围值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.startswith(">"): + # 保持大于值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.startswith("<"): + # 保持小于值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.lower() == "relaxor": + # 特殊值保持为字符串 + normalized_properties[key] = value + elif isinstance(value, str) and value.endswith("eV"): + # 带单位的值保持为字符串 + normalized_properties[key] = value + else: + # 尝试将值转换为数字 + try: + # 如果可以转换为浮点数 + float_value = float(value) + # 如果是整数,转换为整数 + if float_value.is_integer(): + normalized_properties[key] = int(float_value) + else: + normalized_properties[key] = float_value + except (ValueError, TypeError): + # 如果无法转换为数字,保持原值 + normalized_properties[key] = value + + normalized_args["properties"] = normalized_properties + + # 确保batch_size和num_batches是整数 + if "batch_size" in normalized_args: + try: + normalized_args["batch_size"] = int(normalized_args["batch_size"]) + except (ValueError, TypeError): + raise ValueError(f"batch_size必须是整数,而不是 {normalized_args['batch_size']}") + + if "num_batches" in normalized_args: + try: + normalized_args["num_batches"] = int(normalized_args["num_batches"]) + except (ValueError, TypeError): + raise ValueError(f"num_batches必须是整数,而不是 {normalized_args['num_batches']}") + + # 确保diffusion_guidance_factor是浮点数 + if "diffusion_guidance_factor" in normalized_args: + try: + normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"]) + except (ValueError, TypeError): + raise ValueError(f"diffusion_guidance_factor必须是数字,而不是 {normalized_args['diffusion_guidance_factor']}") + + return normalized_args + + +import requests +connection_pool = pooling.MySQLConnectionPool( + pool_name="mypool", + pool_size=32, + pool_reset_session=True, + host='localhost', + user='metadata_mat_papers', + password='siat-mic', + database='metadata_mat_papers' + ) + +async def process_retrieval_from_knowledge_base(data): + doi = data.get('doi') + mp_id = data.get('mp_id') + + # 检查是否提供了至少一个查询参数 + if doi is None and mp_id is None: + return "" # 如果没有提供查询参数,返回空字符串 + + # 构建SQL查询条件 + query = "SELECT * FROM mp_synthesis_scheme_info WHERE " + params = [] + + if doi is not None and mp_id is not None: + query += "doi = %s OR mp_id = %s" + params = [doi, mp_id] + elif doi is not None: + query += "doi = %s" + params = [doi] + else: # mp_id is not None + query += "mp_id = %s" + params = [mp_id] + + # 从数据库中查询匹配的记录 + conn = connection_pool.get_connection() + try: + cursor = conn.cursor(dictionary=True) + try: + cursor.execute(query, params) + result = cursor.fetchone() # 获取第一个匹配的记录 + finally: + cursor.close() + finally: + conn.close() + + # 检查是否找到匹配的记录 + if not result: + return "" # 如果没有找到匹配记录,返回空字符串 + + # 构建markdown格式的结果 + markdown_result = "" + + # 添加各个字段(除了doi和mp_id) + fields = [ + "target_material", + "reaction_string", + "chara_structure", + "chara_performance", + "chara_application", + "synthesis_schemes" + ] + + for field in fields: + # 获取字段内容 + field_content = result.get(field, "") + # 只有当字段内容不为空时才添加该字段 + if field_content and field_content.strip(): + markdown_result += f"\n## {field}\n{field_content}\n\n" + + return markdown_result # 直接返回markdown文本 + + + +async def execute_tool_from_dict(input_dict: dict): + """ + 从字典中提取工具函数名称和参数,并执行相应的工具函数 + + Args: + input_dict: 字典,例如: + {"name": "search_material_property_from_material_project", + "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"} + + Returns: + 工具函数的执行结果,如果工具函数不存在则返回错误信息 + """ + try: + # 解析输入字符串为字典 + # input_dict = json.loads(input_str) + + # 提取函数名和参数 + func_name = input_dict.get("name") + arguments_data = input_dict.get("arguments") + #print('func_name', func_name) + #print("argument", arguments_data) + if not func_name: + return {"status": "error", "message": "未提供函数名称"} + + # 获取所有注册的工具函数 + tools = get_tools() + + # 检查函数名是否存在于工具函数字典中 + if func_name not in tools: + return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"} + + # 获取对应的工具函数 + tool_func = tools[func_name] + + # 处理参数 + arguments = {} + if arguments_data: + # 检查arguments是字符串还是字典 + if isinstance(arguments_data, dict): + # 如果已经是字典,直接使用 + arguments = arguments_data + elif isinstance(arguments_data, str): + # 如果是字符串,尝试解析为JSON + try: + # 尝试直接解析为JSON对象 + arguments = json.loads(arguments_data) + except json.JSONDecodeError: + # 如果解析失败,可能是因为字符串中包含转义字符 + # 尝试修复常见的JSON字符串问题 + fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\') + try: + arguments = json.loads(fixed_str) + except json.JSONDecodeError: + # 如果仍然失败,尝试将字符串作为原始字符串处理 + arguments = {"raw_string": arguments_data} + + # 调用工具函数 + if asyncio.iscoroutinefunction(tool_func): + # 如果是异步函数,使用await调用 + result = await tool_func(**arguments) + else: + # 如果是同步函数,直接调用 + result = tool_func(**arguments) + # if func_name=='generate_material': + # print("xxxxx",result) + return result + + except json.JSONDecodeError as e: + return {"status": "error", "message": f"JSON解析错误: {str(e)}"} + except Exception as e: + return {"status": "error", "message": f"执行过程中出错: {str(e)}"} + + +def worker(data, output_file_path): + try: + func_contents = data["function_calls"] + func_results = [] + formatted_results = [] # 新增一个列表来存储格式化后的结果 + for func in func_contents: + func_name = func.get("name") + arguments_data = func.get("arguments") + + # 使用富文本打印函数名 + print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + + # 使用富文本打印参数 + print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") + + if func.get("name") == 'retrieval_from_knowledge_base': + pass + # delay_time = random.uniform(5, 10) + # time.sleep(delay_time) + result = asyncio.run(process_retrieval_from_knowledge_base(data)) + func_results.append({"function": func['name'], "result": result}) + # 格式化结果 + formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + formatted_results.append(formatted_result) + + elif func.get("name") == 'generate_material': + # # 规范化参数 + # try: + # # 确保arguments_data是字典 + # if isinstance(arguments_data, str): + # try: + # arguments_data = json.loads(arguments_data) + # except json.JSONDecodeError as e: + # print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") + # continue + + # # 规范化参数 + # normalized_args = normalize_material_args(arguments_data) + # # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + # # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + # # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + + # # 优先使用mattergen函数 + # try: + # # output = asyncio.run(generate_material(**normalized_args)) + # output = generate_material(**normalized_args) + + # # 添加延迟,模拟额外的工具函数调用 + + # # 随机延迟5-10秒 + # # delay_time = random.uniform(5, 10) + # # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") + # # time.sleep(delay_time) + + # # # 模拟其他工具函数调用的日志输出 + # # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") + # # time.sleep(0.5) + # # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") + # # time.sleep(0.5) + # # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") + # # time.sleep(0.5) + # # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") + + # except Exception as e: + # print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") + + # # 将结果添加到func_results + # func_results.append({"function": func_name, "result": output}) + + # # 格式化结果 + # formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" + # formatted_results.append(formatted_result) + # except Exception as e: + # print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") + # import traceback + # print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") + pass + else: + # delay_time = random.uniform(5, 10) + # time.sleep(delay_time) + + result = asyncio.run(execute_tool_from_dict(func)) + func_results.append({"function": func['name'], "result": result}) + # 格式化结果 + func_name = func.get("name") + formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" + formatted_results.append(formatted_result) + + # 将所有格式化后的结果连接起来 + final_result = "\n\n\n".join(formatted_results) + data['observation'] = final_result + + # 使用富文本打印开始和结束标记 + print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") + print(data['observation']) + print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") + with file_lock: + with jsonlines.open(output_file_path, mode='a') as writer: + writer.write(data) # observation . data + return f"Processed successfully" + + except Exception as e: + print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") + return f"Error processing: {str(e)}" + + +def main(datas, output_file_path, max_workers=1): + import random + from tqdm import tqdm + import os + from mysql.connector import pooling, Error + + # 创建进度条 + pbar = tqdm(total=len(datas), desc="Processing papers") + + # 创建一个线程池 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交任务到执行器 + future_to_path = {} + for path in datas: + future = executor.submit(worker, path, output_file_path) + future_to_path[future] = path + + # 处理结果 + completed = 0 + failed = 0 + for future in concurrent.futures.as_completed(future_to_path): + path = future_to_path[future] + try: + result = future.result() + if "successfully" in result: + completed += 1 + else: + failed += 1 + # 更新进度条 + pbar.update(1) + # 每100个文件更新一次统计信息 + if (completed + failed) % 100 == 0: + pbar.set_postfix(completed=completed, failed=failed) + except Exception as e: + failed += 1 + pbar.update(1) + print(f"\nWorker for {path} generated an exception: {e}") + + pbar.close() + print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}") + + +if __name__ == '__main__': + import datetime + import jsonlines + datas = [] + with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader: + for obj in reader: + datas.append(obj) + + print(len(datas)) + # print() + output_file = f"./filter_ok_questions_solutions_agent_other_tools_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" + main(datas, output_file, max_workers=32) + + # 示例1:使用正确的JSON格式 + # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' + # argument = json.loads(argument) + # print(json.dumps(argument, indent=2)) + # asyncio.run(mattergen(**argument)) diff --git a/mars_toolkit/__pycache__/__init__.cpython-310.pyc b/mars_toolkit/__pycache__/__init__.cpython-310.pyc index e4e221594f3111f558246b768de91e493cf1bf49..6a1511fa5e6a3d9a6e5aa37bcc8a70e589e677f2 100644 GIT binary patch delta 20 acmcc3ahroXpO=@50SG=DecZ?$zybg|%>|19 delta 20 acmcc3ahroXpO=@50SKJ#zS+nfzybg{-vxjG diff --git a/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc b/mars_toolkit/compute/__pycache__/material_gen.cpython-310.pyc index e929dd8ac5d3acaa37b17b9d6978b04e2acbf899..675755e4afc3efd46fcb805f665d0680aee2b360 100644 GIT binary patch delta 3287 zcmZt|TWl29b?)r!YrI~+@dL2eHV-d`45saU0;QbsMCW$)ef;+dVzomq^X zb(aJaNH8xA&=OvDl?1Aa6xC8`Qxblv{`8~$tmdPwD%Lh1Ri#$6Qlzv!=dKM~+F9+n z=bU@)IcLu6&PQMVtWWqe`nXLRyGO zKzsCRBdmo5uIM#JtyXK)X>~@uR&O+D4MwBZXf$a}#u9A_k;Z~dJ=#pY)JOg2WFj4u zvOz^+Eu)8tgug>lmW-g*fZl4fX>CGlP+w{+D{9LGq3X+xcCB6DReFaJ(IQ5t)+y+a z-eq)a-2$)HR~Rd`l>!gztBlp!YJu12Jw{ZE3cOZdWAtjh0w?-fZ7mokx}qs;?a_5p zF)A?y*7X7wH;82;+cZsR-LyPSron$&PaE{j2P7s>ZDCudiS`n0q)ii&)<>7nW>~k; z7TOBycG|}J=~C9n`q>UPkn*%jbQx;^sy-#RN+};*PTMCG?d2eWUwyCxU_{`Z0DVDH zL_qs#7pUoGJ1KchqALK$btN01tKh$yy~6rIG4gs?EseHHYyj{z>{Y?*U65Z3>$(cr z7~4hHvmHWa1K=BB?=^6309`BVOBLJcCc1gTtLflHt zyj}luOuQSuAH45^vPEan;4;>qQt5WOWx{)sXkVgazZ4yKg2gu&m0kZgxS3>@<$Cw! znVe_NNAf_O>$P%;k?esealI*%Czgb+6$Y_ks*c`-HT2$`8^$!SM?Y z>=kK1puVF48qP=sEoEGN@kOS>1=Pi)0vMpx1$EK;gV0J;9uIskfJVyL2((tT zCVvOx6M?66O$`}2>k&)q%y_~-XZA56?$d}sFN#o4RZ<}Tl!J#%sH{fW|p4**x2hV837)p%Z{%5?^+!sRr%D%pT$U|ecpF6E#FHP~E}i+lFn~rlv+}yV zcw+AI-N&DtpMP|@bp6)cFP+(2H~#wshjdAU`4s?5_b!%QD?L0>x^s5+-8=Ksle4i3 z_hPYFzO7Pn{``Z|g{!4A_sh=oM}FjN>}V)HU0Epv?^Sy5cIhV%OQ$~Eo3L%hzrwQn z89$auvZuZOD_K_hqGajgTaQ0G`*c;MeA&_Y`zPTo(=5w)!e+Bqe^H#BxdN^}?XHxq z*oW=%nbOCT;9P!lrCen?%KLwPI9ZNd>A_EDFMkGaQI2?iIh?B$pMP{(B)02b>GUU0 zE??`9xI`WE@8R6!5E>{2*RPxDbS9g2e%IMlxF&4RnCm-E6clt@wF%%6iVZUmFAs5^G;{N_tnFl@mzUk z#1}n{gBOfv5_){ZWQjw04(W2daFgej#efOe&rz~LkYVTD>V3W;G_}k5Q}uTb?*-IVGuC2qMPy*6>Sn@r zJ(;Y{H=%@oB*C*#B2k6Ej)xP-tObJ^im}dt`3$qTP+Z}jb1U3Jrkw}j%`G+_+l9vz zJz?3dpR!Z}W`g6X$@%igZ*%_A+&1IXkl>L5&|qp5*($g~<1A+;hdA>2DFA8sizlkO zHOV2C9L|~GEYv8(t_VAP48fZS-a>E!!P^KTavfdHnf7>P{ww@z!@K61qwZaY`>1J?(C zaYdy7DbIC-j4MR4f@_X54N-@E51^Ki1Hq!MH2k~ z*RGJ?f#?Z`bbuTU4-WDcJh=KT)#>04|3R&O>Of!quASZ~Ngl!rYuSMF2bjIX$K=69xH0MTP{6WTINs~KYL_(y& zQ9C2e7%;cy*=EL!!NlQB&aTeRSC`;ki;>_~+%`>pIAh~s1|XZtqi^(T(NcbPrJ_c}l9jIHP0u*vZnFQ*%)CqINBjsS-+&mnjNf$ey^US4+vQOs#Y zo<$>xaT*vfOq$o(PDr?9{mL=Bnj3LEhsX*dd+J5L<$r|~XR>R#zc7%Q6yDig|HS{XlP1oO?IdnuH>sPZPD~0WNUZ3ugpKTv2J^y$PrmkkNem&HJ4<^Tzpc zE%Xx*Z06)skO4%%{_avDE$4Z}Zm7Sx<*s|2k1785$Dr{{9w{{q?RiMXd zZ8u~*&C2uyU8E=Psjw;B=n`F_r&yMyz+(`Kt3P3~G%OFwm9eT%c8dv*uzD*->-5BK z=sjYrQF2A8p8g5jX5D|RuKJg>4-USK{;rOHtW-%5ay?UcRW(p+ivi(c6zs&PUx{8! z#8z#~^_mXvuC4hG<9k{_V*Xy@n`jUg%-ys--%RFb4js^I0wBn?9KlVm(X?8Prql6Q z$1D3ECQp+y2j3?vWcgp_DX`gM!fR}q_9k;*gAi^4K#e7cuBRSFiTbBFf3TFkt&vOq zC)thgMNC!n)!cP+0Tae7*XlqKohA!3*AqdgWjm%9XjaDy!nV_HGZ94IGDXMgw5z(f zfd0>*X9jXww9yq#7j|GVSIF3XJM_QGjS}wvkUN#XhHgc4({0oBf(T_T6VBkOKQ+`1 zzl@48KR!8j@YxU+QmU>zn*&k;eieMng~D_%u60|8U_g_nsND3f_*BAK!; z8*1AKHxX_D1j7y2e_SXI_nirfb#7Xn`~FgKgWUFai}%ijL5lNhm~I8&T|x;}ZH1FN zI#G7?2kO3ZC$tl~rR^v?WJlf6sCG|>&eRXyFFn#0pc6AMnx4l56n2G)Evw1k0PjGk zzA9R*4Aoy25R z(>0YGj_@DV)cU#@LPNa%b%A%izP2u2!S7WVRlQFyc0gP`A861d;w`R*Bth?yI?9bPzfC&gn&mYA_!500e~Q3x~%4OJ&Rj! zGlBmWf;fsG0i2>M;wTwF33Lf}1-_~vjBCpgNm!^N<=>y3Bx(Q4+0sY`)#>gwyj|!g zu&Kcd7K8q8vr`xFT!K+<3b)}oj{TlsI3x7q#ApigX diff --git a/mars_toolkit/compute/material_gen.py b/mars_toolkit/compute/material_gen.py index b1fa653..7f89ec6 100644 --- a/mars_toolkit/compute/material_gen.py +++ b/mars_toolkit/compute/material_gen.py @@ -9,9 +9,18 @@ import asyncio import zipfile import shutil import re +import multiprocessing +from multiprocessing import Process, Queue from pathlib import Path from typing import Literal, Dict, Any, Tuple, Union, Optional, List +# 设置多进程启动方法为spawn,解决CUDA初始化错误 +try: + multiprocessing.set_start_method('spawn', force=True) +except RuntimeError: + # 如果已经设置过启动方法,会抛出RuntimeError + pass + from ase.optimize import FIRE from ase.filters import FrechetCellFilter from ase.atoms import Atoms @@ -33,6 +42,49 @@ from ..core.mattergen_wrapper import * logger = logging.getLogger(__name__) +def _process_generate_material_worker(args_queue, result_queue): + """ + 在新进程中处理材料生成的工作函数 + + Args: + args_queue: 包含生成参数的队列 + result_queue: 用于返回结果的队列 + """ + try: + # 配置日志 + import logging + logger = logging.getLogger(__name__) + logger.info("子进程开始执行材料生成...") + + # 从队列获取参数 + args = args_queue.get() + logger.info(f"子进程获取到参数: {args}") + + # 导入MatterGenService + from mars_toolkit.services.mattergen_service import MatterGenService + logger.info("子进程成功导入MatterGenService") + + # 获取MatterGenService实例 + service = MatterGenService.get_instance() + logger.info("子进程成功获取MatterGenService实例") + + # 使用服务生成材料 + logger.info("子进程开始调用generate方法...") + result = service.generate(**args) + logger.info("子进程generate方法调用完成") + + # 将结果放入结果队列 + result_queue.put(result) + logger.info("子进程材料生成完成,结果已放入队列") + except Exception as e: + # 如果发生错误,将错误信息放入结果队列 + import traceback + error_msg = f"材料生成过程中出错: {str(e)}\n{traceback.format_exc()}" + import logging + logging.getLogger(__name__).error(error_msg) + result_queue.put(f"Error: {error_msg}") + + def format_cif_content(content): """ Format CIF content by removing unnecessary headers and organizing each CIF file. @@ -233,7 +285,7 @@ def main( @llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints") -async def generate_material( +def generate_material( properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None, batch_size: int = 2, num_batches: int = 1, @@ -260,16 +312,45 @@ async def generate_material( Returns: Descriptive text with generated crystal structures in CIF format """ + # # 创建队列用于进程间通信 + # args_queue = Queue() + # result_queue = Queue() + + # # 将参数放入队列 + # args_queue.put({ + # "properties": properties, + # "batch_size": batch_size, + # "num_batches": num_batches, + # "diffusion_guidance_factor": diffusion_guidance_factor + # }) + + # # 创建并启动新进程 + # logger.info("启动新进程处理材料生成...") + # p = Process(target=_process_generate_material_worker, args=(args_queue, result_queue)) + # p.start() + + # # 等待进程完成并获取结果 + # p.join() + # result = result_queue.get() + + # # 检查结果是否为错误信息 + # if isinstance(result, str) and result.startswith("Error:"): + # # 记录错误日志 + # logger.error(result) + # 导入MatterGenService from mars_toolkit.services.mattergen_service import MatterGenService + logger.info("子进程成功导入MatterGenService") # 获取MatterGenService实例 service = MatterGenService.get_instance() + logger.info("子进程成功获取MatterGenService实例") # 使用服务生成材料 - return service.generate( - properties=properties, - batch_size=batch_size, - num_batches=num_batches, - diffusion_guidance_factor=diffusion_guidance_factor - ) + logger.info("子进程开始调用generate方法...") + result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor) + logger.info("子进程generate方法调用完成") + if "Error generating structures" in result: + return f"Error: Invalid properties {properties}." + else: + return result diff --git a/mars_toolkit/core/__pycache__/config.cpython-310.pyc b/mars_toolkit/core/__pycache__/config.cpython-310.pyc index b6170e087a288fa6a01abc3c783fc3c7cf7e3d51..d11872dea0d40dc3215de592d786006b4a4d5436 100644 GIT binary patch delta 48 zcmX>ta8iIbpO=@50SKB6Kc>Ig$eY6~BAroEQedU8Z)jt#+&<@_pkr}NH7g* delta 33 ncmX>pa9V&jpO=@50SNAheoTM2kvE5#QFgLDv-swI=3OiRq7Mni diff --git a/mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc b/mars_toolkit/core/__pycache__/mattergen_wrapper.cpython-310.pyc index a0a2b888658ee4eaaa1bdacba1f20d748f4a58e1..33d3c6447a0824499fb9842c998649e62de6b751 100644 GIT binary patch delta 20 Zcmey${*|3OpO=@50R${RZRCE%3;;Po1#18R delta 20 acmey${*|3OpO=@50SFHNdb5%H6*B-r{|4y* diff --git a/mars_toolkit/core/config.py b/mars_toolkit/core/config.py index 4babf2c..886a480 100644 --- a/mars_toolkit/core/config.py +++ b/mars_toolkit/core/config.py @@ -35,7 +35,7 @@ class Config: DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA' # Searxng - SEARXNG_HOST="http://192.168.191.101:40032/" + SEARXNG_HOST="http://192.168.168.1:40032/" # Visualization VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization' diff --git a/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc b/mars_toolkit/query/__pycache__/web_search.cpython-310.pyc index 8f99194497ce05195783a1a1da0b4d822f6f59d6..0da5f8346c5d49868955a7a5b10d5bda4b62d9fe 100644 GIT binary patch delta 741 zcmZ9KO=}ZD7{{NPotb^hZZ^I)eWPeq3hG5sgy2b}LZ!tPv`Zz5&nDVt6Pm0_%Pvt; zXjAAx_NbNQA|AYmA3^*A9{1=+@ZiOZGh2wzUG_h}=Qq#H&SCy|-%C!*aV(I}SopjC z>zy;z-IoP|4@i(Uh)?p7`c#fuLkpSDLhkcW_#)JOov1$6FhbKebLI^zw0%2g(I|wD z?+^%V)X_lmfSMMT&<b#VG)c0@9DO8Kv`9}X^M!L+Xge`p82yV-(0v)9o7{Vawnlb(l6;!8# delta 706 zcmZ9K&ubGw6vyA2nVtQS-DIOxo7noJ;vonkD0&eO;-NzIq986x4t<+YvzxS=Eu!oa zHN`Yy53=Az3kiY;Z=(N!_#X&DE}jL?-mK0>tk50yJD>NNna9BHr}9tN={SxB>iPQa zM>Bij%=ezCg1`epVgY)nW*)GRdps1L2sKX&r6)087l9rco-t)DFhk3;rYwU(XnQt- zZ;?(0nOn@Xs7wW)`wrO+RsnJp7pXYqlJ7JmxukDseU4t~{l^4*O9=27AIcaup^rnj z1#L$_j4^Al2QvuyAmwjE8}F+<#99!|5gS5Vs*1Nnv_O0aF^|PPjYJcb)wcMM#d4&N zA!P|086*cVzyWLU{H7bmm|X>ui;(ts6H<}zD{yx%)e6K!OXN!FN6sjy4-E;%_- z@wFHu!rEmjZQ?5+H!dWipHq2TOJ!^)f>@ua&gbs1R9D?4?MrmEzJHrvfK;k_=4PUi zSpg*v?8<}6752D8%GrC|ta8rORsNn`>m>39)T$p|^3#}o6K}DaJ=Lz^`Rt2!_7qI? zZhK{Q)sH74SZTK=vfEjYmi?NDj;fy_I-bWu9w)Lh(nlwIBR?Oj=HhT3^LdymO6u({ zs$h?pn?2W;Y(p7U*P~kYU4M4{ str: elif tool_name == "generate_material": from mars_toolkit.compute.material_gen import generate_material # 使用简单的属性约束进行测试 - result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1) + # result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1) + result = generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1) elif tool_name == "fetch_chemical_composition_from_OQMD": from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD @@ -171,7 +172,7 @@ if __name__ == "__main__": ] # 选择要测试的工具 - tool_name = tools_to_test[6] # 测试 search_online 工具 + tool_name = tools_to_test[1] # 测试 search_online 工具 # 运行测试 result = asyncio.run(test_tool(tool_name))