Files
matagent/_backend/api.py

213 lines
7.7 KiB
Python

import json
import logging
import os
from typing import Any, Awaitable, Callable, Optional, Sequence
import aiofiles
import yaml
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import TextMessage, UserInputRequestedEvent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination
from autogen_agentchat.teams import SelectorGroupChat, RoundRobinGroupChat
from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage, ToolCallExecutionEvent
from constant import MODEL, OPENAI_API_KEY, OPENAI_BASE_URL
from scientist_team import create_scientist_team
from engineer_team import create_engineer_team
from robot_platform import create_robot_team
from analyst_team import create_analyst_team
from utils import load_agent_configs
async def get_team(
user_input_func: Callable[[str, Optional[CancellationToken]], Awaitable[str]],
) -> RoundRobinGroupChat | SelectorGroupChat:
# Create the team.
scientist_team = create_scientist_team(model_client=model_client)
engineer_team = create_engineer_team()
robot_platform = create_robot_team()
analyst_team = create_analyst_team()
user = UserProxyAgent(
name="user",
input_func=user_input_func, # Use the user input function.
)
cur_path = os.path.dirname(os.path.abspath(__file__))
planning_agent_system_message = load_agent_configs(os.path.join(cur_path, "agent_config/planning.yaml"))
planning_agent = AssistantAgent(
"PlanningAgent",
description="An agent for planning tasks, this agent should be the first to engage when given a new task.",
model_client=model_client,
system_message=planning_agent_system_message["PlanningAgent"],
reflect_on_tool_use=False
)
# The termination condition is a combination of text mention termination and max message termination.
text_mention_termination = TextMentionTermination("TERMINATE")
max_messages_termination = MaxMessageTermination(max_messages=200)
termination = text_mention_termination | max_messages_termination
# The selector function is a function that takes the current message thread of the group chat
# and returns the next speaker's name. If None is returned, the LLM-based selection method will be used.
def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
if messages[-1].source != planning_agent.name:
return planning_agent.name # Always return to the planning agent after the other agents have spoken.
elif "HUMAN" in messages[-1].content:
return user.name
return None
team = SelectorGroupChat(
[planning_agent, user, scientist_team, engineer_team, robot_platform, analyst_team],
model_client=model_client, # Use a smaller model for the selector.
termination_condition=termination,
selector_func=selector_func,
)
# Load state from file.
if not os.path.exists(state_path):
return team
async with aiofiles.open(state_path, "r") as file:
state = json.loads(await file.read())
await team.load_state(state)
return team
logger = logging.getLogger(__name__)
app = FastAPI()
current_task = None # 用于跟踪当前任务
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
model_config_path = "model_config.yaml"
state_path = "team_state.json"
history_path = "team_history.json"
# Serve static files
app.mount("/static", StaticFiles(directory="."), name="static")
@app.get("/")
async def root():
"""Serve the chat interface HTML file."""
return FileResponse("app_team.html")
model_client = OpenAIChatCompletionClient(
model=MODEL,
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
model_info={
"vision": True,
"function_calling": True,
"json_output": True,
"family": "unknown",
},
)
async def get_history() -> list[dict[str, Any]]:
"""Get chat history from file."""
if not os.path.exists(history_path):
return []
async with aiofiles.open(history_path, "r") as file:
return json.loads(await file.read())
@app.get("/history")
async def history() -> list[dict[str, Any]]:
try:
return await get_history()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@app.websocket("/ws/chat")
async def chat(websocket: WebSocket):
await websocket.accept()
# User input function used by the team.
async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str:
data = await websocket.receive_json()
# message = TextMessage.model_validate(data)
# return message.content
return data['content']
try:
while True:
# Get user message.
data = await websocket.receive_json()
# request = TextMessage.model_validate(data)
request = data['content']
try:
# Get the team and respond to the message.
team = await get_team(_user_input)
history = await get_history()
stream = team.run_stream(task=request)
async for message in stream:
if isinstance(message, TaskResult):
continue
print(f"----------------{message.source}----------------\n {message.content}")
await websocket.send_json(message.model_dump())
if not isinstance(message, UserInputRequestedEvent):
# Don't save user input events to history.
history.append(message.model_dump())
# Save team state to file.
async with aiofiles.open(state_path, "w") as file:
state = await team.save_state()
await file.write(json.dumps(state))
# Save chat history to file.
async with aiofiles.open(history_path, "w") as file:
await file.write(json.dumps(history))
except Exception as e:
# Send error message to client
error_message = {
"type": "error",
"content": f"Error: {str(e)}",
"source": "system"
}
await websocket.send_json(error_message)
# Re-enable input after error
await websocket.send_json({
"type": "UserInputRequestedEvent",
"content": "An error occurred. Please try again.",
"source": "system"
})
except WebSocketDisconnect:
logger.info("Client disconnected")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
try:
await websocket.send_json({
"type": "error",
"content": f"Unexpected error: {str(e)}",
"source": "system"
})
except:
pass
# Example usage
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)