生成数据:mattergen改成了同步
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -9,9 +9,18 @@ import asyncio
|
||||
import zipfile
|
||||
import shutil
|
||||
import re
|
||||
import multiprocessing
|
||||
from multiprocessing import Process, Queue
|
||||
from pathlib import Path
|
||||
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
||||
|
||||
# 设置多进程启动方法为spawn,解决CUDA初始化错误
|
||||
try:
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
except RuntimeError:
|
||||
# 如果已经设置过启动方法,会抛出RuntimeError
|
||||
pass
|
||||
|
||||
from ase.optimize import FIRE
|
||||
from ase.filters import FrechetCellFilter
|
||||
from ase.atoms import Atoms
|
||||
@@ -33,6 +42,49 @@ from ..core.mattergen_wrapper import *
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _process_generate_material_worker(args_queue, result_queue):
|
||||
"""
|
||||
在新进程中处理材料生成的工作函数
|
||||
|
||||
Args:
|
||||
args_queue: 包含生成参数的队列
|
||||
result_queue: 用于返回结果的队列
|
||||
"""
|
||||
try:
|
||||
# 配置日志
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("子进程开始执行材料生成...")
|
||||
|
||||
# 从队列获取参数
|
||||
args = args_queue.get()
|
||||
logger.info(f"子进程获取到参数: {args}")
|
||||
|
||||
# 导入MatterGenService
|
||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||
logger.info("子进程成功导入MatterGenService")
|
||||
|
||||
# 获取MatterGenService实例
|
||||
service = MatterGenService.get_instance()
|
||||
logger.info("子进程成功获取MatterGenService实例")
|
||||
|
||||
# 使用服务生成材料
|
||||
logger.info("子进程开始调用generate方法...")
|
||||
result = service.generate(**args)
|
||||
logger.info("子进程generate方法调用完成")
|
||||
|
||||
# 将结果放入结果队列
|
||||
result_queue.put(result)
|
||||
logger.info("子进程材料生成完成,结果已放入队列")
|
||||
except Exception as e:
|
||||
# 如果发生错误,将错误信息放入结果队列
|
||||
import traceback
|
||||
error_msg = f"材料生成过程中出错: {str(e)}\n{traceback.format_exc()}"
|
||||
import logging
|
||||
logging.getLogger(__name__).error(error_msg)
|
||||
result_queue.put(f"Error: {error_msg}")
|
||||
|
||||
|
||||
def format_cif_content(content):
|
||||
"""
|
||||
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
||||
@@ -233,7 +285,7 @@ def main(
|
||||
|
||||
|
||||
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
|
||||
async def generate_material(
|
||||
def generate_material(
|
||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||
batch_size: int = 2,
|
||||
num_batches: int = 1,
|
||||
@@ -260,16 +312,45 @@ async def generate_material(
|
||||
Returns:
|
||||
Descriptive text with generated crystal structures in CIF format
|
||||
"""
|
||||
# # 创建队列用于进程间通信
|
||||
# args_queue = Queue()
|
||||
# result_queue = Queue()
|
||||
|
||||
# # 将参数放入队列
|
||||
# args_queue.put({
|
||||
# "properties": properties,
|
||||
# "batch_size": batch_size,
|
||||
# "num_batches": num_batches,
|
||||
# "diffusion_guidance_factor": diffusion_guidance_factor
|
||||
# })
|
||||
|
||||
# # 创建并启动新进程
|
||||
# logger.info("启动新进程处理材料生成...")
|
||||
# p = Process(target=_process_generate_material_worker, args=(args_queue, result_queue))
|
||||
# p.start()
|
||||
|
||||
# # 等待进程完成并获取结果
|
||||
# p.join()
|
||||
# result = result_queue.get()
|
||||
|
||||
# # 检查结果是否为错误信息
|
||||
# if isinstance(result, str) and result.startswith("Error:"):
|
||||
# # 记录错误日志
|
||||
# logger.error(result)
|
||||
|
||||
# 导入MatterGenService
|
||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||
logger.info("子进程成功导入MatterGenService")
|
||||
|
||||
# 获取MatterGenService实例
|
||||
service = MatterGenService.get_instance()
|
||||
logger.info("子进程成功获取MatterGenService实例")
|
||||
|
||||
# 使用服务生成材料
|
||||
return service.generate(
|
||||
properties=properties,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor
|
||||
)
|
||||
logger.info("子进程开始调用generate方法...")
|
||||
result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
|
||||
logger.info("子进程generate方法调用完成")
|
||||
if "Error generating structures" in result:
|
||||
return f"Error: Invalid properties {properties}."
|
||||
else:
|
||||
return result
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -35,7 +35,7 @@ class Config:
|
||||
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
|
||||
|
||||
# Searxng
|
||||
SEARXNG_HOST="http://192.168.191.101:40032/"
|
||||
SEARXNG_HOST="http://192.168.168.1:40032/"
|
||||
|
||||
# Visualization
|
||||
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
|
||||
|
||||
Binary file not shown.
@@ -5,6 +5,7 @@ This module provides functions for searching information on the web.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Annotated, Dict, Any, List
|
||||
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
@@ -28,6 +29,8 @@ async def search_online(
|
||||
Formatted string with search results (titles, snippets, links)
|
||||
"""
|
||||
# 确保 num_results 是整数
|
||||
os.environ['HTTP_PROXY'] = ''
|
||||
os.environ['HTTPS_PROXY'] = ''
|
||||
try:
|
||||
num_results = int(num_results)
|
||||
except (TypeError, ValueError):
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user