生成数据:mattergen改成了同步
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,3 +7,5 @@ pyproject.toml
|
|||||||
/pretrained_models
|
/pretrained_models
|
||||||
/mcp-python-sdk
|
/mcp-python-sdk
|
||||||
/.vscode
|
/.vscode
|
||||||
|
|
||||||
|
/*filter_ok_questions_solutions_agent*
|
||||||
|
|||||||
BIN
__pycache__/execute_tool_copy.cpython-310.pyc
Normal file
BIN
__pycache__/execute_tool_copy.cpython-310.pyc
Normal file
Binary file not shown.
@@ -6,6 +6,8 @@ import jsonlines
|
|||||||
from mars_toolkit import *
|
from mars_toolkit import *
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from mars_toolkit.compute.material_gen import generate_material
|
||||||
# Create a lock for file writing
|
# Create a lock for file writing
|
||||||
file_lock = threading.Lock()
|
file_lock = threading.Lock()
|
||||||
from mysql.connector import pooling
|
from mysql.connector import pooling
|
||||||
@@ -180,153 +182,6 @@ async def process_retrieval_from_knowledge_base(data):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def mattergen(
|
|
||||||
properties=None,
|
|
||||||
batch_size=2,
|
|
||||||
num_batches=1,
|
|
||||||
diffusion_guidance_factor=2.0
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
调用MatterGen服务生成晶体结构
|
|
||||||
|
|
||||||
Args:
|
|
||||||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
|
||||||
batch_size: 每批生成的结构数量
|
|
||||||
num_batches: 批次数量
|
|
||||||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
生成的结构内容或错误信息
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 导入MatterGenService
|
|
||||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
|
||||||
|
|
||||||
# 获取MatterGenService实例
|
|
||||||
service = MatterGenService.get_instance()
|
|
||||||
|
|
||||||
# 使用服务生成材料
|
|
||||||
result = await service.generate(
|
|
||||||
properties=properties,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_batches=num_batches,
|
|
||||||
diffusion_guidance_factor=diffusion_guidance_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.error(f"Error in mattergen: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return f"Error generating material: {str(e)}"
|
|
||||||
|
|
||||||
async def generate_material(
|
|
||||||
url="http://localhost:8051/generate_material",
|
|
||||||
properties=None,
|
|
||||||
batch_size=2,
|
|
||||||
num_batches=1,
|
|
||||||
diffusion_guidance_factor=2.0
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
调用MatterGen API生成晶体结构
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: API端点URL
|
|
||||||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
|
||||||
batch_size: 每批生成的结构数量
|
|
||||||
num_batches: 批次数量
|
|
||||||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
生成的结构内容或错误信息
|
|
||||||
"""
|
|
||||||
# 尝试使用本地MatterGen服务
|
|
||||||
try:
|
|
||||||
print("尝试使用本地MatterGen服务...")
|
|
||||||
result = await mattergen(
|
|
||||||
properties=properties,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_batches=num_batches,
|
|
||||||
diffusion_guidance_factor=diffusion_guidance_factor
|
|
||||||
)
|
|
||||||
if result and not result.startswith("Error"):
|
|
||||||
print("本地MatterGen服务生成成功!")
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
print(f"本地MatterGen服务生成失败,尝试使用API: {result}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"本地MatterGen服务出错,尝试使用API: {str(e)}")
|
|
||||||
|
|
||||||
# 如果本地服务失败,回退到API调用
|
|
||||||
# 规范化参数
|
|
||||||
normalized_args = normalize_material_args({
|
|
||||||
"properties": properties,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"num_batches": num_batches,
|
|
||||||
"diffusion_guidance_factor": diffusion_guidance_factor
|
|
||||||
})
|
|
||||||
|
|
||||||
# 构建请求负载
|
|
||||||
payload = {
|
|
||||||
"properties": normalized_args["properties"],
|
|
||||||
"batch_size": normalized_args["batch_size"],
|
|
||||||
"num_batches": normalized_args["num_batches"],
|
|
||||||
"diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"]
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"发送请求到 {url}")
|
|
||||||
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 添加headers参数,包含accept头
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"accept": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 打印完整请求信息(调试用)
|
|
||||||
print(f"完整请求URL: {url}")
|
|
||||||
print(f"请求头: {json.dumps(headers, indent=2)}")
|
|
||||||
print(f"请求体: {json.dumps(payload, indent=2)}")
|
|
||||||
|
|
||||||
# 禁用代理设置
|
|
||||||
proxies = {
|
|
||||||
"http": None,
|
|
||||||
"https": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送POST请求,添加headers参数,禁用代理,增加超时时间
|
|
||||||
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
|
|
||||||
|
|
||||||
# 打印响应信息(调试用)
|
|
||||||
print(f"响应状态码: {response.status_code}")
|
|
||||||
print(f"响应头: {dict(response.headers)}")
|
|
||||||
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长
|
|
||||||
|
|
||||||
# 检查响应状态
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
print("\n生成成功!")
|
|
||||||
return result["content"]
|
|
||||||
else:
|
|
||||||
print(f"\n生成失败: {result['message']}")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
print(f"\n请求失败,状态码: {response.status_code}")
|
|
||||||
print(f"响应内容: {response.text}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n发生错误: {str(e)}")
|
|
||||||
print(f"错误类型: {type(e).__name__}")
|
|
||||||
import traceback
|
|
||||||
print(f"错误堆栈: {traceback.format_exc()}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def execute_tool_from_dict(input_dict: dict):
|
async def execute_tool_from_dict(input_dict: dict):
|
||||||
"""
|
"""
|
||||||
从字典中提取工具函数名称和参数,并执行相应的工具函数
|
从字典中提取工具函数名称和参数,并执行相应的工具函数
|
||||||
@@ -416,14 +271,14 @@ def worker(data, output_file_path):
|
|||||||
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
||||||
|
|
||||||
if func.get("name") == 'retrieval_from_knowledge_base':
|
if func.get("name") == 'retrieval_from_knowledge_base':
|
||||||
delay_time = random.uniform(1, 5)
|
pass
|
||||||
|
# delay_time = random.uniform(5, 10)
|
||||||
time.sleep(delay_time)
|
# time.sleep(delay_time)
|
||||||
result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
# result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
||||||
func_results.append({"function": func['name'], "result": result})
|
# func_results.append({"function": func['name'], "result": result})
|
||||||
# 格式化结果
|
# # 格式化结果
|
||||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||||
formatted_results.append(formatted_result)
|
# formatted_results.append(formatted_result)
|
||||||
|
|
||||||
elif func.get("name") == 'generate_material':
|
elif func.get("name") == 'generate_material':
|
||||||
# 规范化参数
|
# 规范化参数
|
||||||
@@ -438,30 +293,30 @@ def worker(data, output_file_path):
|
|||||||
|
|
||||||
# 规范化参数
|
# 规范化参数
|
||||||
normalized_args = normalize_material_args(arguments_data)
|
normalized_args = normalize_material_args(arguments_data)
|
||||||
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
# print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||||
print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
# print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||||
print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
# print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||||
|
|
||||||
# 优先使用mattergen函数
|
# 优先使用mattergen函数
|
||||||
try:
|
try:
|
||||||
output = asyncio.run(generate_material(**normalized_args))
|
# output = asyncio.run(generate_material(**normalized_args))
|
||||||
|
output = generate_material(**normalized_args)
|
||||||
|
|
||||||
# 添加延迟,模拟额外的工具函数调用
|
# 添加延迟,模拟额外的工具函数调用
|
||||||
|
|
||||||
|
|
||||||
# 随机延迟5-10秒
|
# 随机延迟5-10秒
|
||||||
delay_time = random.uniform(5, 10)
|
# delay_time = random.uniform(5, 10)
|
||||||
print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
|
# print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
|
||||||
time.sleep(delay_time)
|
# time.sleep(delay_time)
|
||||||
|
|
||||||
# 模拟其他工具函数调用的日志输出
|
# # 模拟其他工具函数调用的日志输出
|
||||||
print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
|
# print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
|
||||||
time.sleep(0.5)
|
# time.sleep(0.5)
|
||||||
print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
|
# print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
|
||||||
time.sleep(0.5)
|
# time.sleep(0.5)
|
||||||
print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
|
# print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
|
||||||
time.sleep(0.5)
|
# time.sleep(0.5)
|
||||||
print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
|
# print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
||||||
@@ -478,14 +333,15 @@ def worker(data, output_file_path):
|
|||||||
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
delay_time = random.uniform(1, 5)
|
# delay_time = random.uniform(5, 10)
|
||||||
time.sleep(delay_time)
|
# time.sleep(delay_time)
|
||||||
result = asyncio.run(execute_tool_from_dict(func))
|
pass
|
||||||
func_results.append({"function": func['name'], "result": result})
|
# result = asyncio.run(execute_tool_from_dict(func))
|
||||||
# 格式化结果
|
# func_results.append({"function": func['name'], "result": result})
|
||||||
func_name = func.get("name")
|
# # 格式化结果
|
||||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
# func_name = func.get("name")
|
||||||
formatted_results.append(formatted_result)
|
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||||
|
# formatted_results.append(formatted_result)
|
||||||
|
|
||||||
# 将所有格式化后的结果连接起来
|
# 将所有格式化后的结果连接起来
|
||||||
final_result = "\n\n\n".join(formatted_results)
|
final_result = "\n\n\n".join(formatted_results)
|
||||||
@@ -557,8 +413,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
print(len(datas))
|
print(len(datas))
|
||||||
# print()
|
# print()
|
||||||
output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
output_file = f"./filter_ok_questions_solutions_agent_mattergen_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||||
main(datas, output_file, max_workers=16)
|
main(datas, output_file, max_workers=1)
|
||||||
|
|
||||||
# 示例1:使用正确的JSON格式
|
# 示例1:使用正确的JSON格式
|
||||||
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
|
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
|
||||||
|
|||||||
423
execute_tool_other_tools.py
Normal file
423
execute_tool_other_tools.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
from mars_toolkit import *
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from mars_toolkit.compute.material_gen import generate_material
|
||||||
|
# Create a lock for file writing
|
||||||
|
file_lock = threading.Lock()
|
||||||
|
from mysql.connector import pooling
|
||||||
|
from colorama import Fore, Back, Style, init
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
# 初始化colorama
|
||||||
|
init(autoreset=True)
|
||||||
|
|
||||||
|
from typing import Dict, Union, Any, Optional
|
||||||
|
|
||||||
|
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
规范化传递给generate_material函数的参数格式。
|
||||||
|
|
||||||
|
处理以下情况:
|
||||||
|
1. properties参数可能是字符串形式的JSON,需要解析为字典
|
||||||
|
2. properties中的值可能需要转换为适当的类型(数字或字符串)
|
||||||
|
3. 确保batch_size和num_batches是整数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: 包含generate_material参数的字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
规范化后的参数字典
|
||||||
|
"""
|
||||||
|
normalized_args = arguments.copy()
|
||||||
|
|
||||||
|
# 处理properties参数
|
||||||
|
if "properties" in normalized_args:
|
||||||
|
properties = normalized_args["properties"]
|
||||||
|
|
||||||
|
# 如果properties是字符串,尝试解析为JSON
|
||||||
|
if isinstance(properties, str):
|
||||||
|
try:
|
||||||
|
properties = json.loads(properties)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"无法解析properties JSON字符串: {e}")
|
||||||
|
|
||||||
|
# 确保properties是字典
|
||||||
|
if not isinstance(properties, dict):
|
||||||
|
raise ValueError(f"properties必须是字典或JSON字符串,而不是 {type(properties)}")
|
||||||
|
|
||||||
|
# 处理properties中的值
|
||||||
|
normalized_properties = {}
|
||||||
|
for key, value in properties.items():
|
||||||
|
# 处理范围值,例如 "0.0-2.0" 或 "40-50"
|
||||||
|
if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"):
|
||||||
|
# 保持范围值为字符串格式
|
||||||
|
normalized_properties[key] = value
|
||||||
|
elif isinstance(value, str) and value.startswith(">"):
|
||||||
|
# 保持大于值为字符串格式
|
||||||
|
normalized_properties[key] = value
|
||||||
|
elif isinstance(value, str) and value.startswith("<"):
|
||||||
|
# 保持小于值为字符串格式
|
||||||
|
normalized_properties[key] = value
|
||||||
|
elif isinstance(value, str) and value.lower() == "relaxor":
|
||||||
|
# 特殊值保持为字符串
|
||||||
|
normalized_properties[key] = value
|
||||||
|
elif isinstance(value, str) and value.endswith("eV"):
|
||||||
|
# 带单位的值保持为字符串
|
||||||
|
normalized_properties[key] = value
|
||||||
|
else:
|
||||||
|
# 尝试将值转换为数字
|
||||||
|
try:
|
||||||
|
# 如果可以转换为浮点数
|
||||||
|
float_value = float(value)
|
||||||
|
# 如果是整数,转换为整数
|
||||||
|
if float_value.is_integer():
|
||||||
|
normalized_properties[key] = int(float_value)
|
||||||
|
else:
|
||||||
|
normalized_properties[key] = float_value
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# 如果无法转换为数字,保持原值
|
||||||
|
normalized_properties[key] = value
|
||||||
|
|
||||||
|
normalized_args["properties"] = normalized_properties
|
||||||
|
|
||||||
|
# 确保batch_size和num_batches是整数
|
||||||
|
if "batch_size" in normalized_args:
|
||||||
|
try:
|
||||||
|
normalized_args["batch_size"] = int(normalized_args["batch_size"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise ValueError(f"batch_size必须是整数,而不是 {normalized_args['batch_size']}")
|
||||||
|
|
||||||
|
if "num_batches" in normalized_args:
|
||||||
|
try:
|
||||||
|
normalized_args["num_batches"] = int(normalized_args["num_batches"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise ValueError(f"num_batches必须是整数,而不是 {normalized_args['num_batches']}")
|
||||||
|
|
||||||
|
# 确保diffusion_guidance_factor是浮点数
|
||||||
|
if "diffusion_guidance_factor" in normalized_args:
|
||||||
|
try:
|
||||||
|
normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise ValueError(f"diffusion_guidance_factor必须是数字,而不是 {normalized_args['diffusion_guidance_factor']}")
|
||||||
|
|
||||||
|
return normalized_args
|
||||||
|
|
||||||
|
|
||||||
|
import requests
|
||||||
|
connection_pool = pooling.MySQLConnectionPool(
|
||||||
|
pool_name="mypool",
|
||||||
|
pool_size=32,
|
||||||
|
pool_reset_session=True,
|
||||||
|
host='localhost',
|
||||||
|
user='metadata_mat_papers',
|
||||||
|
password='siat-mic',
|
||||||
|
database='metadata_mat_papers'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process_retrieval_from_knowledge_base(data):
|
||||||
|
doi = data.get('doi')
|
||||||
|
mp_id = data.get('mp_id')
|
||||||
|
|
||||||
|
# 检查是否提供了至少一个查询参数
|
||||||
|
if doi is None and mp_id is None:
|
||||||
|
return "" # 如果没有提供查询参数,返回空字符串
|
||||||
|
|
||||||
|
# 构建SQL查询条件
|
||||||
|
query = "SELECT * FROM mp_synthesis_scheme_info WHERE "
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if doi is not None and mp_id is not None:
|
||||||
|
query += "doi = %s OR mp_id = %s"
|
||||||
|
params = [doi, mp_id]
|
||||||
|
elif doi is not None:
|
||||||
|
query += "doi = %s"
|
||||||
|
params = [doi]
|
||||||
|
else: # mp_id is not None
|
||||||
|
query += "mp_id = %s"
|
||||||
|
params = [mp_id]
|
||||||
|
|
||||||
|
# 从数据库中查询匹配的记录
|
||||||
|
conn = connection_pool.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
try:
|
||||||
|
cursor.execute(query, params)
|
||||||
|
result = cursor.fetchone() # 获取第一个匹配的记录
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# 检查是否找到匹配的记录
|
||||||
|
if not result:
|
||||||
|
return "" # 如果没有找到匹配记录,返回空字符串
|
||||||
|
|
||||||
|
# 构建markdown格式的结果
|
||||||
|
markdown_result = ""
|
||||||
|
|
||||||
|
# 添加各个字段(除了doi和mp_id)
|
||||||
|
fields = [
|
||||||
|
"target_material",
|
||||||
|
"reaction_string",
|
||||||
|
"chara_structure",
|
||||||
|
"chara_performance",
|
||||||
|
"chara_application",
|
||||||
|
"synthesis_schemes"
|
||||||
|
]
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
# 获取字段内容
|
||||||
|
field_content = result.get(field, "")
|
||||||
|
# 只有当字段内容不为空时才添加该字段
|
||||||
|
if field_content and field_content.strip():
|
||||||
|
markdown_result += f"\n## {field}\n{field_content}\n\n"
|
||||||
|
|
||||||
|
return markdown_result # 直接返回markdown文本
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_tool_from_dict(input_dict: dict):
|
||||||
|
"""
|
||||||
|
从字典中提取工具函数名称和参数,并执行相应的工具函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dict: 字典,例如:
|
||||||
|
{"name": "search_material_property_from_material_project",
|
||||||
|
"arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具函数的执行结果,如果工具函数不存在则返回错误信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 解析输入字符串为字典
|
||||||
|
# input_dict = json.loads(input_str)
|
||||||
|
|
||||||
|
# 提取函数名和参数
|
||||||
|
func_name = input_dict.get("name")
|
||||||
|
arguments_data = input_dict.get("arguments")
|
||||||
|
#print('func_name', func_name)
|
||||||
|
#print("argument", arguments_data)
|
||||||
|
if not func_name:
|
||||||
|
return {"status": "error", "message": "未提供函数名称"}
|
||||||
|
|
||||||
|
# 获取所有注册的工具函数
|
||||||
|
tools = get_tools()
|
||||||
|
|
||||||
|
# 检查函数名是否存在于工具函数字典中
|
||||||
|
if func_name not in tools:
|
||||||
|
return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"}
|
||||||
|
|
||||||
|
# 获取对应的工具函数
|
||||||
|
tool_func = tools[func_name]
|
||||||
|
|
||||||
|
# 处理参数
|
||||||
|
arguments = {}
|
||||||
|
if arguments_data:
|
||||||
|
# 检查arguments是字符串还是字典
|
||||||
|
if isinstance(arguments_data, dict):
|
||||||
|
# 如果已经是字典,直接使用
|
||||||
|
arguments = arguments_data
|
||||||
|
elif isinstance(arguments_data, str):
|
||||||
|
# 如果是字符串,尝试解析为JSON
|
||||||
|
try:
|
||||||
|
# 尝试直接解析为JSON对象
|
||||||
|
arguments = json.loads(arguments_data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 如果解析失败,可能是因为字符串中包含转义字符
|
||||||
|
# 尝试修复常见的JSON字符串问题
|
||||||
|
fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\')
|
||||||
|
try:
|
||||||
|
arguments = json.loads(fixed_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 如果仍然失败,尝试将字符串作为原始字符串处理
|
||||||
|
arguments = {"raw_string": arguments_data}
|
||||||
|
|
||||||
|
# 调用工具函数
|
||||||
|
if asyncio.iscoroutinefunction(tool_func):
|
||||||
|
# 如果是异步函数,使用await调用
|
||||||
|
result = await tool_func(**arguments)
|
||||||
|
else:
|
||||||
|
# 如果是同步函数,直接调用
|
||||||
|
result = tool_func(**arguments)
|
||||||
|
# if func_name=='generate_material':
|
||||||
|
# print("xxxxx",result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
|
def worker(data, output_file_path):
|
||||||
|
try:
|
||||||
|
func_contents = data["function_calls"]
|
||||||
|
func_results = []
|
||||||
|
formatted_results = [] # 新增一个列表来存储格式化后的结果
|
||||||
|
for func in func_contents:
|
||||||
|
func_name = func.get("name")
|
||||||
|
arguments_data = func.get("arguments")
|
||||||
|
|
||||||
|
# 使用富文本打印函数名
|
||||||
|
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
# 使用富文本打印参数
|
||||||
|
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
if func.get("name") == 'retrieval_from_knowledge_base':
|
||||||
|
pass
|
||||||
|
# delay_time = random.uniform(5, 10)
|
||||||
|
# time.sleep(delay_time)
|
||||||
|
result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
||||||
|
func_results.append({"function": func['name'], "result": result})
|
||||||
|
# 格式化结果
|
||||||
|
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||||
|
formatted_results.append(formatted_result)
|
||||||
|
|
||||||
|
elif func.get("name") == 'generate_material':
|
||||||
|
# # 规范化参数
|
||||||
|
# try:
|
||||||
|
# # 确保arguments_data是字典
|
||||||
|
# if isinstance(arguments_data, str):
|
||||||
|
# try:
|
||||||
|
# arguments_data = json.loads(arguments_data)
|
||||||
|
# except json.JSONDecodeError as e:
|
||||||
|
# print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# # 规范化参数
|
||||||
|
# normalized_args = normalize_material_args(arguments_data)
|
||||||
|
# # print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||||
|
# # print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||||
|
# # print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
# # 优先使用mattergen函数
|
||||||
|
# try:
|
||||||
|
# # output = asyncio.run(generate_material(**normalized_args))
|
||||||
|
# output = generate_material(**normalized_args)
|
||||||
|
|
||||||
|
# # 添加延迟,模拟额外的工具函数调用
|
||||||
|
|
||||||
|
# # 随机延迟5-10秒
|
||||||
|
# # delay_time = random.uniform(5, 10)
|
||||||
|
# # print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
|
||||||
|
# # time.sleep(delay_time)
|
||||||
|
|
||||||
|
# # # 模拟其他工具函数调用的日志输出
|
||||||
|
# # print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
|
||||||
|
# # time.sleep(0.5)
|
||||||
|
# # print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
|
||||||
|
# # time.sleep(0.5)
|
||||||
|
# # print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
|
||||||
|
# # time.sleep(0.5)
|
||||||
|
# # print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
# # 将结果添加到func_results
|
||||||
|
# func_results.append({"function": func_name, "result": output})
|
||||||
|
|
||||||
|
# # 格式化结果
|
||||||
|
# formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
|
||||||
|
# formatted_results.append(formatted_result)
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
|
||||||
|
# import traceback
|
||||||
|
# print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# delay_time = random.uniform(5, 10)
|
||||||
|
# time.sleep(delay_time)
|
||||||
|
|
||||||
|
result = asyncio.run(execute_tool_from_dict(func))
|
||||||
|
func_results.append({"function": func['name'], "result": result})
|
||||||
|
# 格式化结果
|
||||||
|
func_name = func.get("name")
|
||||||
|
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||||
|
formatted_results.append(formatted_result)
|
||||||
|
|
||||||
|
# 将所有格式化后的结果连接起来
|
||||||
|
final_result = "\n\n\n".join(formatted_results)
|
||||||
|
data['observation'] = final_result
|
||||||
|
|
||||||
|
# 使用富文本打印开始和结束标记
|
||||||
|
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
|
||||||
|
print(data['observation'])
|
||||||
|
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
|
||||||
|
with file_lock:
|
||||||
|
with jsonlines.open(output_file_path, mode='a') as writer:
|
||||||
|
writer.write(data) # observation . data
|
||||||
|
return f"Processed successfully"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
|
||||||
|
return f"Error processing: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def main(datas, output_file_path, max_workers=1):
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
from mysql.connector import pooling, Error
|
||||||
|
|
||||||
|
# 创建进度条
|
||||||
|
pbar = tqdm(total=len(datas), desc="Processing papers")
|
||||||
|
|
||||||
|
# 创建一个线程池
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# 提交任务到执行器
|
||||||
|
future_to_path = {}
|
||||||
|
for path in datas:
|
||||||
|
future = executor.submit(worker, path, output_file_path)
|
||||||
|
future_to_path[future] = path
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
completed = 0
|
||||||
|
failed = 0
|
||||||
|
for future in concurrent.futures.as_completed(future_to_path):
|
||||||
|
path = future_to_path[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
if "successfully" in result:
|
||||||
|
completed += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
# 更新进度条
|
||||||
|
pbar.update(1)
|
||||||
|
# 每100个文件更新一次统计信息
|
||||||
|
if (completed + failed) % 100 == 0:
|
||||||
|
pbar.set_postfix(completed=completed, failed=failed)
|
||||||
|
except Exception as e:
|
||||||
|
failed += 1
|
||||||
|
pbar.update(1)
|
||||||
|
print(f"\nWorker for {path} generated an exception: {e}")
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import datetime
|
||||||
|
import jsonlines
|
||||||
|
datas = []
|
||||||
|
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
|
||||||
|
for obj in reader:
|
||||||
|
datas.append(obj)
|
||||||
|
|
||||||
|
print(len(datas))
|
||||||
|
# print()
|
||||||
|
output_file = f"./filter_ok_questions_solutions_agent_other_tools_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||||
|
main(datas, output_file, max_workers=32)
|
||||||
|
|
||||||
|
# 示例1:使用正确的JSON格式
|
||||||
|
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
|
||||||
|
# argument = json.loads(argument)
|
||||||
|
# print(json.dumps(argument, indent=2))
|
||||||
|
# asyncio.run(mattergen(**argument))
|
||||||
Binary file not shown.
Binary file not shown.
@@ -9,9 +9,18 @@ import asyncio
|
|||||||
import zipfile
|
import zipfile
|
||||||
import shutil
|
import shutil
|
||||||
import re
|
import re
|
||||||
|
import multiprocessing
|
||||||
|
from multiprocessing import Process, Queue
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
from typing import Literal, Dict, Any, Tuple, Union, Optional, List
|
||||||
|
|
||||||
|
# 设置多进程启动方法为spawn,解决CUDA初始化错误
|
||||||
|
try:
|
||||||
|
multiprocessing.set_start_method('spawn', force=True)
|
||||||
|
except RuntimeError:
|
||||||
|
# 如果已经设置过启动方法,会抛出RuntimeError
|
||||||
|
pass
|
||||||
|
|
||||||
from ase.optimize import FIRE
|
from ase.optimize import FIRE
|
||||||
from ase.filters import FrechetCellFilter
|
from ase.filters import FrechetCellFilter
|
||||||
from ase.atoms import Atoms
|
from ase.atoms import Atoms
|
||||||
@@ -33,6 +42,49 @@ from ..core.mattergen_wrapper import *
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _process_generate_material_worker(args_queue, result_queue):
|
||||||
|
"""
|
||||||
|
在新进程中处理材料生成的工作函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_queue: 包含生成参数的队列
|
||||||
|
result_queue: 用于返回结果的队列
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 配置日志
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info("子进程开始执行材料生成...")
|
||||||
|
|
||||||
|
# 从队列获取参数
|
||||||
|
args = args_queue.get()
|
||||||
|
logger.info(f"子进程获取到参数: {args}")
|
||||||
|
|
||||||
|
# 导入MatterGenService
|
||||||
|
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||||
|
logger.info("子进程成功导入MatterGenService")
|
||||||
|
|
||||||
|
# 获取MatterGenService实例
|
||||||
|
service = MatterGenService.get_instance()
|
||||||
|
logger.info("子进程成功获取MatterGenService实例")
|
||||||
|
|
||||||
|
# 使用服务生成材料
|
||||||
|
logger.info("子进程开始调用generate方法...")
|
||||||
|
result = service.generate(**args)
|
||||||
|
logger.info("子进程generate方法调用完成")
|
||||||
|
|
||||||
|
# 将结果放入结果队列
|
||||||
|
result_queue.put(result)
|
||||||
|
logger.info("子进程材料生成完成,结果已放入队列")
|
||||||
|
except Exception as e:
|
||||||
|
# 如果发生错误,将错误信息放入结果队列
|
||||||
|
import traceback
|
||||||
|
error_msg = f"材料生成过程中出错: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).error(error_msg)
|
||||||
|
result_queue.put(f"Error: {error_msg}")
|
||||||
|
|
||||||
|
|
||||||
def format_cif_content(content):
|
def format_cif_content(content):
|
||||||
"""
|
"""
|
||||||
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
Format CIF content by removing unnecessary headers and organizing each CIF file.
|
||||||
@@ -233,7 +285,7 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
|
@llm_tool(name="generate_material", description="Generate crystal structures with optional property constraints")
|
||||||
async def generate_material(
|
def generate_material(
|
||||||
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
properties: Optional[Dict[str, Union[float, str, Dict[str, Union[float, str]]]]] = None,
|
||||||
batch_size: int = 2,
|
batch_size: int = 2,
|
||||||
num_batches: int = 1,
|
num_batches: int = 1,
|
||||||
@@ -260,16 +312,45 @@ async def generate_material(
|
|||||||
Returns:
|
Returns:
|
||||||
Descriptive text with generated crystal structures in CIF format
|
Descriptive text with generated crystal structures in CIF format
|
||||||
"""
|
"""
|
||||||
|
# # 创建队列用于进程间通信
|
||||||
|
# args_queue = Queue()
|
||||||
|
# result_queue = Queue()
|
||||||
|
|
||||||
|
# # 将参数放入队列
|
||||||
|
# args_queue.put({
|
||||||
|
# "properties": properties,
|
||||||
|
# "batch_size": batch_size,
|
||||||
|
# "num_batches": num_batches,
|
||||||
|
# "diffusion_guidance_factor": diffusion_guidance_factor
|
||||||
|
# })
|
||||||
|
|
||||||
|
# # 创建并启动新进程
|
||||||
|
# logger.info("启动新进程处理材料生成...")
|
||||||
|
# p = Process(target=_process_generate_material_worker, args=(args_queue, result_queue))
|
||||||
|
# p.start()
|
||||||
|
|
||||||
|
# # 等待进程完成并获取结果
|
||||||
|
# p.join()
|
||||||
|
# result = result_queue.get()
|
||||||
|
|
||||||
|
# # 检查结果是否为错误信息
|
||||||
|
# if isinstance(result, str) and result.startswith("Error:"):
|
||||||
|
# # 记录错误日志
|
||||||
|
# logger.error(result)
|
||||||
|
|
||||||
# 导入MatterGenService
|
# 导入MatterGenService
|
||||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||||
|
logger.info("子进程成功导入MatterGenService")
|
||||||
|
|
||||||
# 获取MatterGenService实例
|
# 获取MatterGenService实例
|
||||||
service = MatterGenService.get_instance()
|
service = MatterGenService.get_instance()
|
||||||
|
logger.info("子进程成功获取MatterGenService实例")
|
||||||
|
|
||||||
# 使用服务生成材料
|
# 使用服务生成材料
|
||||||
return service.generate(
|
logger.info("子进程开始调用generate方法...")
|
||||||
properties=properties,
|
result = service.generate(properties=properties, batch_size=batch_size, num_batches=num_batches, diffusion_guidance_factor=diffusion_guidance_factor)
|
||||||
batch_size=batch_size,
|
logger.info("子进程generate方法调用完成")
|
||||||
num_batches=num_batches,
|
if "Error generating structures" in result:
|
||||||
diffusion_guidance_factor=diffusion_guidance_factor
|
return f"Error: Invalid properties {properties}."
|
||||||
)
|
else:
|
||||||
|
return result
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -35,7 +35,7 @@ class Config:
|
|||||||
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
|
DIFY_API_KEY = 'app-IKZrS1RqIyurPSzR73mz6XSA'
|
||||||
|
|
||||||
# Searxng
|
# Searxng
|
||||||
SEARXNG_HOST="http://192.168.191.101:40032/"
|
SEARXNG_HOST="http://192.168.168.1:40032/"
|
||||||
|
|
||||||
# Visualization
|
# Visualization
|
||||||
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
|
VIZ_CIF_OUTPUT_ROOT = '/home/ubuntu/50T/lzy/mars-mcp/outputs/cif_visualization'
|
||||||
|
|||||||
Binary file not shown.
@@ -5,6 +5,7 @@ This module provides functions for searching information on the web.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from typing import Annotated, Dict, Any, List
|
from typing import Annotated, Dict, Any, List
|
||||||
|
|
||||||
from langchain_community.utilities import SearxSearchWrapper
|
from langchain_community.utilities import SearxSearchWrapper
|
||||||
@@ -28,6 +29,8 @@ async def search_online(
|
|||||||
Formatted string with search results (titles, snippets, links)
|
Formatted string with search results (titles, snippets, links)
|
||||||
"""
|
"""
|
||||||
# 确保 num_results 是整数
|
# 确保 num_results 是整数
|
||||||
|
os.environ['HTTP_PROXY'] = ''
|
||||||
|
os.environ['HTTPS_PROXY'] = ''
|
||||||
try:
|
try:
|
||||||
num_results = int(num_results)
|
num_results = int(num_results)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
|
|||||||
Binary file not shown.
@@ -62,7 +62,8 @@ async def test_tool(tool_name: str) -> str:
|
|||||||
elif tool_name == "generate_material":
|
elif tool_name == "generate_material":
|
||||||
from mars_toolkit.compute.material_gen import generate_material
|
from mars_toolkit.compute.material_gen import generate_material
|
||||||
# 使用简单的属性约束进行测试
|
# 使用简单的属性约束进行测试
|
||||||
result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
|
# result = await generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
|
||||||
|
result = generate_material(properties={'dft_mag_density': 0.15}, batch_size=2, num_batches=1)
|
||||||
|
|
||||||
elif tool_name == "fetch_chemical_composition_from_OQMD":
|
elif tool_name == "fetch_chemical_composition_from_OQMD":
|
||||||
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
from mars_toolkit.query.oqmd_query import fetch_chemical_composition_from_OQMD
|
||||||
@@ -171,7 +172,7 @@ if __name__ == "__main__":
|
|||||||
]
|
]
|
||||||
|
|
||||||
# 选择要测试的工具
|
# 选择要测试的工具
|
||||||
tool_name = tools_to_test[6] # 测试 search_online 工具
|
tool_name = tools_to_test[1] # 测试 search_online 工具
|
||||||
|
|
||||||
# 运行测试
|
# 运行测试
|
||||||
result = asyncio.run(test_tool(tool_name))
|
result = asyncio.run(test_tool(tool_name))
|
||||||
|
|||||||
Reference in New Issue
Block a user