mattergen调用指定GPU&规范化mattergen的输入
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@ model_agent_test.py
|
||||
pyproject.toml
|
||||
/pretrained_models
|
||||
/mcp-python-sdk
|
||||
/.vscode
|
||||
|
||||
BIN
__pycache__/normalize_material_args.cpython-310.pyc
Normal file
BIN
__pycache__/normalize_material_args.cpython-310.pyc
Normal file
Binary file not shown.
@@ -1,13 +1,113 @@
|
||||
import json
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from tools_for_ms.llm_tools import *
|
||||
|
||||
import jsonlines
|
||||
from mars_toolkit import *
|
||||
import threading
|
||||
import uuid
|
||||
# Create a lock for file writing
|
||||
file_lock = threading.Lock()
|
||||
from mysql.connector import pooling
|
||||
from mysql.connector import pooling
|
||||
from colorama import Fore, Back, Style, init
|
||||
import time
|
||||
import random
|
||||
# 初始化colorama
|
||||
init(autoreset=True)
|
||||
|
||||
from typing import Dict, Union, Any, Optional
|
||||
|
||||
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
规范化传递给generate_material函数的参数格式。
|
||||
|
||||
处理以下情况:
|
||||
1. properties参数可能是字符串形式的JSON,需要解析为字典
|
||||
2. properties中的值可能需要转换为适当的类型(数字或字符串)
|
||||
3. 确保batch_size和num_batches是整数
|
||||
|
||||
Args:
|
||||
arguments: 包含generate_material参数的字典
|
||||
|
||||
Returns:
|
||||
规范化后的参数字典
|
||||
"""
|
||||
normalized_args = arguments.copy()
|
||||
|
||||
# 处理properties参数
|
||||
if "properties" in normalized_args:
|
||||
properties = normalized_args["properties"]
|
||||
|
||||
# 如果properties是字符串,尝试解析为JSON
|
||||
if isinstance(properties, str):
|
||||
try:
|
||||
properties = json.loads(properties)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"无法解析properties JSON字符串: {e}")
|
||||
|
||||
# 确保properties是字典
|
||||
if not isinstance(properties, dict):
|
||||
raise ValueError(f"properties必须是字典或JSON字符串,而不是 {type(properties)}")
|
||||
|
||||
# 处理properties中的值
|
||||
normalized_properties = {}
|
||||
for key, value in properties.items():
|
||||
# 处理范围值,例如 "0.0-2.0" 或 "40-50"
|
||||
if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"):
|
||||
# 保持范围值为字符串格式
|
||||
normalized_properties[key] = value
|
||||
elif isinstance(value, str) and value.startswith(">"):
|
||||
# 保持大于值为字符串格式
|
||||
normalized_properties[key] = value
|
||||
elif isinstance(value, str) and value.startswith("<"):
|
||||
# 保持小于值为字符串格式
|
||||
normalized_properties[key] = value
|
||||
elif isinstance(value, str) and value.lower() == "relaxor":
|
||||
# 特殊值保持为字符串
|
||||
normalized_properties[key] = value
|
||||
elif isinstance(value, str) and value.endswith("eV"):
|
||||
# 带单位的值保持为字符串
|
||||
normalized_properties[key] = value
|
||||
else:
|
||||
# 尝试将值转换为数字
|
||||
try:
|
||||
# 如果可以转换为浮点数
|
||||
float_value = float(value)
|
||||
# 如果是整数,转换为整数
|
||||
if float_value.is_integer():
|
||||
normalized_properties[key] = int(float_value)
|
||||
else:
|
||||
normalized_properties[key] = float_value
|
||||
except (ValueError, TypeError):
|
||||
# 如果无法转换为数字,保持原值
|
||||
normalized_properties[key] = value
|
||||
|
||||
normalized_args["properties"] = normalized_properties
|
||||
|
||||
# 确保batch_size和num_batches是整数
|
||||
if "batch_size" in normalized_args:
|
||||
try:
|
||||
normalized_args["batch_size"] = int(normalized_args["batch_size"])
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"batch_size必须是整数,而不是 {normalized_args['batch_size']}")
|
||||
|
||||
if "num_batches" in normalized_args:
|
||||
try:
|
||||
normalized_args["num_batches"] = int(normalized_args["num_batches"])
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"num_batches必须是整数,而不是 {normalized_args['num_batches']}")
|
||||
|
||||
# 确保diffusion_guidance_factor是浮点数
|
||||
if "diffusion_guidance_factor" in normalized_args:
|
||||
try:
|
||||
normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"])
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"diffusion_guidance_factor必须是数字,而不是 {normalized_args['diffusion_guidance_factor']}")
|
||||
|
||||
return normalized_args
|
||||
|
||||
|
||||
import requests
|
||||
connection_pool = pooling.MySQLConnectionPool(
|
||||
pool_name="mypool",
|
||||
pool_size=32,
|
||||
@@ -17,7 +117,8 @@ connection_pool = pooling.MySQLConnectionPool(
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
def process_retrieval_from_knowledge_base(data):
|
||||
|
||||
async def process_retrieval_from_knowledge_base(data):
|
||||
doi = data.get('doi')
|
||||
mp_id = data.get('mp_id')
|
||||
|
||||
@@ -76,6 +177,156 @@ def process_retrieval_from_knowledge_base(data):
|
||||
markdown_result += f"\n## {field}\n{field_content}\n\n"
|
||||
|
||||
return markdown_result # 直接返回markdown文本
|
||||
|
||||
|
||||
|
||||
async def mattergen(
|
||||
properties=None,
|
||||
batch_size=2,
|
||||
num_batches=1,
|
||||
diffusion_guidance_factor=2.0
|
||||
):
|
||||
"""
|
||||
调用MatterGen服务生成晶体结构
|
||||
|
||||
Args:
|
||||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
||||
batch_size: 每批生成的结构数量
|
||||
num_batches: 批次数量
|
||||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
||||
|
||||
Returns:
|
||||
生成的结构内容或错误信息
|
||||
"""
|
||||
try:
|
||||
# 导入MatterGenService
|
||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||
|
||||
# 获取MatterGenService实例
|
||||
service = MatterGenService.get_instance()
|
||||
|
||||
# 使用服务生成材料
|
||||
result = await service.generate(
|
||||
properties=properties,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Error in mattergen: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return f"Error generating material: {str(e)}"
|
||||
|
||||
async def generate_material(
|
||||
url="http://localhost:8051/generate_material",
|
||||
properties=None,
|
||||
batch_size=2,
|
||||
num_batches=1,
|
||||
diffusion_guidance_factor=2.0
|
||||
):
|
||||
"""
|
||||
调用MatterGen API生成晶体结构
|
||||
|
||||
Args:
|
||||
url: API端点URL
|
||||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
||||
batch_size: 每批生成的结构数量
|
||||
num_batches: 批次数量
|
||||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
||||
|
||||
Returns:
|
||||
生成的结构内容或错误信息
|
||||
"""
|
||||
# 尝试使用本地MatterGen服务
|
||||
try:
|
||||
print("尝试使用本地MatterGen服务...")
|
||||
result = await mattergen(
|
||||
properties=properties,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor
|
||||
)
|
||||
if result and not result.startswith("Error"):
|
||||
print("本地MatterGen服务生成成功!")
|
||||
return result
|
||||
else:
|
||||
print(f"本地MatterGen服务生成失败,尝试使用API: {result}")
|
||||
except Exception as e:
|
||||
print(f"本地MatterGen服务出错,尝试使用API: {str(e)}")
|
||||
|
||||
# 如果本地服务失败,回退到API调用
|
||||
# 规范化参数
|
||||
normalized_args = normalize_material_args({
|
||||
"properties": properties,
|
||||
"batch_size": batch_size,
|
||||
"num_batches": num_batches,
|
||||
"diffusion_guidance_factor": diffusion_guidance_factor
|
||||
})
|
||||
|
||||
# 构建请求负载
|
||||
payload = {
|
||||
"properties": normalized_args["properties"],
|
||||
"batch_size": normalized_args["batch_size"],
|
||||
"num_batches": normalized_args["num_batches"],
|
||||
"diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"]
|
||||
}
|
||||
|
||||
print(f"发送请求到 {url}")
|
||||
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
||||
|
||||
try:
|
||||
# 添加headers参数,包含accept头
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"accept": "application/json"
|
||||
}
|
||||
|
||||
# 打印完整请求信息(调试用)
|
||||
print(f"完整请求URL: {url}")
|
||||
print(f"请求头: {json.dumps(headers, indent=2)}")
|
||||
print(f"请求体: {json.dumps(payload, indent=2)}")
|
||||
|
||||
# 禁用代理设置
|
||||
proxies = {
|
||||
"http": None,
|
||||
"https": None
|
||||
}
|
||||
|
||||
# 发送POST请求,添加headers参数,禁用代理,增加超时时间
|
||||
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
|
||||
|
||||
# 打印响应信息(调试用)
|
||||
print(f"响应状态码: {response.status_code}")
|
||||
print(f"响应头: {dict(response.headers)}")
|
||||
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
||||
if result["success"]:
|
||||
print("\n生成成功!")
|
||||
return result["content"]
|
||||
else:
|
||||
print(f"\n生成失败: {result['message']}")
|
||||
return None
|
||||
else:
|
||||
print(f"\n请求失败,状态码: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n发生错误: {str(e)}")
|
||||
print(f"错误类型: {type(e).__name__}")
|
||||
import traceback
|
||||
print(f"错误堆栈: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def execute_tool_from_dict(input_dict: dict):
|
||||
"""
|
||||
从字典中提取工具函数名称和参数,并执行相应的工具函数
|
||||
@@ -149,38 +400,86 @@ async def execute_tool_from_dict(input_dict: dict):
|
||||
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
|
||||
|
||||
|
||||
|
||||
|
||||
# # 示例用法
|
||||
# if __name__ == "__main__":
|
||||
# # 示例输入
|
||||
# input_str = '{"name": "search_material_property_from_material_project", "arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}'
|
||||
|
||||
# # 调用函数
|
||||
# result = asyncio.run(execute_tool_from_string(input_str))
|
||||
# print(result)
|
||||
|
||||
|
||||
def worker(data, output_file_path):
|
||||
|
||||
try:
|
||||
# rich.console.Console().print(tools_schema)
|
||||
# print(tools_schema)
|
||||
func_contents = data["function_calls"]
|
||||
func_results = []
|
||||
formatted_results = [] # 新增一个列表来存储格式化后的结果
|
||||
for func in func_contents:
|
||||
func_name = func.get("name")
|
||||
arguments_data = func.get("arguments")
|
||||
|
||||
# 使用富文本打印函数名
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||
|
||||
# 使用富文本打印参数
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
||||
|
||||
if func.get("name") == 'retrieval_from_knowledge_base':
|
||||
func_name = func.get("name")
|
||||
arguments_data = func.get("arguments")
|
||||
# print('func_name', func_name)
|
||||
# print("argument", arguments_data)
|
||||
result = process_retrieval_from_knowledge_base(data)
|
||||
delay_time = random.uniform(1, 5)
|
||||
|
||||
time.sleep(delay_time)
|
||||
result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
||||
func_results.append({"function": func['name'], "result": result})
|
||||
# 格式化结果
|
||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
elif func.get("name") == 'generate_material':
|
||||
# 规范化参数
|
||||
try:
|
||||
# 确保arguments_data是字典
|
||||
if isinstance(arguments_data, str):
|
||||
try:
|
||||
arguments_data = json.loads(arguments_data)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
|
||||
continue
|
||||
|
||||
# 规范化参数
|
||||
normalized_args = normalize_material_args(arguments_data)
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||
|
||||
# 优先使用mattergen函数
|
||||
try:
|
||||
output = asyncio.run(generate_material(**normalized_args))
|
||||
|
||||
# 添加延迟,模拟额外的工具函数调用
|
||||
|
||||
|
||||
# 随机延迟5-10秒
|
||||
delay_time = random.uniform(5, 10)
|
||||
print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
|
||||
time.sleep(delay_time)
|
||||
|
||||
# 模拟其他工具函数调用的日志输出
|
||||
print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
|
||||
time.sleep(0.5)
|
||||
print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
|
||||
time.sleep(0.5)
|
||||
print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
|
||||
time.sleep(0.5)
|
||||
print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
||||
|
||||
# 将结果添加到func_results
|
||||
func_results.append({"function": func_name, "result": output})
|
||||
|
||||
# 格式化结果
|
||||
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
except Exception as e:
|
||||
print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
|
||||
import traceback
|
||||
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
||||
|
||||
else:
|
||||
delay_time = random.uniform(1, 5)
|
||||
time.sleep(delay_time)
|
||||
result = asyncio.run(execute_tool_from_dict(func))
|
||||
func_results.append({"function": func['name'], "result": result})
|
||||
# 格式化结果
|
||||
@@ -190,23 +489,22 @@ def worker(data, output_file_path):
|
||||
|
||||
# 将所有格式化后的结果连接起来
|
||||
final_result = "\n\n\n".join(formatted_results)
|
||||
data['observation']=final_result
|
||||
# print("#"*50,"start","#"*50)
|
||||
# print(data['obeservation'])
|
||||
# print("#"*50,'end',"#"*50)
|
||||
#return final_result # 返回格式化后的结果,而不是固定消息
|
||||
|
||||
data['observation'] = final_result
|
||||
|
||||
# 使用富文本打印开始和结束标记
|
||||
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
|
||||
print(data['observation'])
|
||||
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
|
||||
with file_lock:
|
||||
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||
writer.write(data) # observation . data
|
||||
return f"Processed successfully"
|
||||
|
||||
except Exception as e:
|
||||
print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
|
||||
return f"Error processing: {str(e)}"
|
||||
|
||||
|
||||
|
||||
def main(datas, output_file_path, max_workers=1):
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
@@ -260,11 +558,10 @@ if __name__ == '__main__':
|
||||
print(len(datas))
|
||||
# print()
|
||||
output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||
main(datas, output_file,max_workers=8)
|
||||
main(datas, output_file, max_workers=16)
|
||||
|
||||
# print("开始测试 process_retrieval_from_knowledge_base 函数...")
|
||||
# data={'doi':'10.1016_s0025-5408(01)00495-0','mp_id':None}
|
||||
# result = process_retrieval_from_knowledge_base(data)
|
||||
# print("函数执行结果:")
|
||||
# print(result)
|
||||
# print("测试完成")
|
||||
# 示例1:使用正确的JSON格式
|
||||
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
|
||||
# argument = json.loads(argument)
|
||||
# print(json.dumps(argument, indent=2))
|
||||
# asyncio.run(mattergen(**argument))
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -12,6 +12,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
import threading
|
||||
import torch
|
||||
|
||||
# 导入mattergen相关模块
|
||||
# import sys
|
||||
@@ -38,6 +39,23 @@ class MatterGenService:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
# 模型到GPU ID的映射
|
||||
MODEL_TO_GPU = {
|
||||
"mattergen_base": "0", # 基础模型使用GPU 0
|
||||
"dft_mag_density": "1", # 磁密度模型使用GPU 1
|
||||
"dft_bulk_modulus": "2", # 体积模量模型使用GPU 2
|
||||
"dft_shear_modulus": "3", # 剪切模量模型使用GPU 3
|
||||
"energy_above_hull": "4", # 能量模型使用GPU 4
|
||||
"formation_energy_per_atom": "5", # 形成能模型使用GPU 5
|
||||
"space_group": "6", # 空间群模型使用GPU 6
|
||||
"hhi_score": "7", # HHI评分模型使用GPU 7
|
||||
"ml_bulk_modulus": "0", # ML体积模量模型使用GPU 0
|
||||
"chemical_system": "1", # 化学系统模型使用GPU 1
|
||||
"dft_band_gap": "2", # 带隙模型使用GPU 2
|
||||
"dft_mag_density_hhi_score": "3", # 多属性模型使用GPU 3
|
||||
"chemical_system_energy_above_hull": "4" # 多属性模型使用GPU 4
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""
|
||||
@@ -125,13 +143,14 @@ class MatterGenService:
|
||||
diffusion_guidance_factor: Controls adherence to target properties
|
||||
|
||||
Returns:
|
||||
tuple: (generator, generator_key, properties_to_condition_on)
|
||||
tuple: (generator, generator_key, properties_to_condition_on, gpu_id)
|
||||
"""
|
||||
# 如果没有属性约束,使用基础生成器
|
||||
if not properties:
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
return self._generators.get("base"), "base", None
|
||||
gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0") # 默认使用GPU 0
|
||||
return self._generators.get("base"), "base", None, gpu_id
|
||||
|
||||
# 处理属性约束
|
||||
properties_to_condition_on = {}
|
||||
@@ -171,6 +190,9 @@ class MatterGenService:
|
||||
model_dir = first_property
|
||||
generator_key = f"multi_{first_property}_etc"
|
||||
|
||||
# 获取对应的GPU ID
|
||||
gpu_id = self.MODEL_TO_GPU.get(model_dir, "0") # 默认使用GPU 0
|
||||
|
||||
# 构建完整的模型路径
|
||||
model_path = os.path.join(config.MATTERGENMODEL_ROOT, model_dir)
|
||||
|
||||
@@ -188,7 +210,7 @@ class MatterGenService:
|
||||
generator.batch_size = batch_size
|
||||
generator.num_batches = num_batches
|
||||
generator.diffusion_guidance_factor = diffusion_guidance_factor if properties else 0.0
|
||||
return generator, generator_key, properties_to_condition_on
|
||||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||||
|
||||
# 创建新的生成器
|
||||
try:
|
||||
@@ -216,13 +238,14 @@ class MatterGenService:
|
||||
|
||||
self._generators[generator_key] = generator
|
||||
logger.info(f"MatterGen generator for {generator_key} initialized successfully")
|
||||
return generator, generator_key, properties_to_condition_on
|
||||
return generator, generator_key, properties_to_condition_on, gpu_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MatterGen generator for {generator_key}: {e}")
|
||||
# 回退到基础生成器
|
||||
if "base" not in self._generators:
|
||||
self._init_base_generator()
|
||||
return self._generators.get("base"), "base", None
|
||||
base_gpu_id = self.MODEL_TO_GPU.get("mattergen_base", "0")
|
||||
return self._generators.get("base"), "base", None, base_gpu_id
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -255,14 +278,24 @@ class MatterGenService:
|
||||
# 如果为None,默认为空字典
|
||||
properties = properties or {}
|
||||
|
||||
# 获取或创建生成器
|
||||
generator, generator_key, properties_to_condition_on = self._get_or_create_generator(
|
||||
# 获取或创建生成器和GPU ID
|
||||
generator, generator_key, properties_to_condition_on, gpu_id = self._get_or_create_generator(
|
||||
properties, batch_size, num_batches, diffusion_guidance_factor
|
||||
)
|
||||
|
||||
print("gpu_id",gpu_id)
|
||||
if generator is None:
|
||||
return "Error: Failed to initialize MatterGen generator"
|
||||
|
||||
# 使用torch.cuda.set_device()直接设置当前GPU
|
||||
try:
|
||||
# 将字符串类型的gpu_id转换为整数
|
||||
cuda_device_id = int(gpu_id)
|
||||
torch.cuda.set_device(cuda_device_id)
|
||||
logger.info(f"Setting CUDA device to GPU {cuda_device_id} for model {generator_key}")
|
||||
print(f"Using GPU {cuda_device_id} (CUDA device index) for model {generator_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error setting CUDA device: {e}. Falling back to default device.")
|
||||
|
||||
# 生成结构
|
||||
try:
|
||||
generator.generate(output_dir=Path(self._output_dir))
|
||||
@@ -339,4 +372,7 @@ You can use these structures for materials discovery, property prediction, or fu
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up files: {e}")
|
||||
|
||||
# GPU设备已经在生成前由torch.cuda.set_device()设置,不需要额外清理
|
||||
logger.info(f"Generation completed on GPU for model {generator_key}")
|
||||
|
||||
return prompt
|
||||
|
||||
Reference in New Issue
Block a user