mattergen转服务
This commit is contained in:
134
mattergen_client_example.py
Normal file
134
mattergen_client_example.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import requests
|
||||
import json
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
def generate_material(
|
||||
url="http://localhost:8051/generate_material",
|
||||
properties=None,
|
||||
batch_size=2,
|
||||
num_batches=1,
|
||||
diffusion_guidance_factor=2.0
|
||||
):
|
||||
"""
|
||||
调用MatterGen API生成晶体结构
|
||||
|
||||
Args:
|
||||
url: API端点URL
|
||||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
||||
batch_size: 每批生成的结构数量
|
||||
num_batches: 批次数量
|
||||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
||||
|
||||
Returns:
|
||||
生成的结构内容或错误信息
|
||||
"""
|
||||
# 构建请求负载
|
||||
payload = {
|
||||
"properties": properties ,
|
||||
"batch_size": batch_size,
|
||||
"num_batches": num_batches,
|
||||
"diffusion_guidance_factor": diffusion_guidance_factor
|
||||
}
|
||||
|
||||
print(f"发送请求到 {url}")
|
||||
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
||||
|
||||
try:
|
||||
# 添加headers参数,包含accept头
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"accept": "application/json"
|
||||
}
|
||||
|
||||
# 打印完整请求信息(调试用)
|
||||
print(f"完整请求URL: {url}")
|
||||
print(f"请求头: {headers}")
|
||||
print(f"请求体: {json.dumps(payload)}")
|
||||
|
||||
# 禁用代理设置
|
||||
proxies = {
|
||||
"http": None,
|
||||
"https": None
|
||||
}
|
||||
|
||||
# 发送POST请求,添加headers参数,禁用代理,增加超时时间
|
||||
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
|
||||
|
||||
# 打印响应信息(调试用)
|
||||
print(f"响应状态码: {response.status_code}")
|
||||
print(f"响应头: {dict(response.headers)}")
|
||||
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
||||
if result["success"]:
|
||||
print("\n生成成功!")
|
||||
return result["content"]
|
||||
else:
|
||||
print(f"\n生成失败: {result['message']}")
|
||||
return None
|
||||
else:
|
||||
print(f"\n请求失败,状态码: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n发生错误: {str(e)}")
|
||||
print(f"错误类型: {type(e).__name__}")
|
||||
import traceback
|
||||
print(f"错误堆栈: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def main():
|
||||
"""命令行入口函数"""
|
||||
parser = argparse.ArgumentParser(description="MatterGen API客户端示例")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument("--url", default="http://localhost:8051/generate_material",
|
||||
help="MatterGen API端点URL")
|
||||
parser.add_argument("--property-name", default='dft_mag_density',help="属性名称,例如dft_band_gap")
|
||||
parser.add_argument("--property-value",default=0.15,help="属性值,例如2.0")
|
||||
parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量")
|
||||
parser.add_argument("--num-batches", type=int, default=1, help="批次数量")
|
||||
parser.add_argument("--guidance-factor", type=float, default=2.0,
|
||||
help="控制生成结构与目标属性的符合程度")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 构建属性字典
|
||||
properties = None
|
||||
if args.property_name and args.property_value:
|
||||
try:
|
||||
# 尝试将属性值转换为数字
|
||||
try:
|
||||
value = float(args.property_value)
|
||||
# 如果是整数,转换为整数
|
||||
if value.is_integer():
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
# 如果无法转换为数字,保持为字符串
|
||||
value = args.property_value
|
||||
|
||||
properties = {args.property_name: value}
|
||||
except Exception as e:
|
||||
print(f"解析属性值时出错: {str(e)}")
|
||||
return
|
||||
|
||||
# 调用API
|
||||
result = generate_material(
|
||||
url=args.url,
|
||||
properties=properties,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.num_batches,
|
||||
diffusion_guidance_factor=args.guidance_factor
|
||||
)
|
||||
|
||||
if result:
|
||||
print("\n生成的结构:")
|
||||
print(result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user