mattergen转服务

This commit is contained in:
lzy
2025-04-02 16:24:50 +08:00
parent a77c2cd377
commit 7034566ee6
30 changed files with 656 additions and 339 deletions

134
mattergen_client_example.py Normal file
View File

@@ -0,0 +1,134 @@
import requests
import json
import argparse
import sys
def generate_material(
url="http://localhost:8051/generate_material",
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen API生成晶体结构
Args:
url: API端点URL
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
# 构建请求负载
payload = {
"properties": properties ,
"batch_size": batch_size,
"num_batches": num_batches,
"diffusion_guidance_factor": diffusion_guidance_factor
}
print(f"发送请求到 {url}")
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
# 添加headers参数包含accept头
headers = {
"Content-Type": "application/json",
"accept": "application/json"
}
# 打印完整请求信息(调试用)
print(f"完整请求URL: {url}")
print(f"请求头: {headers}")
print(f"请求体: {json.dumps(payload)}")
# 禁用代理设置
proxies = {
"http": None,
"https": None
}
# 发送POST请求添加headers参数禁用代理增加超时时间
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
# 打印响应信息(调试用)
print(f"响应状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符避免输出过长
# 检查响应状态
if response.status_code == 200:
result = response.json()
if result["success"]:
print("\n生成成功!")
return result["content"]
else:
print(f"\n生成失败: {result['message']}")
return None
else:
print(f"\n请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return None
except Exception as e:
print(f"\n发生错误: {str(e)}")
print(f"错误类型: {type(e).__name__}")
import traceback
print(f"错误堆栈: {traceback.format_exc()}")
return None
def main():
"""命令行入口函数"""
parser = argparse.ArgumentParser(description="MatterGen API客户端示例")
# 添加命令行参数
parser.add_argument("--url", default="http://localhost:8051/generate_material",
help="MatterGen API端点URL")
parser.add_argument("--property-name", default='dft_mag_density',help="属性名称例如dft_band_gap")
parser.add_argument("--property-value",default=0.15,help="属性值例如2.0")
parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量")
parser.add_argument("--num-batches", type=int, default=1, help="批次数量")
parser.add_argument("--guidance-factor", type=float, default=2.0,
help="控制生成结构与目标属性的符合程度")
args = parser.parse_args()
# 构建属性字典
properties = None
if args.property_name and args.property_value:
try:
# 尝试将属性值转换为数字
try:
value = float(args.property_value)
# 如果是整数,转换为整数
if value.is_integer():
value = int(value)
except ValueError:
# 如果无法转换为数字,保持为字符串
value = args.property_value
properties = {args.property_name: value}
except Exception as e:
print(f"解析属性值时出错: {str(e)}")
return
# 调用API
result = generate_material(
url=args.url,
properties=properties,
batch_size=args.batch_size,
num_batches=args.num_batches,
diffusion_guidance_factor=args.guidance_factor
)
if result:
print("\n生成的结构:")
print(result)
if __name__ == "__main__":
main()