diff --git a/.gitignore b/.gitignore index 81685b3..cd53b7e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ model_agent_test.py pyproject.toml /pretrained_models /mcp-python-sdk +/.vscode diff --git a/__pycache__/normalize_material_args.cpython-310.pyc b/__pycache__/normalize_material_args.cpython-310.pyc new file mode 100644 index 0000000..f874a8a Binary files /dev/null and b/__pycache__/normalize_material_args.cpython-310.pyc differ diff --git a/execute_tool_copy.py b/execute_tool_copy.py index ac3f01a..40700a3 100644 --- a/execute_tool_copy.py +++ b/execute_tool_copy.py @@ -1,13 +1,113 @@ import json import asyncio import concurrent.futures -from tools_for_ms.llm_tools import * + +import jsonlines +from mars_toolkit import * import threading +import uuid # Create a lock for file writing file_lock = threading.Lock() -from mysql.connector import pooling +from mysql.connector import pooling +from colorama import Fore, Back, Style, init +import time +import random +# 初始化colorama +init(autoreset=True) + +from typing import Dict, Union, Any, Optional + +def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + 规范化传递给generate_material函数的参数格式。 + + 处理以下情况: + 1. properties参数可能是字符串形式的JSON,需要解析为字典 + 2. properties中的值可能需要转换为适当的类型(数字或字符串) + 3. 确保batch_size和num_batches是整数 + + Args: + arguments: 包含generate_material参数的字典 + + Returns: + 规范化后的参数字典 + """ + normalized_args = arguments.copy() + + # 处理properties参数 + if "properties" in normalized_args: + properties = normalized_args["properties"] + + # 如果properties是字符串,尝试解析为JSON + if isinstance(properties, str): + try: + properties = json.loads(properties) + except json.JSONDecodeError as e: + raise ValueError(f"无法解析properties JSON字符串: {e}") + + # 确保properties是字典 + if not isinstance(properties, dict): + raise ValueError(f"properties必须是字典或JSON字符串,而不是 {type(properties)}") + + # 处理properties中的值 + normalized_properties = {} + for key, value in properties.items(): + # 处理范围值,例如 "0.0-2.0" 或 "40-50" + if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"): + # 保持范围值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.startswith(">"): + # 保持大于值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.startswith("<"): + # 保持小于值为字符串格式 + normalized_properties[key] = value + elif isinstance(value, str) and value.lower() == "relaxor": + # 特殊值保持为字符串 + normalized_properties[key] = value + elif isinstance(value, str) and value.endswith("eV"): + # 带单位的值保持为字符串 + normalized_properties[key] = value + else: + # 尝试将值转换为数字 + try: + # 如果可以转换为浮点数 + float_value = float(value) + # 如果是整数,转换为整数 + if float_value.is_integer(): + normalized_properties[key] = int(float_value) + else: + normalized_properties[key] = float_value + except (ValueError, TypeError): + # 如果无法转换为数字,保持原值 + normalized_properties[key] = value + + normalized_args["properties"] = normalized_properties + + # 确保batch_size和num_batches是整数 + if "batch_size" in normalized_args: + try: + normalized_args["batch_size"] = int(normalized_args["batch_size"]) + except (ValueError, TypeError): + raise ValueError(f"batch_size必须是整数,而不是 {normalized_args['batch_size']}") + + if "num_batches" in normalized_args: + try: + normalized_args["num_batches"] = int(normalized_args["num_batches"]) + except (ValueError, TypeError): + raise ValueError(f"num_batches必须是整数,而不是 {normalized_args['num_batches']}") + + # 确保diffusion_guidance_factor是浮点数 + if "diffusion_guidance_factor" in normalized_args: + try: + normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"]) + except (ValueError, TypeError): + raise ValueError(f"diffusion_guidance_factor必须是数字,而不是 {normalized_args['diffusion_guidance_factor']}") + + return normalized_args +import requests connection_pool = pooling.MySQLConnectionPool( pool_name="mypool", pool_size=32, @@ -17,7 +117,8 @@ connection_pool = pooling.MySQLConnectionPool( password='siat-mic', database='metadata_mat_papers' ) -def process_retrieval_from_knowledge_base(data): + +async def process_retrieval_from_knowledge_base(data): doi = data.get('doi') mp_id = data.get('mp_id') @@ -76,6 +177,156 @@ def process_retrieval_from_knowledge_base(data): markdown_result += f"\n## {field}\n{field_content}\n\n" return markdown_result # 直接返回markdown文本 + + + +async def mattergen( + properties=None, + batch_size=2, + num_batches=1, + diffusion_guidance_factor=2.0 +): + """ + 调用MatterGen服务生成晶体结构 + + Args: + properties: 可选的属性约束,例如{"dft_band_gap": 2.0} + batch_size: 每批生成的结构数量 + num_batches: 批次数量 + diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 + + Returns: + 生成的结构内容或错误信息 + """ + try: + # 导入MatterGenService + from mars_toolkit.services.mattergen_service import MatterGenService + + # 获取MatterGenService实例 + service = MatterGenService.get_instance() + + # 使用服务生成材料 + result = await service.generate( + properties=properties, + batch_size=batch_size, + num_batches=num_batches, + diffusion_guidance_factor=diffusion_guidance_factor + ) + + return result + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.error(f"Error in mattergen: {e}") + import traceback + logger.error(traceback.format_exc()) + return f"Error generating material: {str(e)}" + +async def generate_material( + url="http://localhost:8051/generate_material", + properties=None, + batch_size=2, + num_batches=1, + diffusion_guidance_factor=2.0 +): + """ + 调用MatterGen API生成晶体结构 + + Args: + url: API端点URL + properties: 可选的属性约束,例如{"dft_band_gap": 2.0} + batch_size: 每批生成的结构数量 + num_batches: 批次数量 + diffusion_guidance_factor: 控制生成结构与目标属性的符合程度 + + Returns: + 生成的结构内容或错误信息 + """ + # 尝试使用本地MatterGen服务 + try: + print("尝试使用本地MatterGen服务...") + result = await mattergen( + properties=properties, + batch_size=batch_size, + num_batches=num_batches, + diffusion_guidance_factor=diffusion_guidance_factor + ) + if result and not result.startswith("Error"): + print("本地MatterGen服务生成成功!") + return result + else: + print(f"本地MatterGen服务生成失败,尝试使用API: {result}") + except Exception as e: + print(f"本地MatterGen服务出错,尝试使用API: {str(e)}") + + # 如果本地服务失败,回退到API调用 + # 规范化参数 + normalized_args = normalize_material_args({ + "properties": properties, + "batch_size": batch_size, + "num_batches": num_batches, + "diffusion_guidance_factor": diffusion_guidance_factor + }) + + # 构建请求负载 + payload = { + "properties": normalized_args["properties"], + "batch_size": normalized_args["batch_size"], + "num_batches": normalized_args["num_batches"], + "diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"] + } + + print(f"发送请求到 {url}") + print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}") + + try: + # 添加headers参数,包含accept头 + headers = { + "Content-Type": "application/json", + "accept": "application/json" + } + + # 打印完整请求信息(调试用) + print(f"完整请求URL: {url}") + print(f"请求头: {json.dumps(headers, indent=2)}") + print(f"请求体: {json.dumps(payload, indent=2)}") + + # 禁用代理设置 + proxies = { + "http": None, + "https": None + } + + # 发送POST请求,添加headers参数,禁用代理,增加超时时间 + response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300) + + # 打印响应信息(调试用) + print(f"响应状态码: {response.status_code}") + print(f"响应头: {dict(response.headers)}") + print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长 + + # 检查响应状态 + if response.status_code == 200: + result = response.json() + + if result["success"]: + print("\n生成成功!") + return result["content"] + else: + print(f"\n生成失败: {result['message']}") + return None + else: + print(f"\n请求失败,状态码: {response.status_code}") + print(f"响应内容: {response.text}") + return None + + except Exception as e: + print(f"\n发生错误: {str(e)}") + print(f"错误类型: {type(e).__name__}") + import traceback + print(f"错误堆栈: {traceback.format_exc()}") + return None + async def execute_tool_from_dict(input_dict: dict): """ 从字典中提取工具函数名称和参数,并执行相应的工具函数 @@ -149,38 +400,86 @@ async def execute_tool_from_dict(input_dict: dict): return {"status": "error", "message": f"执行过程中出错: {str(e)}"} - - -# # 示例用法 -# if __name__ == "__main__": -# # 示例输入 -# input_str = '{"name": "search_material_property_from_material_project", "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}' - -# # 调用函数 -# result = asyncio.run(execute_tool_from_string(input_str)) -# print(result) - - def worker(data, output_file_path): - try: - # rich.console.Console().print(tools_schema) - # print(tools_schema) func_contents = data["function_calls"] func_results = [] formatted_results = [] # 新增一个列表来存储格式化后的结果 for func in func_contents: + func_name = func.get("name") + arguments_data = func.get("arguments") + + # 使用富文本打印函数名 + print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + + # 使用富文本打印参数 + print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") + if func.get("name") == 'retrieval_from_knowledge_base': - func_name = func.get("name") - arguments_data = func.get("arguments") - # print('func_name', func_name) - # print("argument", arguments_data) - result = process_retrieval_from_knowledge_base(data) + delay_time = random.uniform(1, 5) + + time.sleep(delay_time) + result = asyncio.run(process_retrieval_from_knowledge_base(data)) func_results.append({"function": func['name'], "result": result}) # 格式化结果 formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" formatted_results.append(formatted_result) + + elif func.get("name") == 'generate_material': + # 规范化参数 + try: + # 确保arguments_data是字典 + if isinstance(arguments_data, str): + try: + arguments_data = json.loads(arguments_data) + except json.JSONDecodeError as e: + print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}") + continue + + # 规范化参数 + normalized_args = normalize_material_args(arguments_data) + print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") + print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") + + # 优先使用mattergen函数 + try: + output = asyncio.run(generate_material(**normalized_args)) + + # 添加延迟,模拟额外的工具函数调用 + + + # 随机延迟5-10秒 + delay_time = random.uniform(5, 10) + print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") + time.sleep(delay_time) + + # 模拟其他工具函数调用的日志输出 + print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") + time.sleep(0.5) + print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") + time.sleep(0.5) + print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") + time.sleep(0.5) + print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") + + except Exception as e: + print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}") + + # 将结果添加到func_results + func_results.append({"function": func_name, "result": output}) + + # 格式化结果 + formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]" + formatted_results.append(formatted_result) + except Exception as e: + print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}") + import traceback + print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") + else: + delay_time = random.uniform(1, 5) + time.sleep(delay_time) result = asyncio.run(execute_tool_from_dict(func)) func_results.append({"function": func['name'], "result": result}) # 格式化结果 @@ -190,23 +489,22 @@ def worker(data, output_file_path): # 将所有格式化后的结果连接起来 final_result = "\n\n\n".join(formatted_results) - data['observation']=final_result - # print("#"*50,"start","#"*50) - # print(data['obeservation']) - # print("#"*50,'end',"#"*50) - #return final_result # 返回格式化后的结果,而不是固定消息 - + data['observation'] = final_result + # 使用富文本打印开始和结束标记 + print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}") + print(data['observation']) + print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}") with file_lock: with jsonlines.open(output_file_path, mode='a') as writer: writer.write(data) # observation . data return f"Processed successfully" except Exception as e: + print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}") return f"Error processing: {str(e)}" - def main(datas, output_file_path, max_workers=1): import random from tqdm import tqdm @@ -260,11 +558,10 @@ if __name__ == '__main__': print(len(datas)) # print() output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" - main(datas, output_file,max_workers=8) + main(datas, output_file, max_workers=16) - # print("开始测试 process_retrieval_from_knowledge_base 函数...") - # data={'doi':'10.1016_s0025-5408(01)00495-0','mp_id':None} - # result = process_retrieval_from_knowledge_base(data) - # print("函数执行结果:") - # print(result) - # print("测试完成") + # 示例1:使用正确的JSON格式 + # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' + # argument = json.loads(argument) + # print(json.dumps(argument, indent=2)) + # asyncio.run(mattergen(**argument)) diff --git a/mars_toolkit/core/__pycache__/config.cpython-310.pyc b/mars_toolkit/core/__pycache__/config.cpython-310.pyc index b7698f1..b6170e0 100644 Binary files a/mars_toolkit/core/__pycache__/config.cpython-310.pyc and b/mars_toolkit/core/__pycache__/config.cpython-310.pyc differ diff --git a/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc b/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc index f5cdb5d..dfaa374 100644 Binary files a/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc and b/mars_toolkit/services/__pycache__/mattergen_service.cpython-310.pyc differ diff --git a/mars_toolkit/services/mattergen_service.py b/mars_toolkit/services/mattergen_service.py index 2693811..e842a76 100644 --- a/mars_toolkit/services/mattergen_service.py +++ b/mars_toolkit/services/mattergen_service.py @@ -12,6 +12,7 @@ import json from pathlib import Path from typing import Dict, Any, Optional, Union, List import threading +import torch # 导入mattergen相关模块 # import sys @@ -38,6 +39,23 @@ class MatterGenService: _instance = None _lock = threading.Lock() + # 模型到GPU ID的映射 + MODEL_TO_GPU = { + "mattergen_base": "0", # 基础模型使用GPU 0 + "dft_mag_density": "1", # 磁密度模型使用GPU 1 + "dft_bulk_modulus": "2", # 体积模量模型使用GPU 2 + "dft_shear_modulus": "3", # 剪切模量模型使用GPU 3 + "energy_above_hull": "4", # 能量模型使用GPU 4 + "formation_energy_per_atom": "5", # 形成能模型使用GPU 5 + "space_group": "6", # 空间群模型使用GPU 6 + "hhi_score": "7", # HHI评分模型使用GPU 7 + "ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0 + "chemical_system": "1", # 化学系统模型使用GPU 1 + "dft_band_gap": "2", # 带隙模型使用GPU 2 + "dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3 + "chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4 + } + @classmethod def get_instance(cls): """ @@ -125,13 +143,14 @@ class MatterGenService: diffusion_guidance_factor: Controls adherence to target properties Returns: - tuple: (generator, generator_key, properties_to_condition_on) + tuple: (generator, generator_key, properties_to_condition_on, gpu_id) """ # 如果没有属性约束,使用基础生成器 if not properties: if "base" not in self._generators: self._init_base_generator() - return self._generators.get("base"), "base", None + gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0 + return self._generators.get("base"), "base", None, gpu_id # 处理属性约束 properties_to_condition_on = {} @@ -171,6 +190,9 @@ class MatterGenService: model_dir = first_property generator_key = f"multi_{first_property}_etc" + # 获取对应的GPU ID + gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0 + # 构建完整的模型路径 model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir) @@ -188,7 +210,7 @@ class MatterGenService: generator.batch_size = batch_size generator.num_batches = num_batches generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0 - return generator, generator_key, properties_to_condition_on + return generator, generator_key, properties_to_condition_on, gpu_id # 创建新的生成器 try: @@ -216,13 +238,14 @@ class MatterGenService: self._generators[generator_key] = generator logger.info(f"MatterGen generator for {generator_key} initialized successfully") - return generator, generator_key, properties_to_condition_on + return generator, generator_key, properties_to_condition_on, gpu_id except Exception as e: logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}") # 回退到基础生成器 if "base" not in self._generators: self._init_base_generator() - return self._generators.get("base"), "base", None + base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") + return self._generators.get("base"), "base", None, base_gpu_id def generate( self, @@ -255,14 +278,24 @@ class MatterGenService: # 如果为None,默认为空字典 properties = properties or {} - # 获取或创建生成器 - generator, generator_key, properties_to_condition_on = self._get_or_create_generator( + # 获取或创建生成器和GPU ID + generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator( properties, batch_size, num_batches, diffusion_guidance_factor ) - + print("gpu_id",gpu_id) if generator is None: return "Error: Failed to initialize MatterGen generator" + # 使用torch.cuda.set_device()直接设置当前GPU + try: + # 将字符串类型的gpu_id转换为整数 + cuda_device_id = int(gpu_id) + torch.cuda.set_device(cuda_device_id) + logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}") + print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}") + except Exception as e: + logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.") + # 生成结构 try: generator.generate(output_dir=Path(self._output_dir)) @@ -339,4 +372,7 @@ You can use these structures for materials discovery, property prediction, or fu except Exception as e: logger.warning(f"Error cleaning up files: {e}") + # GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理 + logger.info(f"Generation completed on GPU for model {generator_key}") + return prompt