生成数据:mattergen改成了同步

This commit is contained in:
lzy
2025-04-06 20:35:13 +08:00
parent 71d8dabd17
commit 72045e5cfe
14 changed files with 557 additions and 191 deletions

2
.gitignore vendored
View File

@@ -7,3 +7,5 @@ pyproject.toml
/pretrained_models /pretrained_models
/mcp-python-sdk /mcp-python-sdk
/.vscode /.vscode
/*filter_ok_questions_solutions_agent*

Binary file not shown.

View File

@@ -6,6 +6,8 @@ import jsonlines
from mars_toolkit import * from mars_toolkit import *
import threading import threading
import uuid import uuid
from mars_toolkit.compute.material_gen import generate_material
# Create a lock for file writing # Create a lock for file writing
file_lock = threading.Lock() file_lock = threading.Lock()
from mysql.connector import pooling from mysql.connector import pooling
@@ -180,153 +182,6 @@ async def process_retrieval_from_knowledge_base(data):
async def mattergen(
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen服务生成晶体结构
Args:
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
try:
# 导入MatterGenService
from mars_toolkit.services.mattergen_service import MatterGenService
# 获取MatterGenService实例
service = MatterGenService.get_instance()
# 使用服务生成材料
result = await service.generate(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
diffusion_guidance_factor=diffusion_guidance_factor
)
return result
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Error in mattergen: {e}")
import traceback
logger.error(traceback.format_exc())
return f"Error generating material: {str(e)}"
async def generate_material(
url="http://localhost:8051/generate_material",
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen API生成晶体结构
Args:
url: API端点URL
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
# 尝试使用本地MatterGen服务
try:
print("尝试使用本地MatterGen服务...")
result = await mattergen(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
diffusion_guidance_factor=diffusion_guidance_factor
)
if result and not result.startswith("Error"):
print("本地MatterGen服务生成成功!")
return result
else:
print(f"本地MatterGen服务生成失败尝试使用API: {result}")
except Exception as e:
print(f"本地MatterGen服务出错尝试使用API: {str(e)}")
# 如果本地服务失败回退到API调用
# 规范化参数
normalized_args = normalize_material_args({
"properties": properties,
"batch_size": batch_size,
"num_batches": num_batches,
"diffusion_guidance_factor": diffusion_guidance_factor
})
# 构建请求负载
payload = {
"properties": normalized_args["properties"],
"batch_size": normalized_args["batch_size"],
"num_batches": normalized_args["num_batches"],
"diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"]
}
print(f"发送请求到 {url}")
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
# 添加headers参数包含accept头
headers = {
"Content-Type": "application/json",
"accept": "application/json"
}
# 打印完整请求信息(调试用)
print(f"完整请求URL: {url}")
print(f"请求头: {json.dumps(headers, indent=2)}")
print(f"请求体: {json.dumps(payload, indent=2)}")
# 禁用代理设置
proxies = {
"http": None,
"https": None
}
# 发送POST请求添加headers参数禁用代理增加超时时间
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
# 打印响应信息(调试用)
print(f"响应状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符避免输出过长
# 检查响应状态
if response.status_code == 200:
result = response.json()
if result["success"]:
print("\n生成成功!")
return result["content"]
else:
print(f"\n生成失败: {result['message']}")
return None
else:
print(f"\n请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return None
except Exception as e:
print(f"\n发生错误: {str(e)}")
print(f"错误类型: {type(e).__name__}")
import traceback
print(f"错误堆栈: {traceback.format_exc()}")
return None
async def execute_tool_from_dict(input_dict: dict): async def execute_tool_from_dict(input_dict: dict):
""" """
从字典中提取工具函数名称和参数,并执行相应的工具函数 从字典中提取工具函数名称和参数,并执行相应的工具函数
@@ -416,14 +271,14 @@ def worker(data, output_file_path):
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}") print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base': if func.get("name") == 'retrieval_from_knowledge_base':
delay_time = random.uniform(1, 5) pass
# delay_time = random.uniform(5, 10)
time.sleep(delay_time) # time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data)) # result = asyncio.run(process_retrieval_from_knowledge_base(data))
func_results.append({"function": func['name'], "result": result}) # func_results.append({"function": func['name'], "result": result})
# 格式化结果 # # 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" # formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result) # formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material': elif func.get("name") == 'generate_material':
# 规范化参数 # 规范化参数
@@ -438,30 +293,30 @@ def worker(data, output_file_path):
# 规范化参数 # 规范化参数
normalized_args = normalize_material_args(arguments_data) normalized_args = normalize_material_args(arguments_data)
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}") # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}") # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# 优先使用mattergen函数 # 优先使用mattergen函数
try: try:
output = asyncio.run(generate_material(**normalized_args)) # output = asyncio.run(generate_material(**normalized_args))
output = generate_material(**normalized_args)
# 添加延迟,模拟额外的工具函数调用 # 添加延迟,模拟额外的工具函数调用
# 随机延迟5-10秒 # 随机延迟5-10秒
delay_time = random.uniform(5, 10) # delay_time = random.uniform(5, 10)
print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}") # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
time.sleep(delay_time) # time.sleep(delay_time)
# 模拟其他工具函数调用的日志输出 # # 模拟其他工具函数调用的日志输出
print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}") # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
time.sleep(0.5) # time.sleep(0.5)
print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}") # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
time.sleep(0.5) # time.sleep(0.5)
print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}") # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
time.sleep(0.5) # time.sleep(0.5)
print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}") # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
except Exception as e: except Exception as e:
print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}") print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
@@ -478,14 +333,15 @@ def worker(data, output_file_path):
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}") print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
else: else:
delay_time = random.uniform(1, 5) # delay_time = random.uniform(5, 10)
time.sleep(delay_time) # time.sleep(delay_time)
result = asyncio.run(execute_tool_from_dict(func)) pass
func_results.append({"function": func['name'], "result": result}) # result = asyncio.run(execute_tool_from_dict(func))
# 格式化结果 # func_results.append({"function": func['name'], "result": result})
func_name = func.get("name") # # 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]" # func_name = func.get("name")
formatted_results.append(formatted_result) # formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
# formatted_results.append(formatted_result)
# 将所有格式化后的结果连接起来 # 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results) final_result = "\n\n\n".join(formatted_results)
@@ -557,8 +413,8 @@ if __name__ == '__main__':
print(len(datas)) print(len(datas))
# print() # print()
output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" output_file = f"./filter_ok_questions_solutions_agent_mattergen_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=16) main(datas, output_file, max_workers=1)
# 示例1使用正确的JSON格式 # 示例1使用正确的JSON格式
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}' # argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'

423
execute_tool_other_tools.py Normal file
View File

@@ -0,0 +1,423 @@
import json
import asyncio
import concurrent.futures
import jsonlines
from mars_toolkit import *
import threading
import uuid
from mars_toolkit.compute.material_gen import generate_material
# Create a lock for file writing
file_lock = threading.Lock()
from mysql.connector import pooling
from colorama import Fore, Back, Style, init
import time
import random
# 初始化colorama
init(autoreset=True)
from typing import Dict, Union, Any, Optional
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化传递给generate_material函数的参数格式。
处理以下情况:
1. properties参数可能是字符串形式的JSON需要解析为字典
2. properties中的值可能需要转换为适当的类型数字或字符串
3. 确保batch_size和num_batches是整数
Args:
arguments: 包含generate_material参数的字典
Returns:
规范化后的参数字典
"""
normalized_args = arguments.copy()
# 处理properties参数
if "properties" in normalized_args:
properties = normalized_args["properties"]
# 如果properties是字符串尝试解析为JSON
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError as e:
raise ValueError(f"无法解析properties JSON字符串: {e}")
# 确保properties是字典
if not isinstance(properties, dict):
raise ValueError(f"properties必须是字典或JSON字符串而不是 {type(properties)}")
# 处理properties中的值
normalized_properties = {}
for key, value in properties.items():
# 处理范围值,例如 "0.0-2.0" 或 "40-50"
if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"):
# 保持范围值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith(">"):
# 保持大于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith("<"):
# 保持小于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.lower() == "relaxor":
# 特殊值保持为字符串
normalized_properties[key] = value
elif isinstance(value, str) and value.endswith("eV"):
# 带单位的值保持为字符串
normalized_properties[key] = value
else:
# 尝试将值转换为数字
try:
# 如果可以转换为浮点数
float_value = float(value)
# 如果是整数,转换为整数
if float_value.is_integer():
normalized_properties[key] = int(float_value)
else:
normalized_properties[key] = float_value
except (ValueError, TypeError):
# 如果无法转换为数字,保持原值
normalized_properties[key] = value
normalized_args["properties"] = normalized_properties
# 确保batch_size和num_batches是整数
if "batch_size" in normalized_args:
try:
normalized_args["batch_size"] = int(normalized_args["batch_size"])
except (ValueError, TypeError):
raise ValueError(f"batch_size必须是整数而不是 {normalized_args['batch_size']}")
if "num_batches" in normalized_args:
try:
normalized_args["num_batches"] = int(normalized_args["num_batches"])
except (ValueError, TypeError):
raise ValueError(f"num_batches必须是整数而不是 {normalized_args['num_batches']}")
# 确保diffusion_guidance_factor是浮点数
if "diffusion_guidance_factor" in normalized_args:
try:
normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"])
except (ValueError, TypeError):
raise ValueError(f"diffusion_guidance_factor必须是数字而不是 {normalized_args['diffusion_guidance_factor']}")
return normalized_args
import requests
connection_pool = pooling.MySQLConnectionPool(
pool_name="mypool",
pool_size=32,
pool_reset_session=True,
host='localhost',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
async def process_retrieval_from_knowledge_base(data):
doi = data.get('doi')
mp_id = data.get('mp_id')
# 检查是否提供了至少一个查询参数
if doi is None and mp_id is None:
return "" # 如果没有提供查询参数,返回空字符串
# 构建SQL查询条件
query = "SELECT * FROM mp_synthesis_scheme_info WHERE "
params = []
if doi is not None and mp_id is not None:
query += "doi = %s OR mp_id = %s"
params = [doi, mp_id]
elif doi is not None:
query += "doi = %s"
params = [doi]
else: # mp_id is not None
query += "mp_id = %s"
params = [mp_id]
# 从数据库中查询匹配的记录
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
cursor.execute(query, params)
result = cursor.fetchone() # 获取第一个匹配的记录
finally:
cursor.close()
finally:
conn.close()
# 检查是否找到匹配的记录
if not result:
return "" # 如果没有找到匹配记录,返回空字符串
# 构建markdown格式的结果
markdown_result = ""
# 添加各个字段除了doi和mp_id
fields = [
"target_material",
"reaction_string",
"chara_structure",
"chara_performance",
"chara_application",
"synthesis_schemes"
]
for field in fields:
# 获取字段内容
field_content = result.get(field, "")
# 只有当字段内容不为空时才添加该字段
if field_content and field_content.strip():
markdown_result += f"\n## {field}\n{field_content}\n\n"
return markdown_result # 直接返回markdown文本
async def execute_tool_from_dict(input_dict: dict):
"""
从字典中提取工具函数名称和参数,并执行相应的工具函数
Args:
input_dict: 字典,例如:
{"name": "search_material_property_from_material_project",
"arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}
Returns:
工具函数的执行结果,如果工具函数不存在则返回错误信息
"""
try:
# 解析输入字符串为字典
# input_dict = json.loads(input_str)
# 提取函数名和参数
func_name = input_dict.get("name")
arguments_data = input_dict.get("arguments")
#print('func_name', func_name)
#print("argument", arguments_data)
if not func_name:
return {"status": "error", "message": "未提供函数名称"}
# 获取所有注册的工具函数
tools = get_tools()
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"}
# 获取对应的工具函数
tool_func = tools[func_name]
# 处理参数
arguments = {}
if arguments_data:
# 检查arguments是字符串还是字典
if isinstance(arguments_data, dict):
# 如果已经是字典,直接使用
arguments = arguments_data
elif isinstance(arguments_data, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(arguments_data)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": arguments_data}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
except json.JSONDecodeError as e:
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
except Exception as e:
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
def worker(data, output_file_path):
try:
func_contents = data["function_calls"]
func_results = []
formatted_results = [] # 新增一个列表来存储格式化后的结果
for func in func_contents:
func_name = func.get("name")
arguments_data = func.get("arguments")
# 使用富文本打印函数名
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# 使用富文本打印参数
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
pass
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
# # 规范化参数
# try:
# # 确保arguments_data是字典
# if isinstance(arguments_data, str):
# try:
# arguments_data = json.loads(arguments_data)
# except json.JSONDecodeError as e:
# print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
# continue
# # 规范化参数
# normalized_args = normalize_material_args(arguments_data)
# # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# # 优先使用mattergen函数
# try:
# # output = asyncio.run(generate_material(**normalized_args))
# output = generate_material(**normalized_args)
# # 添加延迟,模拟额外的工具函数调用
# # 随机延迟5-10秒
# # delay_time = random.uniform(5, 10)
# # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
# # time.sleep(delay_time)
# # # 模拟其他工具函数调用的日志输出
# # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
# # time.sleep(0.5)
# # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
# # time.sleep(0.5)
# # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
# # time.sleep(0.5)
# # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
# except Exception as e:
# print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
# # 将结果添加到func_results
# func_results.append({"function": func_name, "result": output})
# # 格式化结果
# formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
# formatted_results.append(formatted_result)
# except Exception as e:
# print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
# import traceback
# print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
pass
else:
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
func_name = func.get("name")
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['observation'] = final_result
# 使用富文本打印开始和结束标记
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
print(data['observation'])
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:
print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}"
def main(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
from mysql.connector import pooling, Error
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing papers")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_path = {}
for path in datas:
future = executor.submit(worker, path, output_file_path)
future_to_path[future] = path
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_path):
path = future_to_path[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {path} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
import jsonlines
datas = []
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
for obj in reader:
datas.append(obj)
print(len(datas))
# print()
output_file = f"./filter_ok_questions_solutions_agent_other_tools_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=32)
# 示例1使用正确的JSON格式
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
# argument = json.loads(argument)
# print(json.dumps(argument, indent=2))
# asyncio.run(mattergen(**argument))

View File

@@ -9,9 +9,18 @@ import asyncio
import zipfile import zipfile
import shutil import shutil
import re import re
import multiprocessing
from multiprocessing import Process, Queue
from pathlib import Path from pathlib import Path
from typing import Literal, Dict, Any, Tuple, Union, Optional, List from typing import Literal, Dict, Any, Tuple, Union, Optional, List
# 设置多进程启动方法为spawn解决CUDA初始化错误
try:
multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
# 如果已经设置过启动方法会抛出RuntimeError
pass
from ase.optimize import FIRE from ase.optimize import FIRE
from ase.filters import FrechetCellFilter from ase.filters import FrechetCellFilter
from ase.atoms import Atoms from ase.atoms import Atoms
@@ -33,6 +42,49 @@ from ..core.mattergen_wrapper import *
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _process_generate_material_worker(args_queue, result_queue):
"""
在新进程中处理材料生成的工作函数
Args:
args_queue: 包含生成参数的队列
result_queue: 用于返回结果的队列
"""
try:
# 配置日志
import logging
logger = logging.getLogger(__name__)
logger.info("子进程开始执行材料生成...")
# 从队列获取参数
args = args_queue.get()
logger.info(f"子进程获取到参数: {args}")
# 导入MatterGenService
from mars_toolkit.services.mattergen_service import MatterGenService
logger.info("子进程成功导入MatterGenService")
# 获取MatterGenService实例
service = MatterGenService.get_instance()
logger.info("子进程成功获取MatterGenService实例")
# 使用服务生成材料
logger.info("子进程开始调用generate方法...")
result = service.generate(**args)
logger.info("子进程generate方法调用完成")
# 将结果放入结果队列
result_queue.put(result)
logger.info("子进程材料生成完成,结果已放入队列")
except Exception as e:
# 如果发生错误,将错误信息放入结果队列
import traceback
error_msg = f"材料生成过程中出错: {str(e)}\n{traceback.format_exc()}"
import logging
logging.getLogger(__name__).error(error_msg)
result_queue.put(f"Error: {error_msg}")
def format_cif_content(content): def format_cif_content(content):
""" """
Format CIF content by removing unnecessary headers and organizing each CIF file. Format CIF content by removing unnecessary headers and organizing each CIF file.
@@ -233,7 +285,7 @@ def main(
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints") @llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
async def generate_material( def generate_material(
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None, properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
batch_size: int = 2, batch_size: int = 2,
num_batches: int = 1, num_batches: int = 1,
@@ -260,16 +312,45 @@ async def generate_material(
Returns: Returns:
Descriptive text with generated crystal structures in CIF format Descriptive text with generated crystal structures in CIF format
""" """
# # 创建队列用于进程间通信
# args_queue = Queue()
# result_queue = Queue()
# # 将参数放入队列
# args_queue.put({
# "properties": properties,
# "batch_size": batch_size,
# "num_batches": num_batches,
# "diffusion_guidance_factor": diffusion_guidance_factor
# })
# # 创建并启动新进程
# logger.info("启动新进程处理材料生成...")
# p = Process(target=_process_generate_material_worker, args=(args_queue, result_queue))
# p.start()
# # 等待进程完成并获取结果
# p.join()
# result = result_queue.get()
# # 检查结果是否为错误信息
# if isinstance(result, str) and result.startswith("Error:"):
# # 记录错误日志
# logger.error(result)
# 导入MatterGenService # 导入MatterGenService
from mars_toolkit.services.mattergen_service import MatterGenService from mars_toolkit.services.mattergen_service import MatterGenService
logger.info("子进程成功导入MatterGenService")
# 获取MatterGenService实例 # 获取MatterGenService实例
service = MatterGenService.get_instance() service = MatterGenService.get_instance()
logger.info("子进程成功获取MatterGenService实例")
# 使用服务生成材料 # 使用服务生成材料
return service.generate( logger.info("子进程开始调用generate方法...")
properties=properties, result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
batch_size=batch_size, logger.info("子进程generate方法调用完成")
num_batches=num_batches, if "Error generating structures" in result:
diffusion_guidance_factor=diffusion_guidance_factor return f"Error: Invalid properties {properties}."
) else:
return result

View File

@@ -35,7 +35,7 @@ class Config:
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA' DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
# Searxng # Searxng
SEARXNG_HOST="http://192.168.191.101:40032/" SEARXNG_HOST="http://192.168.168.1:40032/"
# Visualization # Visualization
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization' VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'

View File

@@ -5,6 +5,7 @@ This module provides functions for searching information on the web.
""" """
import asyncio import asyncio
import os
from typing import Annotated, Dict, Any, List from typing import Annotated, Dict, Any, List
from langchain_community.utilities import SearxSearchWrapper from langchain_community.utilities import SearxSearchWrapper
@@ -28,6 +29,8 @@ async def search_online(
Formatted string with search results (titles, snippets, links) Formatted string with search results (titles, snippets, links)
""" """
# 确保 num_results 是整数 # 确保 num_results 是整数
os.environ['HTTP_PROXY'] = ''
os.environ['HTTPS_PROXY'] = ''
try: try:
num_results = int(num_results) num_results = int(num_results)
except (TypeError, ValueError): except (TypeError, ValueError):

View File

@@ -62,7 +62,8 @@ async def test_tool(tool_name: str) -> str:
elif tool_name == "generate_material": elif tool_name == "generate_material":
from mars_toolkit.compute.material_gen import generate_material from mars_toolkit.compute.material_gen import generate_material
# 使用简单的属性约束进行测试 # 使用简单的属性约束进行测试
result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1) # result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
result = generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
elif tool_name == "fetch_chemical_composition_from_OQMD": elif tool_name == "fetch_chemical_composition_from_OQMD":
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
@@ -171,7 +172,7 @@ if __name__ == "__main__":
] ]
# 选择要测试的工具 # 选择要测试的工具
tool_name = tools_to_test[6] # 测试 search_online 工具 tool_name = tools_to_test[1] # 测试 search_online 工具
# 运行测试 # 运行测试
result = asyncio.run(test_tool(tool_name)) result = asyncio.run(test_tool(tool_name))