import json import logging import os from typing import Any, Awaitable, Callable, Optional, Sequence import aiofiles import yaml from autogen_agentchat.agents import AssistantAgent, UserProxyAgent from autogen_agentchat.base import TaskResult from autogen_agentchat.messages import TextMessage, UserInputRequestedEvent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination from autogen_agentchat.teams import SelectorGroupChat, RoundRobinGroupChat from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage, ToolCallExecutionEvent from constant import MODEL, OPENAI_API_KEY, OPENAI_BASE_URL from scientist_team import create_scientist_team from engineer_team import create_engineer_team from robot_platform import create_robot_team from analyst_team import create_analyst_team from utils import load_agent_configs async def get_team( user_input_func: Callable[[str, Optional[CancellationToken]], Awaitable[str]], ) -> RoundRobinGroupChat | SelectorGroupChat: # Create the team. scientist_team = create_scientist_team(model_client=model_client) engineer_team = create_engineer_team() robot_platform = create_robot_team() analyst_team = create_analyst_team() user = UserProxyAgent( name="user", input_func=user_input_func, # Use the user input function. ) cur_path = os.path.dirname(os.path.abspath(__file__)) planning_agent_system_message = load_agent_configs(os.path.join(cur_path, "agent_config/planning.yaml")) planning_agent = AssistantAgent( "PlanningAgent", description="An agent for planning tasks, this agent should be the first to engage when given a new task.", model_client=model_client, system_message=planning_agent_system_message["PlanningAgent"], reflect_on_tool_use=False ) # The termination condition is a combination of text mention termination and max message termination. text_mention_termination = TextMentionTermination("TERMINATE") max_messages_termination = MaxMessageTermination(max_messages=200) termination = text_mention_termination | max_messages_termination # The selector function is a function that takes the current message thread of the group chat # and returns the next speaker's name. If None is returned, the LLM-based selection method will be used. def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None: if messages[-1].source != planning_agent.name: return planning_agent.name # Always return to the planning agent after the other agents have spoken. elif "HUMAN" in messages[-1].content: return user.name return None team = SelectorGroupChat( [planning_agent, user, scientist_team, engineer_team, robot_platform, analyst_team], model_client=model_client, # Use a smaller model for the selector. termination_condition=termination, selector_func=selector_func, ) # Load state from file. if not os.path.exists(state_path): return team async with aiofiles.open(state_path, "r") as file: state = json.loads(await file.read()) await team.load_state(state) return team logger = logging.getLogger(__name__) app = FastAPI() current_task = None # 用于跟踪当前任务 # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) model_config_path = "model_config.yaml" state_path = "team_state.json" history_path = "team_history.json" # Serve static files app.mount("/static", StaticFiles(directory="."), name="static") @app.get("/") async def root(): """Serve the chat interface HTML file.""" return FileResponse("app_team.html") model_client = OpenAIChatCompletionClient( model=MODEL, base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY, model_info={ "vision": True, "function_calling": True, "json_output": True, "family": "unknown", }, ) async def get_history() -> list[dict[str, Any]]: """Get chat history from file.""" if not os.path.exists(history_path): return [] async with aiofiles.open(history_path, "r") as file: return json.loads(await file.read()) @app.get("/history") async def history() -> list[dict[str, Any]]: try: return await get_history() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e @app.websocket("/ws/chat") async def chat(websocket: WebSocket): await websocket.accept() # User input function used by the team. async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str: data = await websocket.receive_json() # message = TextMessage.model_validate(data) # return message.content return data['content'] try: while True: # Get user message. data = await websocket.receive_json() # request = TextMessage.model_validate(data) request = data['content'] try: # Get the team and respond to the message. team = await get_team(_user_input) history = await get_history() stream = team.run_stream(task=request) async for message in stream: if isinstance(message, TaskResult): continue print(f"----------------{message.source}----------------\n {message.content}") await websocket.send_json(message.model_dump()) if not isinstance(message, UserInputRequestedEvent): # Don't save user input events to history. history.append(message.model_dump()) # Save team state to file. async with aiofiles.open(state_path, "w") as file: state = await team.save_state() await file.write(json.dumps(state)) # Save chat history to file. async with aiofiles.open(history_path, "w") as file: await file.write(json.dumps(history)) except Exception as e: # Send error message to client error_message = { "type": "error", "content": f"Error: {str(e)}", "source": "system" } await websocket.send_json(error_message) # Re-enable input after error await websocket.send_json({ "type": "UserInputRequestedEvent", "content": "An error occurred. Please try again.", "source": "system" }) except WebSocketDisconnect: logger.info("Client disconnected") except Exception as e: logger.error(f"Unexpected error: {str(e)}") try: await websocket.send_json({ "type": "error", "content": f"Unexpected error: {str(e)}", "source": "system" }) except: pass # Example usage if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)