135 lines
4.3 KiB
Python
135 lines
4.3 KiB
Python
import requests
|
||
import json
|
||
import argparse
|
||
import sys
|
||
|
||
def generate_material(
|
||
url="http://localhost:8051/generate_material",
|
||
properties=None,
|
||
batch_size=2,
|
||
num_batches=1,
|
||
diffusion_guidance_factor=2.0
|
||
):
|
||
"""
|
||
调用MatterGen API生成晶体结构
|
||
|
||
Args:
|
||
url: API端点URL
|
||
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
|
||
batch_size: 每批生成的结构数量
|
||
num_batches: 批次数量
|
||
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
|
||
|
||
Returns:
|
||
生成的结构内容或错误信息
|
||
"""
|
||
# 构建请求负载
|
||
payload = {
|
||
"properties": properties ,
|
||
"batch_size": batch_size,
|
||
"num_batches": num_batches,
|
||
"diffusion_guidance_factor": diffusion_guidance_factor
|
||
}
|
||
|
||
print(f"发送请求到 {url}")
|
||
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
||
|
||
try:
|
||
# 添加headers参数,包含accept头
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"accept": "application/json"
|
||
}
|
||
|
||
# 打印完整请求信息(调试用)
|
||
print(f"完整请求URL: {url}")
|
||
print(f"请求头: {headers}")
|
||
print(f"请求体: {json.dumps(payload)}")
|
||
|
||
# 禁用代理设置
|
||
proxies = {
|
||
"http": None,
|
||
"https": None
|
||
}
|
||
|
||
# 发送POST请求,添加headers参数,禁用代理,增加超时时间
|
||
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
|
||
|
||
# 打印响应信息(调试用)
|
||
print(f"响应状态码: {response.status_code}")
|
||
print(f"响应头: {dict(response.headers)}")
|
||
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符,避免输出过长
|
||
|
||
# 检查响应状态
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
|
||
if result["success"]:
|
||
print("\n生成成功!")
|
||
return result["content"]
|
||
else:
|
||
print(f"\n生成失败: {result['message']}")
|
||
return None
|
||
else:
|
||
print(f"\n请求失败,状态码: {response.status_code}")
|
||
print(f"响应内容: {response.text}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"\n发生错误: {str(e)}")
|
||
print(f"错误类型: {type(e).__name__}")
|
||
import traceback
|
||
print(f"错误堆栈: {traceback.format_exc()}")
|
||
return None
|
||
|
||
def main():
|
||
"""命令行入口函数"""
|
||
parser = argparse.ArgumentParser(description="MatterGen API客户端示例")
|
||
|
||
# 添加命令行参数
|
||
parser.add_argument("--url", default="http://localhost:8051/generate_material",
|
||
help="MatterGen API端点URL")
|
||
parser.add_argument("--property-name", default='dft_mag_density',help="属性名称,例如dft_band_gap")
|
||
parser.add_argument("--property-value",default=0.15,help="属性值,例如2.0")
|
||
parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量")
|
||
parser.add_argument("--num-batches", type=int, default=1, help="批次数量")
|
||
parser.add_argument("--guidance-factor", type=float, default=2.0,
|
||
help="控制生成结构与目标属性的符合程度")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 构建属性字典
|
||
properties = None
|
||
if args.property_name and args.property_value:
|
||
try:
|
||
# 尝试将属性值转换为数字
|
||
try:
|
||
value = float(args.property_value)
|
||
# 如果是整数,转换为整数
|
||
if value.is_integer():
|
||
value = int(value)
|
||
except ValueError:
|
||
# 如果无法转换为数字,保持为字符串
|
||
value = args.property_value
|
||
|
||
properties = {args.property_name: value}
|
||
except Exception as e:
|
||
print(f"解析属性值时出错: {str(e)}")
|
||
return
|
||
|
||
# 调用API
|
||
result = generate_material(
|
||
url=args.url,
|
||
properties=properties,
|
||
batch_size=args.batch_size,
|
||
num_batches=args.num_batches,
|
||
diffusion_guidance_factor=args.guidance_factor
|
||
)
|
||
|
||
if result:
|
||
print("\n生成的结构:")
|
||
print(result)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|