初次提交

This commit is contained in:
lzy
2025-05-09 14:16:33 +08:00
commit 3a50afeec4
56 changed files with 9224 additions and 0 deletions

View File

@@ -0,0 +1,339 @@
import asyncio
import json
import sys
import os
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
# 添加项目根目录到sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sci_mcp import *
# 创建Rich控制台对象
console = Console()
# 创建一个列表来存储工具调用结果
tool_results = []
# 定义分隔符样式
def print_separator(title=""):
console.print(Panel(f"[bold magenta]{title}[/]", border_style="cyan", expand=False), justify="center")
api_key="gpustack_41f8963fb74e9f39_931a02e3ab35b41b1eada85da9c40705"
base_url="http://gpustack.ddwtop.team/v1-openai"
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
from sci_mcp import *
tools_schemas = get_domain_tool_schemas(["material",'general'])
tool_map = get_domain_tools(["material",'general'])
# 打印消息的函数
def print_message(message):
# 处理不同类型的消息对象
if hasattr(message, 'role'): # ChatCompletionMessage 对象
role = message.role
content = message.content if hasattr(message, 'content') else ""
# 如果是工具消息,获取工具名称
tool_name = None
if role == "tool" and hasattr(message, 'name'):
tool_name = message.name
else: # 字典类型
role = message.get("role", "unknown")
content = message.get("content", "")
# 如果是工具消息,获取工具名称
tool_name = message.get("name") if role == "tool" else None
# 根据角色选择不同的颜色
role_colors = {
"system": "bright_blue",
"user": "green",
"assistant": "yellow",
"tool": "bright_red"
}
color = role_colors.get(role, "white")
# 创建富文本面板
text = Text()
# 如果是工具消息,添加工具名称
if role == "tool" and tool_name:
text.append(f"{role} ({tool_name}): ", style=f"bold {color}")
else:
text.append(f"{role}: ", style=f"bold {color}")
text.append(str(content))
console.print(Panel(text, border_style=color))
messages = [
{"role": "user", "content": """data_Ti4V
_symmetry_space_group_name_H-M Fmmm
_cell_length_a 3.18353600
_cell_length_b 4.52677200
_cell_length_c 22.74397000
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 69
_chemical_formula_structural Ti4V
_chemical_formula_sum 'Ti16 V4'
_cell_volume 327.76657340
_cell_formula_units_Z 4
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-x, -y, z'
4 'x, y, -z'
5 'x, -y, -z'
6 '-x, y, z'
7 '-x, y, -z'
8 'x, -y, z'
9 'x+1/2, y, z+1/2'
10 '-x+1/2, -y, -z+1/2'
11 '-x+1/2, -y, z+1/2'
12 'x+1/2, y, -z+1/2'
13 'x+1/2, -y, -z+1/2'
14 '-x+1/2, y, z+1/2'
15 '-x+1/2, y, -z+1/2'
16 'x+1/2, -y, z+1/2'
17 'x+1/2, y+1/2, z'
18 '-x+1/2, -y+1/2, -z'
19 '-x+1/2, -y+1/2, z'
20 'x+1/2, y+1/2, -z'
21 'x+1/2, -y+1/2, -z'
22 '-x+1/2, y+1/2, z'
23 '-x+1/2, y+1/2, -z'
24 'x+1/2, -y+1/2, z'
25 'x, y+1/2, z+1/2'
26 '-x, -y+1/2, -z+1/2'
27 '-x, -y+1/2, z+1/2'
28 'x, y+1/2, -z+1/2'
29 'x, -y+1/2, -z+1/2'
30 '-x, y+1/2, z+1/2'
31 '-x, y+1/2, -z+1/2'
32 'x, -y+1/2, z+1/2'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ti Ti0 8 0.00000000 0.00000000 0.19939400 1.0
Ti Ti1 8 0.00000000 0.00000000 0.40878900 1.0
V V2 4 0.00000000 0.00000000 0.00000000 1.0
根据上文提供的CIF文件请你分析该晶体材料的化学式、晶体系统、空间群、密度、是否为金属、磁性和稳定性并用JSON格式回答。"""}
]
#how to synthesize CsPbBr3 at room temperature
#
# 打印初始消息
print_separator("初始消息")
for message in messages:
print_message(message)
finish_reason = None
async def execute_tool(tool_name,tool_arguments):
# 设置LLM调用上下文标记
try:
tool_func = tool_map[tool_name] # 获取工具函数
arguments = {}
if tool_arguments:
# 检查arguments是字符串还是字典
if isinstance(tool_arguments, dict):
# 如果已经是字典,直接使用
arguments = tool_arguments
elif isinstance(tool_arguments, str):
# 如果是字符串尝试解析为JSON
try:
# 尝试直接解析为JSON对象
arguments = json.loads(tool_arguments)
except json.JSONDecodeError:
# 如果解析失败,可能是因为字符串中包含转义字符
# 尝试修复常见的JSON字符串问题
fixed_str = tool_arguments.replace('\\"', '"').replace('\\\\', '\\')
try:
arguments = json.loads(fixed_str)
except json.JSONDecodeError:
# 如果仍然失败,尝试将字符串作为原始字符串处理
arguments = {"raw_string": tool_arguments}
# 调用工具函数
if asyncio.iscoroutinefunction(tool_func):
# 如果是异步函数使用await调用
result = await tool_func(**arguments)
else:
# 如果是同步函数,直接调用
result = tool_func(**arguments)
# if func_name=='generate_material':
# print("xxxxx",result)
return result
finally:
# 清除LLM调用上下文标记
pass
# 定义一个函数来估算消息的token数量粗略估计
def estimate_tokens(messages):
# 简单估计每个英文单词约1.3个token每个中文字符约1个token
total = 0
for msg in messages:
content = msg.get("content", "") if isinstance(msg, dict) else (msg.content if hasattr(msg, "content") else "")
if content:
# 粗略估计内容的token数
total += len(content) * 1.3
# 估计工具调用的token数
tool_calls = msg.get("tool_calls", []) if isinstance(msg, dict) else (msg.tool_calls if hasattr(msg, "tool_calls") else [])
for tool_call in tool_calls:
if isinstance(tool_call, dict):
args = tool_call.get("function", {}).get("arguments", "")
total += len(args) * 1.3
else:
args = tool_call.function.arguments if hasattr(tool_call, "function") else ""
total += len(args) * 1.3
return int(total)
# 管理消息历史保持在token限制以内
def manage_message_history(messages, max_tokens=7000):
# 保留第一条消息(通常是系统消息或初始用户消息)
if len(messages) <= 1:
return messages
# 估算当前消息的token数
current_tokens = estimate_tokens(messages)
# 如果当前token数已经接近限制开始裁剪历史消息
if current_tokens > max_tokens:
# 保留第一条消息和最近的消息
preserved_messages = [messages[0]]
# 从最新的消息开始添加直到接近但不超过token限制
temp_messages = []
for msg in reversed(messages[1:]):
temp_messages.insert(0, msg)
if estimate_tokens([messages[0]] + temp_messages) > max_tokens:
# 如果添加这条消息会超过限制,则停止添加
temp_messages.pop(0)
break
# 如果裁剪后的消息太少,至少保留最近的几条消息
if len(temp_messages) < 4 and len(messages) > 4:
temp_messages = messages[-4:]
return preserved_messages + temp_messages
return messages
while finish_reason is None or finish_reason == "tool_calls":
# 在发送请求前管理消息历史
managed_messages = manage_message_history(messages)
if len(managed_messages) < len(messages):
print_separator(f"消息历史已裁剪,从{len(messages)}条减少到{len(managed_messages)}")
messages = managed_messages
completion = client.chat.completions.create(
model="MARS-T1",
messages=messages,
temperature=0.3,
tools=tools_schemas,
timeout=120,
)
choice = completion.choices[0]
finish_reason = choice.finish_reason
if finish_reason == "tool_calls": # <-- 判断当前返回内容是否包含 tool_calls
# 打印assistant消息
print_separator("Assistant消息")
print_message(choice.message)
# 将ChatCompletionMessage对象转换为字典
assistant_message = {
"role": "assistant",
"content": choice.message.content if hasattr(choice.message, 'content') else None
}
# 如果有工具调用,添加到字典中
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
# 将tool_calls对象转换为字典列表
tool_calls_list = []
for tool_call in choice.message.tool_calls:
tool_call_dict = {
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
tool_calls_list.append(tool_call_dict)
assistant_message["tool_calls"] = tool_calls_list
# 添加消息到上下文
messages.append(assistant_message) # <-- 我们将模型返回给我们的 assistant 消息也添加到上下文中,以便于下次请求时模型能理解我们的诉求
# 打印工具调用信息
print_separator("工具调用")
for tool_call in choice.message.tool_calls:
console.print(f"[bold cyan]工具名称:[/] [yellow]{tool_call.function.name}[/]")
console.print(f"[bold cyan]工具ID:[/] [yellow]{tool_call.id}[/]")
console.print(f"[bold cyan]参数:[/] [yellow]{tool_call.function.arguments}[/]")
console.print("")
tool_call_name = tool_call.function.name
tool_call_arguments = json.loads(tool_call.function.arguments) # <-- arguments 是序列化后的 JSON Object我们需要使用 json.loads 反序列化一下
try:
tool_result = asyncio.run(execute_tool(tool_name=tool_call_name,tool_arguments=tool_call_arguments)) # <-- 通过 tool_map 快速找到需要执行哪个函数
except Exception as e:
tool_result=f'工具调用失败{e}'
# 将工具调用结果保存到单独的列表中,使用指定格式包裹
formatted_result = f"[{tool_call_name} content begin]{tool_result}[{tool_call_name} content end]"
tool_results.append({
"tool_name": tool_call_name,
"tool_id": tool_call.id,
"formatted_result": formatted_result,
"raw_result": tool_result
})
# 打印保存的工具结果信息
console.print(f"[bold green]已保存工具结果:[/] [cyan]{tool_call_name}[/]")
# 构造工具响应消息
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call_name,
"content": tool_result, # <-- 我们约定使用字符串格式向模型提交工具调用结果
}
# 打印工具响应
print_separator(f"工具响应: {tool_call_name}")
print_message(tool_message)
# 添加消息到上下文
#messages.append(tool_message)
# 打印最终响应
if choice.message.content:
print_separator("最终响应")
console.print(Panel(choice.message.content, border_style="green"))
# 打印收集的所有工具调用结果
if tool_results:
print_separator("所有工具调用结果")
console.print(f"[bold cyan]共收集了 {len(tool_results)} 个工具调用结果[/]")
# 将所有格式化的结果写入文件
with open("tool_results.txt", "w", encoding="utf-8") as f:
for result in tool_results:
f.write(f"{result['formatted_result']}\n\n")
console.print(f"[bold green]所有工具调用结果已保存到 tool_results.txt 文件中[/]")