Files
mars-mcp/mattergen_client_example.py
2025-04-02 16:24:50 +08:00

135 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import argparse
import sys
def generate_material(
url="http://localhost:8051/generate_material",
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen API生成晶体结构
Args:
url: API端点URL
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
# 构建请求负载
payload = {
"properties": properties ,
"batch_size": batch_size,
"num_batches": num_batches,
"diffusion_guidance_factor": diffusion_guidance_factor
}
print(f"发送请求到 {url}")
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
# 添加headers参数包含accept头
headers = {
"Content-Type": "application/json",
"accept": "application/json"
}
# 打印完整请求信息(调试用)
print(f"完整请求URL: {url}")
print(f"请求头: {headers}")
print(f"请求体: {json.dumps(payload)}")
# 禁用代理设置
proxies = {
"http": None,
"https": None
}
# 发送POST请求添加headers参数禁用代理增加超时时间
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
# 打印响应信息(调试用)
print(f"响应状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符避免输出过长
# 检查响应状态
if response.status_code == 200:
result = response.json()
if result["success"]:
print("\n生成成功!")
return result["content"]
else:
print(f"\n生成失败: {result['message']}")
return None
else:
print(f"\n请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return None
except Exception as e:
print(f"\n发生错误: {str(e)}")
print(f"错误类型: {type(e).__name__}")
import traceback
print(f"错误堆栈: {traceback.format_exc()}")
return None
def main():
"""命令行入口函数"""
parser = argparse.ArgumentParser(description="MatterGen API客户端示例")
# 添加命令行参数
parser.add_argument("--url", default="http://localhost:8051/generate_material",
help="MatterGen API端点URL")
parser.add_argument("--property-name", default='dft_mag_density',help="属性名称例如dft_band_gap")
parser.add_argument("--property-value",default=0.15,help="属性值例如2.0")
parser.add_argument("--batch-size", type=int, default=2, help="每批生成的结构数量")
parser.add_argument("--num-batches", type=int, default=1, help="批次数量")
parser.add_argument("--guidance-factor", type=float, default=2.0,
help="控制生成结构与目标属性的符合程度")
args = parser.parse_args()
# 构建属性字典
properties = None
if args.property_name and args.property_value:
try:
# 尝试将属性值转换为数字
try:
value = float(args.property_value)
# 如果是整数,转换为整数
if value.is_integer():
value = int(value)
except ValueError:
# 如果无法转换为数字,保持为字符串
value = args.property_value
properties = {args.property_name: value}
except Exception as e:
print(f"解析属性值时出错: {str(e)}")
return
# 调用API
result = generate_material(
url=args.url,
properties=properties,
batch_size=args.batch_size,
num_batches=args.num_batches,
diffusion_guidance_factor=args.guidance_factor
)
if result:
print("\n生成的结构:")
print(result)
if __name__ == "__main__":
main()