mattergen调用指定GPU&规范化mattergen的输入

This commit is contained in:
lzy
2025-04-05 20:19:43 +08:00
parent bac8f067e0
commit 71d8dabd17
6 changed files with 379 additions and 45 deletions

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ model_agent_test.py
pyproject.toml pyproject.toml
/pretrained_models /pretrained_models
/mcp-python-sdk /mcp-python-sdk
/.vscode

Binary file not shown.

View File

@@ -1,13 +1,113 @@
import json import json
import asyncio import asyncio
import concurrent.futures import concurrent.futures
from tools_for_ms.llm_tools import *
import jsonlines
from mars_toolkit import *
import threading import threading
import uuid
# Create a lock for file writing # Create a lock for file writing
file_lock = threading.Lock() file_lock = threading.Lock()
from mysql.connector import pooling from mysql.connector import pooling
from colorama import Fore, Back, Style, init
import time
import random
# 初始化colorama
init(autoreset=True)
from typing import Dict, Union, Any, Optional
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化传递给generate_material函数的参数格式。
处理以下情况:
1. properties参数可能是字符串形式的JSON需要解析为字典
2. properties中的值可能需要转换为适当的类型数字或字符串
3. 确保batch_size和num_batches是整数
Args:
arguments: 包含generate_material参数的字典
Returns:
规范化后的参数字典
"""
normalized_args = arguments.copy()
# 处理properties参数
if "properties" in normalized_args:
properties = normalized_args["properties"]
# 如果properties是字符串尝试解析为JSON
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError as e:
raise ValueError(f"无法解析properties JSON字符串: {e}")
# 确保properties是字典
if not isinstance(properties, dict):
raise ValueError(f"properties必须是字典或JSON字符串而不是 {type(properties)}")
# 处理properties中的值
normalized_properties = {}
for key, value in properties.items():
# 处理范围值,例如 "0.0-2.0" 或 "40-50"
if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"):
# 保持范围值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith(">"):
# 保持大于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith("<"):
# 保持小于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.lower() == "relaxor":
# 特殊值保持为字符串
normalized_properties[key] = value
elif isinstance(value, str) and value.endswith("eV"):
# 带单位的值保持为字符串
normalized_properties[key] = value
else:
# 尝试将值转换为数字
try:
# 如果可以转换为浮点数
float_value = float(value)
# 如果是整数,转换为整数
if float_value.is_integer():
normalized_properties[key] = int(float_value)
else:
normalized_properties[key] = float_value
except (ValueError, TypeError):
# 如果无法转换为数字,保持原值
normalized_properties[key] = value
normalized_args["properties"] = normalized_properties
# 确保batch_size和num_batches是整数
if "batch_size" in normalized_args:
try:
normalized_args["batch_size"] = int(normalized_args["batch_size"])
except (ValueError, TypeError):
raise ValueError(f"batch_size必须是整数而不是 {normalized_args['batch_size']}")
if "num_batches" in normalized_args:
try:
normalized_args["num_batches"] = int(normalized_args["num_batches"])
except (ValueError, TypeError):
raise ValueError(f"num_batches必须是整数而不是 {normalized_args['num_batches']}")
# 确保diffusion_guidance_factor是浮点数
if "diffusion_guidance_factor" in normalized_args:
try:
normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"])
except (ValueError, TypeError):
raise ValueError(f"diffusion_guidance_factor必须是数字而不是 {normalized_args['diffusion_guidance_factor']}")
return normalized_args
import requests
connection_pool = pooling.MySQLConnectionPool( connection_pool = pooling.MySQLConnectionPool(
pool_name="mypool", pool_name="mypool",
pool_size=32, pool_size=32,
@@ -17,7 +117,8 @@ connection_pool = pooling.MySQLConnectionPool(
password='siat-mic', password='siat-mic',
database='metadata_mat_papers' database='metadata_mat_papers'
) )
def process_retrieval_from_knowledge_base(data):
async def process_retrieval_from_knowledge_base(data):
doi = data.get('doi') doi = data.get('doi')
mp_id = data.get('mp_id') mp_id = data.get('mp_id')
@@ -76,6 +177,156 @@ def process_retrieval_from_knowledge_base(data):
markdown_result += f"\n## {field}\n{field_content}\n\n" markdown_result += f"\n## {field}\n{field_content}\n\n"
return markdown_result # 直接返回markdown文本 return markdown_result # 直接返回markdown文本
async def mattergen(
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen服务生成晶体结构
Args:
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
try:
# 导入MatterGenService
from mars_toolkit.services.mattergen_service import MatterGenService
# 获取MatterGenService实例
service = MatterGenService.get_instance()
# 使用服务生成材料
result = await service.generate(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
diffusion_guidance_factor=diffusion_guidance_factor
)
return result
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Error in mattergen: {e}")
import traceback
logger.error(traceback.format_exc())
return f"Error generating material: {str(e)}"
async def generate_material(
url="http://localhost:8051/generate_material",
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen API生成晶体结构
Args:
url: API端点URL
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
# 尝试使用本地MatterGen服务
try:
print("尝试使用本地MatterGen服务...")
result = await mattergen(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
diffusion_guidance_factor=diffusion_guidance_factor
)
if result and not result.startswith("Error"):
print("本地MatterGen服务生成成功!")
return result
else:
print(f"本地MatterGen服务生成失败尝试使用API: {result}")
except Exception as e:
print(f"本地MatterGen服务出错尝试使用API: {str(e)}")
# 如果本地服务失败回退到API调用
# 规范化参数
normalized_args = normalize_material_args({
"properties": properties,
"batch_size": batch_size,
"num_batches": num_batches,
"diffusion_guidance_factor": diffusion_guidance_factor
})
# 构建请求负载
payload = {
"properties": normalized_args["properties"],
"batch_size": normalized_args["batch_size"],
"num_batches": normalized_args["num_batches"],
"diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"]
}
print(f"发送请求到 {url}")
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
# 添加headers参数包含accept头
headers = {
"Content-Type": "application/json",
"accept": "application/json"
}
# 打印完整请求信息(调试用)
print(f"完整请求URL: {url}")
print(f"请求头: {json.dumps(headers, indent=2)}")
print(f"请求体: {json.dumps(payload, indent=2)}")
# 禁用代理设置
proxies = {
"http": None,
"https": None
}
# 发送POST请求添加headers参数禁用代理增加超时时间
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
# 打印响应信息(调试用)
print(f"响应状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符避免输出过长
# 检查响应状态
if response.status_code == 200:
result = response.json()
if result["success"]:
print("\n生成成功!")
return result["content"]
else:
print(f"\n生成失败: {result['message']}")
return None
else:
print(f"\n请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return None
except Exception as e:
print(f"\n发生错误: {str(e)}")
print(f"错误类型: {type(e).__name__}")
import traceback
print(f"错误堆栈: {traceback.format_exc()}")
return None
async def execute_tool_from_dict(input_dict: dict): async def execute_tool_from_dict(input_dict: dict):
""" """
从字典中提取工具函数名称和参数,并执行相应的工具函数 从字典中提取工具函数名称和参数,并执行相应的工具函数
@@ -149,38 +400,86 @@ async def execute_tool_from_dict(input_dict: dict):
return {"status": "error", "message": f"执行过程中出错: {str(e)}"} return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
# # 示例用法
# if __name__ == "__main__":
# # 示例输入
# input_str = '{"name": "search_material_property_from_material_project", "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}'
# # 调用函数
# result = asyncio.run(execute_tool_from_string(input_str))
# print(result)
def worker(data, output_file_path): def worker(data, output_file_path):
try: try:
# rich.console.Console().print(tools_schema)
# print(tools_schema)
func_contents = data["function_calls"] func_contents = data["function_calls"]
func_results = [] func_results = []
formatted_results = [] # 新增一个列表来存储格式化后的结果 formatted_results = [] # 新增一个列表来存储格式化后的结果
for func in func_contents: for func in func_contents:
func_name = func.get("name")
arguments_data = func.get("arguments")
# 使用富文本打印函数名
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# 使用富文本打印参数
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base': if func.get("name") == 'retrieval_from_knowledge_base':
func_name = func.get("name") delay_time = random.uniform(1, 5)
arguments_data = func.get("arguments")
# print('func_name', func_name) time.sleep(delay_time)
# print("argument", arguments_data) result = asyncio.run(process_retrieval_from_knowledge_base(data))
result = process_retrieval_from_knowledge_base(data)
func_results.append({"function": func['name'], "result": result}) func_results.append({"function": func['name'], "result": result})
# 格式化结果 # 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result) formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
# 规范化参数
try:
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
continue
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# 优先使用mattergen函数
try:
output = asyncio.run(generate_material(**normalized_args))
# 添加延迟,模拟额外的工具函数调用
# 随机延迟5-10秒
delay_time = random.uniform(5, 10)
print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
time.sleep(delay_time)
# 模拟其他工具函数调用的日志输出
print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
time.sleep(0.5)
print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
time.sleep(0.5)
print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
time.sleep(0.5)
print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
# 将结果添加到func_results
func_results.append({"function": func_name, "result": output})
# 格式化结果
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
else: else:
delay_time = random.uniform(1, 5)
time.sleep(delay_time)
result = asyncio.run(execute_tool_from_dict(func)) result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result}) func_results.append({"function": func['name'], "result": result})
# 格式化结果 # 格式化结果
@@ -190,23 +489,22 @@ def worker(data, output_file_path):
# 将所有格式化后的结果连接起来 # 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results) final_result = "\n\n\n".join(formatted_results)
data['observation']=final_result data['observation'] = final_result
# print("#"*50,"start","#"*50)
# print(data['obeservation'])
# print("#"*50,'end',"#"*50)
#return final_result # 返回格式化后的结果,而不是固定消息
# 使用富文本打印开始和结束标记
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
print(data['observation'])
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock: with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer: with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data writer.write(data) # observation . data
return f"Processed successfully" return f"Processed successfully"
except Exception as e: except Exception as e:
print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}" return f"Error processing: {str(e)}"
def main(datas, output_file_path, max_workers=1): def main(datas, output_file_path, max_workers=1):
import random import random
from tqdm import tqdm from tqdm import tqdm
@@ -260,11 +558,10 @@ if __name__ == '__main__':
print(len(datas)) print(len(datas))
# print() # print()
output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file,max_workers=8) main(datas, output_file, max_workers=16)
# print("开始测试 process_retrieval_from_knowledge_base 函数...") # 示例1使用正确的JSON格式
# data={'doi':'10.1016_s0025-5408(01)00495-0','mp_id':None} # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
# result = process_retrieval_from_knowledge_base(data) # argument = json.loads(argument)
# print("函数执行结果:") # print(json.dumps(argument, indent=2))
# print(result) # asyncio.run(mattergen(**argument))
# print("测试完成")

View File

@@ -12,6 +12,7 @@ import json
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional, Union, List from typing import Dict, Any, Optional, Union, List
import threading import threading
import torch
# 导入mattergen相关模块 # 导入mattergen相关模块
# import sys # import sys
@@ -38,6 +39,23 @@ class MatterGenService:
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
# 模型到GPU ID的映射
MODEL_TO_GPU = {
"mattergen_base": "0", # 基础模型使用GPU 0
"dft_mag_density": "1", # 磁密度模型使用GPU 1
"dft_bulk_modulus": "2", # 体积模量模型使用GPU 2
"dft_shear_modulus": "3", # 剪切模量模型使用GPU 3
"energy_above_hull": "4", # 能量模型使用GPU 4
"formation_energy_per_atom": "5", # 形成能模型使用GPU 5
"space_group": "6", # 空间群模型使用GPU 6
"hhi_score": "7", # HHI评分模型使用GPU 7
"ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0
"chemical_system": "1", # 化学系统模型使用GPU 1
"dft_band_gap": "2", # 带隙模型使用GPU 2
"dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3
"chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4
}
@classmethod @classmethod
def get_instance(cls): def get_instance(cls):
""" """
@@ -125,13 +143,14 @@ class MatterGenService:
diffusion_guidance_factor: Controls adherence to target properties diffusion_guidance_factor: Controls adherence to target properties
Returns: Returns:
tuple: (generator, generator_key, properties_to_condition_on) tuple: (generator, generator_key, properties_to_condition_on, gpu_id)
""" """
# 如果没有属性约束,使用基础生成器 # 如果没有属性约束,使用基础生成器
if not properties: if not properties:
if "base" not in self._generators: if "base" not in self._generators:
self._init_base_generator() self._init_base_generator()
return self._generators.get("base"), "base", None gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0
return self._generators.get("base"), "base", None, gpu_id
# 处理属性约束 # 处理属性约束
properties_to_condition_on = {} properties_to_condition_on = {}
@@ -171,6 +190,9 @@ class MatterGenService:
model_dir = first_property model_dir = first_property
generator_key = f"multi_{first_property}_etc" generator_key = f"multi_{first_property}_etc"
# 获取对应的GPU ID
gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0
# 构建完整的模型路径 # 构建完整的模型路径
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir) model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
@@ -188,7 +210,7 @@ class MatterGenService:
generator.batch_size = batch_size generator.batch_size = batch_size
generator.num_batches = num_batches generator.num_batches = num_batches
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0 generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
return generator, generator_key, properties_to_condition_on return generator, generator_key, properties_to_condition_on, gpu_id
# 创建新的生成器 # 创建新的生成器
try: try:
@@ -216,13 +238,14 @@ class MatterGenService:
self._generators[generator_key] = generator self._generators[generator_key] = generator
logger.info(f"MatterGen generator for {generator_key} initialized successfully") logger.info(f"MatterGen generator for {generator_key} initialized successfully")
return generator, generator_key, properties_to_condition_on return generator, generator_key, properties_to_condition_on, gpu_id
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}") logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
# 回退到基础生成器 # 回退到基础生成器
if "base" not in self._generators: if "base" not in self._generators:
self._init_base_generator() self._init_base_generator()
return self._generators.get("base"), "base", None base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0")
return self._generators.get("base"), "base", None, base_gpu_id
def generate( def generate(
self, self,
@@ -255,14 +278,24 @@ class MatterGenService:
# 如果为None默认为空字典 # 如果为None默认为空字典
properties = properties or {} properties = properties or {}
# 获取或创建生成器 # 获取或创建生成器和GPU ID
generator, generator_key, properties_to_condition_on = self._get_or_create_generator( generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator(
properties, batch_size, num_batches, diffusion_guidance_factor properties, batch_size, num_batches, diffusion_guidance_factor
) )
print("gpu_id",gpu_id)
if generator is None: if generator is None:
return "Error: Failed to initialize MatterGen generator" return "Error: Failed to initialize MatterGen generator"
# 使用torch.cuda.set_device()直接设置当前GPU
try:
# 将字符串类型的gpu_id转换为整数
cuda_device_id = int(gpu_id)
torch.cuda.set_device(cuda_device_id)
logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}")
print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}")
except Exception as e:
logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.")
# 生成结构 # 生成结构
try: try:
generator.generate(output_dir=Path(self._output_dir)) generator.generate(output_dir=Path(self._output_dir))
@@ -339,4 +372,7 @@ You can use these structures for materials discovery, property prediction, or fu
except Exception as e: except Exception as e:
logger.warning(f"Error cleaning up files: {e}") logger.warning(f"Error cleaning up files: {e}")
# GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理
logger.info(f"Generation completed on GPU for model {generator_key}")
return prompt return prompt