生成数据:mattergen改成了同步
This commit is contained in:
@@ -6,6 +6,8 @@ import jsonlines
|
||||
from mars_toolkit import *
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from mars_toolkit.compute.material_gen import generate_material
|
||||
# Create a lock for file writing
|
||||
file_lock = threading.Lock()
|
||||
from mysql.connector import pooling
|
||||
@@ -180,153 +182,6 @@ async def process_retrieval_from_knowledge_base(data):
|
||||
|
||||
|
||||
|
||||
async def mattergen(
|
||||
properties=None,
|
||||
batch_size=2,
|
||||
num_batches=1,
|
||||
diffusion_guidance_factor=2.0
|
||||
):
|
||||
"""
|
||||
调用MatterGen服务生成晶体结构
|
||||
|
||||
Args:
|
||||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
||||
batch_size: 每批生成的结构数量
|
||||
num_batches: 批次数量
|
||||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
||||
|
||||
Returns:
|
||||
生成的结构内容或错误信息
|
||||
"""
|
||||
try:
|
||||
# 导入MatterGenService
|
||||
from mars_toolkit.services.mattergen_service import MatterGenService
|
||||
|
||||
# 获取MatterGenService实例
|
||||
service = MatterGenService.get_instance()
|
||||
|
||||
# 使用服务生成材料
|
||||
result = await service.generate(
|
||||
properties=properties,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Error in mattergen: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return f"Error generating material: {str(e)}"
|
||||
|
||||
async def generate_material(
|
||||
url="http://localhost:8051/generate_material",
|
||||
properties=None,
|
||||
batch_size=2,
|
||||
num_batches=1,
|
||||
diffusion_guidance_factor=2.0
|
||||
):
|
||||
"""
|
||||
调用MatterGen API生成晶体结构
|
||||
|
||||
Args:
|
||||
url: API端点URL
|
||||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
||||
batch_size: 每批生成的结构数量
|
||||
num_batches: 批次数量
|
||||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
||||
|
||||
Returns:
|
||||
生成的结构内容或错误信息
|
||||
"""
|
||||
# 尝试使用本地MatterGen服务
|
||||
try:
|
||||
print("尝试使用本地MatterGen服务...")
|
||||
result = await mattergen(
|
||||
properties=properties,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
diffusion_guidance_factor=diffusion_guidance_factor
|
||||
)
|
||||
if result and not result.startswith("Error"):
|
||||
print("本地MatterGen服务生成成功!")
|
||||
return result
|
||||
else:
|
||||
print(f"本地MatterGen服务生成失败,尝试使用API: {result}")
|
||||
except Exception as e:
|
||||
print(f"本地MatterGen服务出错,尝试使用API: {str(e)}")
|
||||
|
||||
# 如果本地服务失败,回退到API调用
|
||||
# 规范化参数
|
||||
normalized_args = normalize_material_args({
|
||||
"properties": properties,
|
||||
"batch_size": batch_size,
|
||||
"num_batches": num_batches,
|
||||
"diffusion_guidance_factor": diffusion_guidance_factor
|
||||
})
|
||||
|
||||
# 构建请求负载
|
||||
payload = {
|
||||
"properties": normalized_args["properties"],
|
||||
"batch_size": normalized_args["batch_size"],
|
||||
"num_batches": normalized_args["num_batches"],
|
||||
"diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"]
|
||||
}
|
||||
|
||||
print(f"发送请求到 {url}")
|
||||
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
||||
|
||||
try:
|
||||
# 添加headers参数,包含accept头
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"accept": "application/json"
|
||||
}
|
||||
|
||||
# 打印完整请求信息(调试用)
|
||||
print(f"完整请求URL: {url}")
|
||||
print(f"请求头: {json.dumps(headers, indent=2)}")
|
||||
print(f"请求体: {json.dumps(payload, indent=2)}")
|
||||
|
||||
# 禁用代理设置
|
||||
proxies = {
|
||||
"http": None,
|
||||
"https": None
|
||||
}
|
||||
|
||||
# 发送POST请求,添加headers参数,禁用代理,增加超时时间
|
||||
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
|
||||
|
||||
# 打印响应信息(调试用)
|
||||
print(f"响应状态码: {response.status_code}")
|
||||
print(f"响应头: {dict(response.headers)}")
|
||||
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
||||
if result["success"]:
|
||||
print("\n生成成功!")
|
||||
return result["content"]
|
||||
else:
|
||||
print(f"\n生成失败: {result['message']}")
|
||||
return None
|
||||
else:
|
||||
print(f"\n请求失败,状态码: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n发生错误: {str(e)}")
|
||||
print(f"错误类型: {type(e).__name__}")
|
||||
import traceback
|
||||
print(f"错误堆栈: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def execute_tool_from_dict(input_dict: dict):
|
||||
"""
|
||||
从字典中提取工具函数名称和参数,并执行相应的工具函数
|
||||
@@ -416,14 +271,14 @@ def worker(data, output_file_path):
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
|
||||
|
||||
if func.get("name") == 'retrieval_from_knowledge_base':
|
||||
delay_time = random.uniform(1, 5)
|
||||
|
||||
time.sleep(delay_time)
|
||||
result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
||||
func_results.append({"function": func['name'], "result": result})
|
||||
# 格式化结果
|
||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
pass
|
||||
# delay_time = random.uniform(5, 10)
|
||||
# time.sleep(delay_time)
|
||||
# result = asyncio.run(process_retrieval_from_knowledge_base(data))
|
||||
# func_results.append({"function": func['name'], "result": result})
|
||||
# # 格式化结果
|
||||
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
# formatted_results.append(formatted_result)
|
||||
|
||||
elif func.get("name") == 'generate_material':
|
||||
# 规范化参数
|
||||
@@ -438,30 +293,30 @@ def worker(data, output_file_path):
|
||||
|
||||
# 规范化参数
|
||||
normalized_args = normalize_material_args(arguments_data)
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||
# print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
|
||||
# print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||
# print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
|
||||
|
||||
# 优先使用mattergen函数
|
||||
try:
|
||||
output = asyncio.run(generate_material(**normalized_args))
|
||||
# output = asyncio.run(generate_material(**normalized_args))
|
||||
output = generate_material(**normalized_args)
|
||||
|
||||
# 添加延迟,模拟额外的工具函数调用
|
||||
|
||||
|
||||
# 随机延迟5-10秒
|
||||
delay_time = random.uniform(5, 10)
|
||||
print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
|
||||
time.sleep(delay_time)
|
||||
# delay_time = random.uniform(5, 10)
|
||||
# print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
|
||||
# time.sleep(delay_time)
|
||||
|
||||
# 模拟其他工具函数调用的日志输出
|
||||
print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
|
||||
time.sleep(0.5)
|
||||
print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
|
||||
time.sleep(0.5)
|
||||
print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
|
||||
time.sleep(0.5)
|
||||
print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
|
||||
# # 模拟其他工具函数调用的日志输出
|
||||
# print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
|
||||
# time.sleep(0.5)
|
||||
# print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
|
||||
# time.sleep(0.5)
|
||||
# print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
|
||||
# time.sleep(0.5)
|
||||
# print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"{Fore.RED}mattergen出错,尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
|
||||
@@ -478,14 +333,15 @@ def worker(data, output_file_path):
|
||||
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
|
||||
|
||||
else:
|
||||
delay_time = random.uniform(1, 5)
|
||||
time.sleep(delay_time)
|
||||
result = asyncio.run(execute_tool_from_dict(func))
|
||||
func_results.append({"function": func['name'], "result": result})
|
||||
# 格式化结果
|
||||
func_name = func.get("name")
|
||||
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
formatted_results.append(formatted_result)
|
||||
# delay_time = random.uniform(5, 10)
|
||||
# time.sleep(delay_time)
|
||||
pass
|
||||
# result = asyncio.run(execute_tool_from_dict(func))
|
||||
# func_results.append({"function": func['name'], "result": result})
|
||||
# # 格式化结果
|
||||
# func_name = func.get("name")
|
||||
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
|
||||
# formatted_results.append(formatted_result)
|
||||
|
||||
# 将所有格式化后的结果连接起来
|
||||
final_result = "\n\n\n".join(formatted_results)
|
||||
@@ -557,8 +413,8 @@ if __name__ == '__main__':
|
||||
|
||||
print(len(datas))
|
||||
# print()
|
||||
output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||
main(datas, output_file, max_workers=16)
|
||||
output_file = f"./filter_ok_questions_solutions_agent_mattergen_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
|
||||
main(datas, output_file, max_workers=1)
|
||||
|
||||
# 示例1:使用正确的JSON格式
|
||||
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
|
||||
|
||||
Reference in New Issue
Block a user