生成数据:mattergen改成了同步

This commit is contained in:
lzy
2025-04-06 20:35:13 +08:00
parent 71d8dabd17
commit 72045e5cfe
14 changed files with 557 additions and 191 deletions

View File

@@ -6,6 +6,8 @@ import jsonlines
from mars_toolkit import *
import threading
import uuid
from mars_toolkit.compute.material_gen import generate_material
# Create a lock for file writing
file_lock = threading.Lock()
from mysql.connector import pooling
@@ -180,153 +182,6 @@ async def process_retrieval_from_knowledge_base(data):
async def mattergen(
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen服务生成晶体结构
Args:
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
try:
# 导入MatterGenService
from mars_toolkit.services.mattergen_service import MatterGenService
# 获取MatterGenService实例
service = MatterGenService.get_instance()
# 使用服务生成材料
result = await service.generate(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
diffusion_guidance_factor=diffusion_guidance_factor
)
return result
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Error in mattergen: {e}")
import traceback
logger.error(traceback.format_exc())
return f"Error generating material: {str(e)}"
async def generate_material(
url="http://localhost:8051/generate_material",
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen API生成晶体结构
Args:
url: API端点URL
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
# 尝试使用本地MatterGen服务
try:
print("尝试使用本地MatterGen服务...")
result = await mattergen(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
diffusion_guidance_factor=diffusion_guidance_factor
)
if result and not result.startswith("Error"):
print("本地MatterGen服务生成成功!")
return result
else:
print(f"本地MatterGen服务生成失败尝试使用API: {result}")
except Exception as e:
print(f"本地MatterGen服务出错尝试使用API: {str(e)}")
# 如果本地服务失败回退到API调用
# 规范化参数
normalized_args = normalize_material_args({
"properties": properties,
"batch_size": batch_size,
"num_batches": num_batches,
"diffusion_guidance_factor": diffusion_guidance_factor
})
# 构建请求负载
payload = {
"properties": normalized_args["properties"],
"batch_size": normalized_args["batch_size"],
"num_batches": normalized_args["num_batches"],
"diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"]
}
print(f"发送请求到 {url}")
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
# 添加headers参数包含accept头
headers = {
"Content-Type": "application/json",
"accept": "application/json"
}
# 打印完整请求信息(调试用)
print(f"完整请求URL: {url}")
print(f"请求头: {json.dumps(headers, indent=2)}")
print(f"请求体: {json.dumps(payload, indent=2)}")
# 禁用代理设置
proxies = {
"http": None,
"https": None
}
# 发送POST请求添加headers参数禁用代理增加超时时间
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
# 打印响应信息(调试用)
print(f"响应状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符避免输出过长
# 检查响应状态
if response.status_code == 200:
result = response.json()
if result["success"]:
print("\n生成成功!")
return result["content"]
else:
print(f"\n生成失败: {result['message']}")
return None
else:
print(f"\n请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return None
except Exception as e:
print(f"\n发生错误: {str(e)}")
print(f"错误类型: {type(e).__name__}")
import traceback
print(f"错误堆栈: {traceback.format_exc()}")
return None
async def execute_tool_from_dict(input_dict: dict):
"""
从字典中提取工具函数名称和参数,并执行相应的工具函数
@@ -416,14 +271,14 @@ def worker(data, output_file_path):
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
delay_time = random.uniform(1, 5)
time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
pass
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
# result = asyncio.run(process_retrieval_from_knowledge_base(data))
# func_results.append({"function": func['name'], "result": result})
# # 格式化结果
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
# formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
# 规范化参数
@@ -438,30 +293,30 @@ def worker(data, output_file_path):
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# 优先使用mattergen函数
try:
output = asyncio.run(generate_material(**normalized_args))
# output = asyncio.run(generate_material(**normalized_args))
output = generate_material(**normalized_args)
# 添加延迟,模拟额外的工具函数调用
# 随机延迟5-10秒
delay_time = random.uniform(5, 10)
print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
time.sleep(delay_time)
# delay_time = random.uniform(5, 10)
# print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
# time.sleep(delay_time)
# 模拟其他工具函数调用的日志输出
print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
time.sleep(0.5)
print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
time.sleep(0.5)
print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
time.sleep(0.5)
print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
# # 模拟其他工具函数调用的日志输出
# print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
# time.sleep(0.5)
# print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
# time.sleep(0.5)
# print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
# time.sleep(0.5)
# print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
@@ -478,14 +333,15 @@ def worker(data, output_file_path):
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
else:
delay_time = random.uniform(1, 5)
time.sleep(delay_time)
result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
func_name = func.get("name")
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
# delay_time = random.uniform(5, 10)
# time.sleep(delay_time)
pass
# result = asyncio.run(execute_tool_from_dict(func))
# func_results.append({"function": func['name'], "result": result})
# # 格式化结果
# func_name = func.get("name")
# formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
# formatted_results.append(formatted_result)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
@@ -557,8 +413,8 @@ if __name__ == '__main__':
print(len(datas))
# print()
output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=16)
output_file = f"./filter_ok_questions_solutions_agent_mattergen_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=1)
# 示例1使用正确的JSON格式
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'