生成数据:mattergen改成了同步

This commit is contained in:
lzy
2025-04-06 20:35:13 +08:00
parent 71d8dabd17
commit 72045e5cfe
14 changed files with 557 additions and 191 deletions

View File

@@ -9,9 +9,18 @@ import asyncio
import zipfile
import shutil
import re
import multiprocessing
from multiprocessing import Process, Queue
from pathlib import Path
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
# 设置多进程启动方法为spawn解决CUDA初始化错误
try:
multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
# 如果已经设置过启动方法会抛出RuntimeError
pass
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.atoms import Atoms
@@ -33,6 +42,49 @@ from ..core.mattergen_wrapper import *
logger = logging.getLogger(__name__)
def _process_generate_material_worker(args_queue, result_queue):
"""
在新进程中处理材料生成的工作函数
Args:
args_queue: 包含生成参数的队列
result_queue: 用于返回结果的队列
"""
try:
# 配置日志
import logging
logger = logging.getLogger(__name__)
logger.info("子进程开始执行材料生成...")
# 从队列获取参数
args = args_queue.get()
logger.info(f"子进程获取到参数: {args}")
# 导入MatterGenService
from mars_toolkit.services.mattergen_service import MatterGenService
logger.info("子进程成功导入MatterGenService")
# 获取MatterGenService实例
service = MatterGenService.get_instance()
logger.info("子进程成功获取MatterGenService实例")
# 使用服务生成材料
logger.info("子进程开始调用generate方法...")
result = service.generate(**args)
logger.info("子进程generate方法调用完成")
# 将结果放入结果队列
result_queue.put(result)
logger.info("子进程材料生成完成,结果已放入队列")
except Exception as e:
# 如果发生错误,将错误信息放入结果队列
import traceback
error_msg = f"材料生成过程中出错: {str(e)}\n{traceback.format_exc()}"
import logging
logging.getLogger(__name__).error(error_msg)
result_queue.put(f"Error: {error_msg}")
def format_cif_content(content):
"""
Format CIF content by removing unnecessary headers and organizing each CIF file.
@@ -233,7 +285,7 @@ def main(
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
async def generate_material(
def generate_material(
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
batch_size: int = 2,
num_batches: int = 1,
@@ -260,16 +312,45 @@ async def generate_material(
Returns:
Descriptive text with generated crystal structures in CIF format
"""
# # 创建队列用于进程间通信
# args_queue = Queue()
# result_queue = Queue()
# # 将参数放入队列
# args_queue.put({
# "properties": properties,
# "batch_size": batch_size,
# "num_batches": num_batches,
# "diffusion_guidance_factor": diffusion_guidance_factor
# })
# # 创建并启动新进程
# logger.info("启动新进程处理材料生成...")
# p = Process(target=_process_generate_material_worker, args=(args_queue, result_queue))
# p.start()
# # 等待进程完成并获取结果
# p.join()
# result = result_queue.get()
# # 检查结果是否为错误信息
# if isinstance(result, str) and result.startswith("Error:"):
# # 记录错误日志
# logger.error(result)
# 导入MatterGenService
from mars_toolkit.services.mattergen_service import MatterGenService
logger.info("子进程成功导入MatterGenService")
# 获取MatterGenService实例
service = MatterGenService.get_instance()
logger.info("子进程成功获取MatterGenService实例")
# 使用服务生成材料
return service.generate(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
diffusion_guidance_factor=diffusion_guidance_factor
)
logger.info("子进程开始调用generate方法...")
result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
logger.info("子进程generate方法调用完成")
if "Error generating structures" in result:
return f"Error: Invalid properties {properties}."
else:
return result

View File

@@ -35,7 +35,7 @@ class Config:
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
# Searxng
SEARXNG_HOST="http://192.168.191.101:40032/"
SEARXNG_HOST="http://192.168.168.1:40032/"
# Visualization
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'

View File

@@ -5,6 +5,7 @@ This module provides functions for searching information on the web.
"""
import asyncio
import os
from typing import Annotated, Dict, Any, List
from langchain_community.utilities import SearxSearchWrapper
@@ -28,6 +29,8 @@ async def search_online(
Formatted string with search results (titles, snippets, links)
"""
# 确保 num_results 是整数
os.environ['HTTP_PROXY'] = ''
os.environ['HTTPS_PROXY'] = ''
try:
num_results = int(num_results)
except (TypeError, ValueError):