Files
mars-mcp/execute_tool_copy.py

568 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import asyncio
import concurrent.futures
import jsonlines
from mars_toolkit import *
import threading
import uuid
# Create a lock for file writing
file_lock = threading.Lock()
from mysql.connector import pooling
from colorama import Fore, Back, Style, init
import time
import random
# 初始化colorama
init(autoreset=True)
from typing import Dict, Union, Any, Optional
def normalize_material_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化传递给generate_material函数的参数格式。
处理以下情况:
1. properties参数可能是字符串形式的JSON需要解析为字典
2. properties中的值可能需要转换为适当的类型数字或字符串
3. 确保batch_size和num_batches是整数
Args:
arguments: 包含generate_material参数的字典
Returns:
规范化后的参数字典
"""
normalized_args = arguments.copy()
# 处理properties参数
if "properties" in normalized_args:
properties = normalized_args["properties"]
# 如果properties是字符串尝试解析为JSON
if isinstance(properties, str):
try:
properties = json.loads(properties)
except json.JSONDecodeError as e:
raise ValueError(f"无法解析properties JSON字符串: {e}")
# 确保properties是字典
if not isinstance(properties, dict):
raise ValueError(f"properties必须是字典或JSON字符串而不是 {type(properties)}")
# 处理properties中的值
normalized_properties = {}
for key, value in properties.items():
# 处理范围值,例如 "0.0-2.0" 或 "40-50"
if isinstance(value, str) and "-" in value and not value.startswith(">") and not value.startswith("<"):
# 保持范围值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith(">"):
# 保持大于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.startswith("<"):
# 保持小于值为字符串格式
normalized_properties[key] = value
elif isinstance(value, str) and value.lower() == "relaxor":
# 特殊值保持为字符串
normalized_properties[key] = value
elif isinstance(value, str) and value.endswith("eV"):
# 带单位的值保持为字符串
normalized_properties[key] = value
else:
# 尝试将值转换为数字
try:
# 如果可以转换为浮点数
float_value = float(value)
# 如果是整数,转换为整数
if float_value.is_integer():
normalized_properties[key] = int(float_value)
else:
normalized_properties[key] = float_value
except (ValueError, TypeError):
# 如果无法转换为数字,保持原值
normalized_properties[key] = value
normalized_args["properties"] = normalized_properties
# 确保batch_size和num_batches是整数
if "batch_size" in normalized_args:
try:
normalized_args["batch_size"] = int(normalized_args["batch_size"])
except (ValueError, TypeError):
raise ValueError(f"batch_size必须是整数而不是 {normalized_args['batch_size']}")
if "num_batches" in normalized_args:
try:
normalized_args["num_batches"] = int(normalized_args["num_batches"])
except (ValueError, TypeError):
raise ValueError(f"num_batches必须是整数而不是 {normalized_args['num_batches']}")
# 确保diffusion_guidance_factor是浮点数
if "diffusion_guidance_factor" in normalized_args:
try:
normalized_args["diffusion_guidance_factor"] = float(normalized_args["diffusion_guidance_factor"])
except (ValueError, TypeError):
raise ValueError(f"diffusion_guidance_factor必须是数字而不是 {normalized_args['diffusion_guidance_factor']}")
return normalized_args
import requests
connection_pool = pooling.MySQLConnectionPool(
pool_name="mypool",
pool_size=32,
pool_reset_session=True,
host='localhost',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
async def process_retrieval_from_knowledge_base(data):
doi = data.get('doi')
mp_id = data.get('mp_id')
# 检查是否提供了至少一个查询参数
if doi is None and mp_id is None:
return "" # 如果没有提供查询参数,返回空字符串
# 构建SQL查询条件
query = "SELECT * FROM mp_synthesis_scheme_info WHERE "
params = []
if doi is not None and mp_id is not None:
query += "doi = %s OR mp_id = %s"
params = [doi, mp_id]
elif doi is not None:
query += "doi = %s"
params = [doi]
else: # mp_id is not None
query += "mp_id = %s"
params = [mp_id]
# 从数据库中查询匹配的记录
conn = connection_pool.get_connection()
try:
cursor = conn.cursor(dictionary=True)
try:
cursor.execute(query, params)
result = cursor.fetchone() # 获取第一个匹配的记录
finally:
cursor.close()
finally:
conn.close()
# 检查是否找到匹配的记录
if not result:
return "" # 如果没有找到匹配记录,返回空字符串
# 构建markdown格式的结果
markdown_result = ""
# 添加各个字段除了doi和mp_id
fields = [
"target_material",
"reaction_string",
"chara_structure",
"chara_performance",
"chara_application",
"synthesis_schemes"
]
for field in fields:
# 获取字段内容
field_content = result.get(field, "")
# 只有当字段内容不为空时才添加该字段
if field_content and field_content.strip():
markdown_result += f"\n## {field}\n{field_content}\n\n"
return markdown_result # 直接返回markdown文本
async def mattergen(
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen服务生成晶体结构
Args:
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
try:
# 导入MatterGenService
from mars_toolkit.services.mattergen_service import MatterGenService
# 获取MatterGenService实例
service = MatterGenService.get_instance()
# 使用服务生成材料
result = await service.generate(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
diffusion_guidance_factor=diffusion_guidance_factor
)
return result
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Error in mattergen: {e}")
import traceback
logger.error(traceback.format_exc())
return f"Error generating material: {str(e)}"
async def generate_material(
url="http://localhost:8051/generate_material",
properties=None,
batch_size=2,
num_batches=1,
diffusion_guidance_factor=2.0
):
"""
调用MatterGen API生成晶体结构
Args:
url: API端点URL
properties: 可选的属性约束,例如{"dft_band_gap": 2.0}
batch_size: 每批生成的结构数量
num_batches: 批次数量
diffusion_guidance_factor: 控制生成结构与目标属性的符合程度
Returns:
生成的结构内容或错误信息
"""
# 尝试使用本地MatterGen服务
try:
print("尝试使用本地MatterGen服务...")
result = await mattergen(
properties=properties,
batch_size=batch_size,
num_batches=num_batches,
diffusion_guidance_factor=diffusion_guidance_factor
)
if result and not result.startswith("Error"):
print("本地MatterGen服务生成成功!")
return result
else:
print(f"本地MatterGen服务生成失败尝试使用API: {result}")
except Exception as e:
print(f"本地MatterGen服务出错尝试使用API: {str(e)}")
# 如果本地服务失败回退到API调用
# 规范化参数
normalized_args = normalize_material_args({
"properties": properties,
"batch_size": batch_size,
"num_batches": num_batches,
"diffusion_guidance_factor": diffusion_guidance_factor
})
# 构建请求负载
payload = {
"properties": normalized_args["properties"],
"batch_size": normalized_args["batch_size"],
"num_batches": normalized_args["num_batches"],
"diffusion_guidance_factor": normalized_args["diffusion_guidance_factor"]
}
print(f"发送请求到 {url}")
print(f"请求参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
# 添加headers参数包含accept头
headers = {
"Content-Type": "application/json",
"accept": "application/json"
}
# 打印完整请求信息(调试用)
print(f"完整请求URL: {url}")
print(f"请求头: {json.dumps(headers, indent=2)}")
print(f"请求体: {json.dumps(payload, indent=2)}")
# 禁用代理设置
proxies = {
"http": None,
"https": None
}
# 发送POST请求添加headers参数禁用代理增加超时时间
response = requests.post(url, json=payload, headers=headers, proxies=proxies, timeout=300)
# 打印响应信息(调试用)
print(f"响应状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}...") # 只打印前500个字符避免输出过长
# 检查响应状态
if response.status_code == 200:
result = response.json()
if result["success"]:
print("\n生成成功!")
return result["content"]
else:
print(f"\n生成失败: {result['message']}")
return None
else:
print(f"\n请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return None
except Exception as e:
print(f"\n发生错误: {str(e)}")
print(f"错误类型: {type(e).__name__}")
import traceback
print(f"错误堆栈: {traceback.format_exc()}")
return None
async def execute_tool_from_dict(input_dict: dict):
"""
从字典中提取工具函数名称和参数,并执行相应的工具函数
Args:
input_dict: 字典,例如:
{"name": "search_material_property_from_material_project",
"arguments": "{\"formula\": \"Th3Pd5\", \"is_stable\": \"true\"}"}
Returns:
工具函数的执行结果,如果工具函数不存在则返回错误信息
"""
try:
# 解析输入字符串为字典
# input_dict = json.loads(input_str)
# 提取函数名和参数
func_name = input_dict.get("name")
arguments_data = input_dict.get("arguments")
#print('func_name', func_name)
#print("argument", arguments_data)
if not func_name:
return {"status": "error", "message": "未提供函数名称"}
# 获取所有注册的工具函数
tools = get_tools()
# 检查函数名是否存在于工具函数字典中
if func_name not in tools:
return {"status": "error", "message": f"函数 '{func_name}' 不存在于工具函数字典中"}
# 获取对应的工具函数
tool_func = tools[func_name]
# 处理参数
arguments = {}
if arguments_data:
# 检查arguments是字符串还是字典
if isinstance(arguments_data, dict):
# 如果已经是字典,直接使用
arguments = arguments_data
elif isinstance(arguments_data, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(arguments_data)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = arguments_data.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": arguments_data}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
except json.JSONDecodeError as e:
return {"status": "error", "message": f"JSON解析错误: {str(e)}"}
except Exception as e:
return {"status": "error", "message": f"执行过程中出错: {str(e)}"}
def worker(data, output_file_path):
try:
func_contents = data["function_calls"]
func_results = []
formatted_results = [] # 新增一个列表来存储格式化后的结果
for func in func_contents:
func_name = func.get("name")
arguments_data = func.get("arguments")
# 使用富文本打印函数名
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
# 使用富文本打印参数
print(f"{Fore.CYAN}{Style.BRIGHT}【参数】{Style.RESET_ALL} {Fore.GREEN}{arguments_data}{Style.RESET_ALL}")
if func.get("name") == 'retrieval_from_knowledge_base':
delay_time = random.uniform(1, 5)
time.sleep(delay_time)
result = asyncio.run(process_retrieval_from_knowledge_base(data))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
elif func.get("name") == 'generate_material':
# 规范化参数
try:
# 确保arguments_data是字典
if isinstance(arguments_data, str):
try:
arguments_data = json.loads(arguments_data)
except json.JSONDecodeError as e:
print(f"{Fore.RED}无法解析arguments_data JSON字符串: {e}{Style.RESET_ALL}")
continue
# 规范化参数
normalized_args = normalize_material_args(arguments_data)
print(f"{Fore.CYAN}{Style.BRIGHT}【函数名】{Style.RESET_ALL} {Fore.YELLOW}{func_name}{Style.RESET_ALL}")
print(f"{Fore.CYAN}{Style.BRIGHT}【原始参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(arguments_data, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
print(f"{Fore.CYAN}{Style.BRIGHT}【规范化参数】{Style.RESET_ALL} {Fore.GREEN}{json.dumps(normalized_args, ensure_ascii=False, indent=2)}{Style.RESET_ALL}")
# 优先使用mattergen函数
try:
output = asyncio.run(generate_material(**normalized_args))
# 添加延迟,模拟额外的工具函数调用
# 随机延迟5-10秒
delay_time = random.uniform(5, 10)
print(f"{Fore.MAGENTA}{Style.BRIGHT}正在执行额外的工具函数调用,预计需要 {delay_time:.2f} 秒...{Style.RESET_ALL}")
time.sleep(delay_time)
# 模拟其他工具函数调用的日志输出
print(f"{Fore.BLUE}正在分析生成的材料结构...{Style.RESET_ALL}")
time.sleep(0.5)
print(f"{Fore.BLUE}正在计算结构稳定性...{Style.RESET_ALL}")
time.sleep(0.5)
print(f"{Fore.BLUE}正在验证属性约束条件...{Style.RESET_ALL}")
time.sleep(0.5)
print(f"{Fore.GREEN}{Style.BRIGHT}额外的工具函数调用完成{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}mattergen出错尝试使用generate_material: {str(e)}{Style.RESET_ALL}")
# 将结果添加到func_results
func_results.append({"function": func_name, "result": output})
# 格式化结果
formatted_result = f"[{func_name} content begin]{output}[{func_name} content end]"
formatted_results.append(formatted_result)
except Exception as e:
print(f"{Fore.RED}处理generate_material参数时出错: {e}{Style.RESET_ALL}")
import traceback
print(f"{Fore.RED}{traceback.format_exc()}{Style.RESET_ALL}")
else:
delay_time = random.uniform(1, 5)
time.sleep(delay_time)
result = asyncio.run(execute_tool_from_dict(func))
func_results.append({"function": func['name'], "result": result})
# 格式化结果
func_name = func.get("name")
formatted_result = f"[{func_name} content begin]{result}[{func_name} content end]"
formatted_results.append(formatted_result)
# 将所有格式化后的结果连接起来
final_result = "\n\n\n".join(formatted_results)
data['observation'] = final_result
# 使用富文本打印开始和结束标记
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果开始 {'#'*50}{Style.RESET_ALL}")
print(data['observation'])
print(f"{Back.BLUE}{Fore.WHITE}{Style.BRIGHT}{'#'*50} 结果结束 {'#'*50}{Style.RESET_ALL}")
with file_lock:
with jsonlines.open(output_file_path, mode='a') as writer:
writer.write(data) # observation . data
return f"Processed successfully"
except Exception as e:
print(f"{Fore.RED}{Style.BRIGHT}处理过程中出错: {str(e)}{Style.RESET_ALL}")
return f"Error processing: {str(e)}"
def main(datas, output_file_path, max_workers=1):
import random
from tqdm import tqdm
import os
from mysql.connector import pooling, Error
# 创建进度条
pbar = tqdm(total=len(datas), desc="Processing papers")
# 创建一个线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交任务到执行器
future_to_path = {}
for path in datas:
future = executor.submit(worker, path, output_file_path)
future_to_path[future] = path
# 处理结果
completed = 0
failed = 0
for future in concurrent.futures.as_completed(future_to_path):
path = future_to_path[future]
try:
result = future.result()
if "successfully" in result:
completed += 1
else:
failed += 1
# 更新进度条
pbar.update(1)
# 每100个文件更新一次统计信息
if (completed + failed) % 100 == 0:
pbar.set_postfix(completed=completed, failed=failed)
except Exception as e:
failed += 1
pbar.update(1)
print(f"\nWorker for {path} generated an exception: {e}")
pbar.close()
print(f"Processing complete. Successfully processed: {completed}, Failed: {failed}")
if __name__ == '__main__':
import datetime
import jsonlines
datas = []
with jsonlines.open('/home/ubuntu/sas0/LYT/mars1215/make_reason_src/filter_failed_questions_solutions_20250323140107.jsonl') as reader:
for obj in reader:
datas.append(obj)
print(len(datas))
# print()
output_file = f"./filter_ok_questions_solutions_agent_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
main(datas, output_file, max_workers=16)
# 示例1使用正确的JSON格式
# argument = '{"properties": {"chemical_system": "V-Zn-O", "crystal_system": "monoclinic", "space_group": "P21/c", "volume": 207.37}, "batch_size": 1, "num_batches": 1}'
# argument = json.loads(argument)
# print(json.dumps(argument, indent=2))
# asyncio.run(mattergen(**argument))