From 89728ada0ba71b4f565aee764d3db36c52f2edd5 Mon Sep 17 00:00:00 2001 From: Yutang Li Date: Mon, 24 Feb 2025 19:51:35 +0800 Subject: [PATCH] 234 --- _backend/api.py | 7 +- _backend/api1.py | 280 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 284 insertions(+), 3 deletions(-) create mode 100644 _backend/api1.py diff --git a/_backend/api.py b/_backend/api.py index 7201a06..955ce99 100644 --- a/_backend/api.py +++ b/_backend/api.py @@ -220,8 +220,8 @@ async def chat(websocket: WebSocket): async for message in stream: if isinstance(message, TaskResult): continue - print(f"----------------{message.source}----------------\n {message.content}") - if message.type == 'TextMessage' and message.type == 'HandoffMessage': + if message.type == 'TextMessage' or message.type == 'HandoffMessage': + print(f"----------------{message.source}----------------\n {message.content}") await websocket.send_json(message.model_dump()) if not isinstance(message, UserInputRequestedEvent): history.append(message.model_dump()) @@ -338,7 +338,8 @@ async def websocket_endpoint(websocket: WebSocket, camera_id: str): send_task.cancel() await asyncio.gather(capture_task, send_task, return_exceptions=True) finally: - process.kill() + if process and process.poll() is None: # Check if the process is still running + process.kill() await websocket.close() diff --git a/_backend/api1.py b/_backend/api1.py new file mode 100644 index 0000000..b745200 --- /dev/null +++ b/_backend/api1.py @@ -0,0 +1,280 @@ +import json +import logging +import os +from typing import Any, Awaitable, Callable, Optional, Sequence +import uuid +import aiofiles +import yaml +import cv2 +import base64 +import asyncio +import numpy as np +import time +import subprocess +import ffmpeg +import io +from PIL import Image +from collections import deque + +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles +from fastapi.responses import HTMLResponse +from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack +from aiortc.contrib.media import MediaRelay +from aiortc.contrib.media import MediaPlayer + +from autogen_agentchat.agents import AssistantAgent, UserProxyAgent +from autogen_agentchat.base import TaskResult +from autogen_agentchat.messages import TextMessage, UserInputRequestedEvent +from autogen_agentchat.teams import RoundRobinGroupChat +from autogen_core import CancellationToken +from autogen_core.models import ChatCompletionClient +from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination +from autogen_agentchat.teams import SelectorGroupChat, RoundRobinGroupChat +from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage, ToolCallExecutionEvent +from constant import MODEL, OPENAI_API_KEY, OPENAI_BASE_URL +from scientist_team import create_scientist_team +from engineer_team import create_engineer_team +from robot_platform import create_robot_team +from analyst_team import create_analyst_team +from utils import load_agent_configs + + +logger = logging.getLogger(__name__) + + +relay = MediaRelay() + +model_client = OpenAIChatCompletionClient( + model=MODEL, + base_url=OPENAI_BASE_URL, + api_key=OPENAI_API_KEY, + model_info={ + "vision": True, + "function_calling": True, + "json_output": True, + "family": "unknown", + }, +) + + +async def get_team( + user_input_func: Callable[[str, Optional[CancellationToken]], Awaitable[str]], + session_dir: str +) -> RoundRobinGroupChat | SelectorGroupChat: + + # Create the team. + scientist_team = create_scientist_team(model_client=model_client) + engineer_team = create_engineer_team() + robot_platform = create_robot_team() + analyst_team = create_analyst_team() + user = UserProxyAgent( + name="user", + input_func=user_input_func, # Use the user input function. + ) + cur_path = os.path.dirname(os.path.abspath(__file__)) + planning_agent_system_message = load_agent_configs(os.path.join(cur_path, "agent_config/planning.yaml")) + planning_agent = AssistantAgent( + "PlanningAgent", + description="An agent for planning tasks, this agent should be the first to engage when given a new task.", + model_client=model_client, + system_message=planning_agent_system_message["PlanningAgent"], + reflect_on_tool_use=False + ) + + # The termination condition is a combination of text mention termination and max message termination. + text_mention_termination = TextMentionTermination("TERMINATE") + max_messages_termination = MaxMessageTermination(max_messages=200) + termination = text_mention_termination | max_messages_termination + + # The selector function is a function that takes the current message thread of the group chat + # and returns the next speaker's name. If None is returned, the LLM-based selection method will be used. + def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None: + if messages[-1].source != planning_agent.name: + return planning_agent.name # Always return to the planning agent after the other agents have spoken. + elif "HUMAN" in messages[-1].content: + return user.name + return None + + team = SelectorGroupChat( + [planning_agent, user, scientist_team, engineer_team, robot_platform, analyst_team], + model_client=model_client, # Use a smaller model for the selector. + termination_condition=termination, + selector_func=selector_func, + ) + # Load state from file. + state_path = os.path.join(session_dir, "team_state.json") + if not os.path.exists(state_path): + return team + async with aiofiles.open(state_path, "r") as file: + state = json.loads(await file.read()) + await team.load_state(state) + return team + + + +app = FastAPI() +current_task = None # 用于跟踪当前任务 +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allows all origins + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + +# model_config_path = "model_config.yaml" +# state_path = "team_state.json" +# history_path = "team_history.json" + +# Serve static files +app.mount("/static", StaticFiles(directory="."), name="static") + + +async def get_session_history(session_dir: str) -> list[dict[str, Any]]: + """Get chat history from file using UUID.""" + session_history_path = os.path.join(session_dir, "team_history.json") + if not os.path.exists(session_history_path): + return [] + async with aiofiles.open(session_history_path, "r") as file: + content = await file.read() + if content: + return json.loads(content) + else: + return [] + + +@app.websocket("/history/{session_uuid}") +async def history(websocket: WebSocket) -> list[dict[str, Any]]: + await websocket.accept() + data = await websocket.receive_json() + session_uuid = data["uuid"] + try: + session_history = await get_session_history(session_uuid) + await websocket.send_json(session_history) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + +@app.websocket("/sessions") +async def sessions(websocket: WebSocket) -> list[dict[str, str]]: + """Get all history UUIDs and their main content.""" + await websocket.accept() + cur_path = os.path.dirname(os.path.abspath(__file__)) + history_dir = os.path.join(cur_path, "history") + session_data = [] + for dir_name in os.listdir(history_dir): + session_dir = os.path.join(history_dir, dir_name) + if os.path.isdir(session_dir): # Check if it's a directory + try: + history = await get_session_history(session_dir) + main_content = history[0]["content"] if history and "content" in history[0] else "" + session_data.append({"uuid": dir_name, "content": main_content}) + + except Exception as e: + print(f"Error reading history for {dir_name}: {e}") # Log the error but continue + + await websocket.send_json(session_data) + + +@app.websocket("/ws/chat") +async def chat(websocket: WebSocket): + await websocket.accept() + + # User input function used by the team. + async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str: + data = await websocket.receive_json() + return data['content'] + + try: + while True: + # Get user message. + data = await websocket.receive_json() + + if 'session_uuid' not in data: + # New session + request = TextMessage.model_validate(data) + # request = data['content'] + session_uuid = str(uuid.uuid4()) # Initialize a unique UUID for each session + cur_path = os.path.dirname(os.path.abspath(__file__)) + session_dir = os.path.join(cur_path, "history", session_uuid) # Directory for session states + os.makedirs(session_dir, exist_ok=True) # ensure the directory is created. + history = [] + else: + session_uuid = data['session_uuid'] + cur_path = os.path.dirname(os.path.abspath(__file__)) + session_dir = os.path.join(cur_path, "history", session_uuid) # Directory for session states + history = await get_session_history(session_dir) + new_data = {k: v for k, v in data.items() if k != "session_uuid"} + request = TextMessage.model_validate(new_data) + request = history + request + + try: + # Get the team and respond to the message. + team = await get_team(_user_input, session_dir) + + stream = team.run_stream(task=request) + async for message in stream: + if isinstance(message, TaskResult): + continue + if message.type == 'TextMessage' or message.type == 'HandoffMessage': + print(f"----------------{message.source}----------------\n {message.content}") + await websocket.send_json(message.model_dump()) + if not isinstance(message, UserInputRequestedEvent): + history.append(message.model_dump()) + + # Save chat history to file. + session_history_path = os.path.join(session_dir, "team_history.json") + async with aiofiles.open(session_history_path, "w") as file: + await file.write(json.dumps(history)) + + # # Save team state to file. + # session_state_path = os.path.join(session_dir, "team_state.json") + # async with aiofiles.open(session_state_path, "w") as file: + # state = await team.save_state() + # await file.write(json.dumps(state)) + + except Exception as e: + # Send error message to client + error_message = { + "type": "error", + "content": f"Error: {str(e)}", + "source": "system" + } + await websocket.send_json(error_message) + # Re-enable input after error + await websocket.send_json({ + "type": "UserInputRequestedEvent", + "content": "An error occurred. Please try again.", + "source": "system" + }) + + except WebSocketDisconnect: + logger.info("Client disconnected") + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + try: + await websocket.send_json({ + "type": "error", + "content": f"Unexpected error: {str(e)}", + "source": "system" + }) + except: + pass + + + +RTSP_STREAMS = { + "camera1": "rtsp://admin:@192.168.1.13:554/live", + "camera2": "rtsp://admin:@192.168.1.10:554/live", +} + + +# Example usage +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000)